In [None]:
import os
os.chdir('../')  # Moving up one directory to the root
import sys
import tensorflow as tf
import numpy as np
import random
import json
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from models.sentiment_model import EnhancedDistilBertForSentiment
from config.model_config import ModelConfig
from main import SentimentAnalyzer
from utils.analysis import SentimentAnalysisVisualizer

In [None]:
# Add project root to Python path
project_root = os.path.dirname(os.getcwd())
sys.path.append(project_root)

In [None]:
# Initialize analyzer and load model
analyzer = SentimentAnalyzer()
model, history = analyzer.load_saved_model(epoch=5)

print("Processing data to get test split...")
analyzer.process_data()

In [None]:
# Visualize training history
visualizer = SentimentAnalysisVisualizer()
visualizer.visualize_training_history(history)

In [None]:
print("\nEvaluating model on test set...")
test_predictions = []
print(f"Processing {len(analyzer.test_texts)} test examples...")

for i, text in enumerate(analyzer.test_texts):
    prediction = analyzer.predict(text)
    test_predictions.append(prediction)
    if (i + 1) % 100 == 0:  # Progress update every 100 examples
        print(f"Processed {i + 1}/{len(analyzer.test_texts)} examples")

In [None]:
# First, let's check what format our test labels are in
print("Test labels shape:", analyzer.test_labels['sentiment'].shape)
print("Sample test label:", analyzer.test_labels['sentiment'][0])

In [None]:
# 4. Calculate metrics - note that y_true is already in correct format
y_true = analyzer.test_labels['sentiment']  # Already integers (0, 1, 2)
y_pred = [np.argmax([p['sentiment']['negative'], p['sentiment']['neutral'], p['sentiment']['positive']]) 
          for p in test_predictions]

# 5. Display metrics
target_names = ['Negative', 'Neutral', 'Positive']
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=target_names))

# 6. Show confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=target_names,
            yticklabels=target_names)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

# 7. Show detailed examples
print("\nDetailed Examples from Test Set:")
sample_indices = random.sample(range(len(analyzer.test_texts)), 5)
for idx in sample_indices:
    text = analyzer.test_texts[idx]
    true_sentiment = analyzer.test_labels['sentiment'][idx]
    pred = test_predictions[idx]
    
    print(f"\nText: {text}")
    print(f"True Sentiment: {target_names[true_sentiment]}")
    print(f"Predicted probabilities:")
    print(f"- Negative: {pred['sentiment']['negative']:.3f}")
    print(f"- Neutral:  {pred['sentiment']['neutral']:.3f}")
    print(f"- Positive: {pred['sentiment']['positive']:.3f}")
    print(f"Additional features:")
    print(f"- Sarcasm detected: {pred['sarcasm']['detected']}")
    print(f"- Negation detected: {pred['negation']['detected']}")
    print(f"- Multipolar: {pred['multipolarity']['is_multipolar']}")

In [None]:
# 1. First get the basic counts
sarcasm_counts = {
    'True': sum(1 for p in test_predictions if p['sarcasm']['detected']),
    'False': sum(1 for p in test_predictions if not p['sarcasm']['detected'])
}

negation_counts = {
    'True': sum(1 for p in test_predictions if p['negation']['detected']),
    'False': sum(1 for p in test_predictions if not p['negation']['detected'])
}

multipolar_counts = {
    'True': sum(1 for p in test_predictions if p['multipolarity']['is_multipolar']),
    'False': sum(1 for p in test_predictions if not p['multipolarity']['is_multipolar'])
}

# Display counts and percentages
total = len(test_predictions)

print("\nFeature Distribution in Test Set:")
print("\nSarcasm Detection:")
print(f"True:  {sarcasm_counts['True']} ({sarcasm_counts['True']/total*100:.1f}%)")
print(f"False: {sarcasm_counts['False']} ({sarcasm_counts['False']/total*100:.1f}%)")

print("\nNegation Detection:")
print(f"True:  {negation_counts['True']} ({negation_counts['True']/total*100:.1f}%)")
print(f"False: {negation_counts['False']} ({negation_counts['False']/total*100:.1f}%)")

print("\nMultipolarity Detection:")
print(f"True:  {multipolar_counts['True']} ({multipolar_counts['True']/total*100:.1f}%)")
print(f"False: {multipolar_counts['False']} ({multipolar_counts['False']/total*100:.1f}%)")

# 2. Then add the detailed probability analysis
print("\nDetailed Feature Analysis:")
print("\nSarcasm Probabilities:")
sarcasm_probs = [p['sarcasm']['probability'] for p in test_predictions]
print(f"Min: {min(sarcasm_probs):.3f}")
print(f"Max: {max(sarcasm_probs):.3f}")
print(f"Mean: {np.mean(sarcasm_probs):.3f}")

print("\nNegation Probabilities:")
negation_probs = [p['negation']['probability'] for p in test_predictions]
print(f"Min: {min(negation_probs):.3f}")
print(f"Max: {max(negation_probs):.3f}")
print(f"Mean: {np.mean(negation_probs):.3f}")

print("\nMultipolarity Scores:")
polarity_scores = [p['multipolarity']['score'] for p in test_predictions]
print(f"Min: {min(polarity_scores):.3f}")
print(f"Max: {max(polarity_scores):.3f}")
print(f"Mean: {np.mean(polarity_scores):.3f}")

# Check for special tokens in processed texts
special_tokens = {
    'SARC': sum(1 for idx, p in enumerate(test_predictions) if '_SARC_' in analyzer.test_texts[idx]),
    'NEG': sum(1 for idx, p in enumerate(test_predictions) if '_NEG_' in analyzer.test_texts[idx])
}
print("\nSpecial Tokens Found:")
print(f"_SARC_ tokens: {special_tokens['SARC']}")
print(f"_NEG_ tokens: {special_tokens['NEG']}")

# 3. Keep your visualization code if you want it
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Sarcasm plot
ax1.bar(['True', 'False'], [sarcasm_counts['True'], sarcasm_counts['False']])
ax1.set_title('Sarcasm Detection')
ax1.set_ylabel('Count')

# Negation plot
ax2.bar(['True', 'False'], [negation_counts['True'], negation_counts['False']])
ax2.set_title('Negation Detection')

# Multipolarity plot
ax3.bar(['True', 'False'], [multipolar_counts['True'], multipolar_counts['False']])
ax3.set_title('Multipolarity Detection')

plt.tight_layout()
plt.show()