# Sentiment Analysis with Attention Visualization Playground

This notebook allows you to interact with a sentiment classification model that uses attention mechanisms to identify important parts of sentences. The model has been trained on the SST-2 (Stanford Sentiment Treebank) dataset, which consists of movie reviews annotated with binary sentiment labels.

## What You'll Learn
- How to load and use a pre-trained sentiment analysis model
- How to visualize attention weights to see which words the model focuses on
- How to analyze model predictions and confidence scores
- How attention mechanisms work in natural language processing

## Setup and Imports

First, let's import the necessary libraries and modules.

In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from transformers import AutoTokenizer

# Import model and visualization modules
from models.model import SentimentClassificationModel
from visualization.attention_viz import visualize_multiple_examples, analyze_attention_patterns
from config import config

## Load the Model and Tokenizer

Let's load the pre-trained model and tokenizer.

In [None]:
# Set device (CPU, CUDA, or MPS)
# Uncomment the appropriate line for your system
device = torch.device("cpu")  # Default to CPU
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")  # CUDA (NVIDIA GPUs)
# device = torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else torch.device("cpu")  # MPS (Apple Silicon)

print(f"Using device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Initialize model with the same parameters as the training configuration
model = SentimentClassificationModel(
    vocab_size=tokenizer.vocab_size,
    embedding_dim=config.embedding_dim,
    hidden_dim=config.hidden_dim,
    output_dim=config.output_dim,
    pad_idx=tokenizer.pad_token_id,
    bidirectional=True  # Using bidirectional GRU as in the trained model
)

# Load the pre-trained weights
checkpoint_path = "weights/2025-04-15_11-13-14__last.pth"
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()  # Set the model to evaluation mode

print("Model loaded successfully!")

## Example Sentences for Visualization

Here are some example sentences with different sentiments. Feel free to modify this list or add your own examples.

In [None]:
example_sentences = [
    "I love this movie. It's amazing!",
    "This is the worst movie I've ever seen.",
    "The plot was boring and predictable.",
    "The acting was top-notch and the cinematography was stunning.",
    "I wouldn't recommend this film to anyone."
]

# You can add or modify sentences here
# example_sentences.append("Your own sentence goes here")

## Visualize Attention Weights

Now let's visualize the attention weights for each of our example sentences. The attention weights show which words the model focuses on when making its prediction.

In [None]:
# Visualize attention weights for the example sentences
fig = visualize_multiple_examples(
    model,
    tokenizer,
    example_sentences,
    device,
    figsize=(15, 4*len(example_sentences))
)

## Analyze Attention Patterns

We can also perform a more detailed analysis of the attention patterns across all the example sentences.

In [None]:
analysis_results = analyze_attention_patterns(model, tokenizer, example_sentences, device)

# Display the results in a more readable format
for i, result in enumerate(analysis_results):
    print(f"\nSentence {i+1}: {result['sentence']}")
    print(f"Prediction: {result['prediction']} (Confidence: {result['confidence']:.4f})")
    print(f"Token with highest attention: '{result['max_attention_token']}' (Attention: {result['max_attention_value']:.4f})")
    print(f"Mean attention: {result['mean_attention']:.4f}, Std: {result['std_attention']:.4f}")

## Interactive Prediction

Let's create a simple function to make predictions on new sentences interactively.

In [ ]:
def predict_sentiment(sentence):
    """Predict the sentiment of a sentence and visualize attention weights."""
    # Tokenize the sentence
    token_ids = tokenizer.encode(sentence, add_special_tokens=True)
    input_ids = torch.tensor([token_ids]).to(device)
    
    # Make prediction
    with torch.no_grad():
        outputs, attention_weights = model(input_ids, return_attention=True)
    
    # Get probabilities and prediction
    probs = torch.nn.functional.softmax(outputs, dim=1)
    predicted_class = torch.argmax(probs, dim=1).item()
    confidence = probs[0][predicted_class].item()
    sentiment = "Positive" if predicted_class == 1 else "Negative"
    
    # Print prediction and confidence
    print(f"Sentence: '{sentence}'")
    print(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
    
    # Get tokens for visualization
    tokens = [tokenizer.convert_ids_to_tokens(id) for id in token_ids]
    
    # Extract attention weights
    attention = attention_weights.squeeze(0).cpu().numpy()
    
    # Ensure attention is in the right shape for visualization
    if len(attention.shape) == 1:
        attention = attention.reshape(1, -1)
    
    # Find token with maximum attention
    if len(attention.shape) == 2:
        flat_idx = np.argmax(attention)
        # Convert flat index to 2D coordinates
        row_idx, col_idx = np.unravel_index(flat_idx, attention.shape)
        max_attention_idx = col_idx
        max_attention_value = attention[row_idx, col_idx]
    else:
        max_attention_idx = np.argmax(attention)
        max_attention_value = attention[max_attention_idx]
        
    max_attention_token = tokens[max_attention_idx]
    print(f"Token with highest attention: '{max_attention_token}' (Attention: {max_attention_value:.4f})")
    
    # Visualize attention
    plt.figure(figsize=(12, 3))
    sns.heatmap(
        attention,
        cmap="YlOrRd",
        annot=True,
        fmt=".3f",
        cbar=False,
        xticklabels=tokens,
        yticklabels=["Attention"]
    )
    plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
    plt.title(f"Prediction: {sentiment} (Confidence: {confidence:.4f})")
    plt.tight_layout()
    plt.show()
    
    return sentiment, confidence, attention

## Try It Yourself!

Now you can try the model on your own sentences. Type in any sentence, and the model will predict its sentiment and show which words it's focusing on.

In [None]:
# Try with your own sentence
your_sentence = "The food at this restaurant was absolutely delicious and the service was excellent!"
predict_sentiment(your_sentence)

## Experiment with Different Sentences

Try different types of sentences and see how the model performs. Here are some suggestions:

1. Use sentences with clear positive or negative sentiment
2. Try more neutral sentences
3. Use sentences with negation ("not bad", "isn't great")
4. Try sentences with mixed sentiment

See which words get the highest attention in each case!

In [None]:
# Try with some more examples
sentences_to_try = [
    "The movie wasn't bad at all.",
    "This book is not particularly exciting, but it's informative.",
    "While I enjoyed the beginning, the ending was disappointing.",
    "Despite some flaws, the overall experience was positive."
]

for sentence in sentences_to_try:
    predict_sentiment(sentence)
    print("\n" + "-"*50 + "\n")

## How Attention Works in This Model

The attention mechanism in this model works by assigning weights to different words in the input sentence. These weights represent how important each word is for the final sentiment prediction.

Here's a simplified explanation of how the attention mechanism works:

1. The input sentence is first processed by an embedding layer and then by a GRU (Gated Recurrent Unit) layer, which produces hidden states for each word.
2. The attention mechanism calculates a score for each hidden state, indicating its importance.
3. These scores are normalized using a softmax function to create attention weights that sum to 1.
4. The final context vector is a weighted sum of all hidden states, where the weights are the attention weights.
5. This context vector is then used for the final prediction.

The attention weights visualized in the heatmaps above show which words the model focused on when making its prediction. Words with higher attention weights have a stronger influence on the model's decision.

## Conclusion

In this notebook, we've explored a sentiment classification model that uses attention mechanisms to focus on important words in sentences. We've seen how attention weights can provide interpretability, showing us which parts of the input the model is focusing on when making predictions.

Key takeaways:
- Attention mechanisms help models focus on relevant parts of the input
- Visualizing attention weights provides insights into model decision-making
- The model tends to focus on words with strong sentiment polarity
- Understanding attention can help us build more interpretable and effective NLP models