# Exploring the LSTM Sentiment Analysis Model

This notebook allows you to explore the trained LSTM model, load it from a configuration file, and test it with your own text input using an interactive widget.

In [None]:
import os
import yaml
import torch
import numpy as np
import pickle
import ipywidgets as widgets
from IPython.display import display, HTML
from src.data import YelpDataProcessor
from src.models import LSTMSentimentModel
import logging

# Configure logging to be less verbose in the notebook
logging.basicConfig(level=logging.WARNING)

## 1. Load Configuration and Models

First, let's load the model configuration and trained model.

In [None]:
# Select which configuration to use
config_selector = widgets.Dropdown(
    options=[
        ('Default LSTM Config', 'model_configs/lstm_default.yaml'),
        ('Tuned LSTM Config', 'model_configs/lstm_tuning_v1.yaml')
    ],
    value='model_configs/lstm_default.yaml',
    description='Config:',
    style={'description_width': 'initial'}
)

display(config_selector)

In [None]:
def load_config(config_path):
    """Load configuration from YAML file"""
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_model_and_processor(config_path):
    """Load model and data processor based on config"""
    # Load configuration
    print(f"Loading configuration from {config_path}")
    config = load_config(config_path)
    
    # Extract configuration values
    data_config = config.get('data', {})
    model_config = config.get('model', {})
    hp_tuning_config = config.get('hyperparameter_tuning', {})
    best_params = hp_tuning_config.get('best_params', None)
    
    # Apply best parameters if they exist, otherwise use defaults
    effective_model_config = model_config.copy()
    if best_params is not None:
        for param, value in best_params.items():
            if param in ['embedding_dim', 'hidden_size', 'dropout', 'bidirectional'] and param in best_params:
                effective_model_config[param] = value
            if param == 'attention':
                effective_model_config['use_attention'] = value
                
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Initialize data processor
    data_processor = YelpDataProcessor(
        data_path=data_config.get('path'),
        max_length=data_config.get('max_length', 128),
        batch_size=data_config.get('batch_size', 32),
        tokenization_method=data_config.get('tokenization_method', 'bpe')
    )
    
    # Load label encoder
    label_encoder_path = 'models/label_encoder.pkl'
    if os.path.exists(label_encoder_path):
        with open(label_encoder_path, 'rb') as f:
            data_processor.label_encoder = pickle.load(f)
    else:
        print("Label encoder not found. Loading sample data to create one.")
        df = data_processor.load_data()
        data_processor.prepare_data_lstm(df, max_vocab_size=data_config.get('max_vocab_size', 10000))
    
    # Load tokenizer based on tokenization method
    tokenization_method = data_config.get('tokenization_method', 'bpe')
    if tokenization_method == 'bpe':
        tokenizer_path = 'models/bpe_tokenizer.json'
        if os.path.exists(tokenizer_path):
            data_processor.load_bpe_tokenizer(tokenizer_path)
        else:
            print("BPE tokenizer not found. Please train the model first.")
            return None, None
    else:  # word tokenization
        vocab_path = 'models/word_vocab.json'
        if os.path.exists(vocab_path):
            with open(vocab_path, 'r') as f:
                vocab_data = yaml.safe_load(f)
                data_processor.word_to_idx = vocab_data.get('word_to_idx', {})
                data_processor.idx_to_word = vocab_data.get('idx_to_word', {})
                data_processor.vocab_size = vocab_data.get('vocab_size', 0)
        else:
            print("Word vocabulary not found. Please train the model first.")
            return None, None
    
    # Determine model name/directory
    model_name = config.get('name', 'lstm_model')
    model_dir = f"models/{model_name}"
    
    # Build model with configuration
    model = LSTMSentimentModel(
        vocab_size=data_processor.vocab_size,
        embedding_dim=effective_model_config.get('embedding_dim', 128),
        hidden_size=effective_model_config.get('hidden_size', 64),
        num_classes=len(data_processor.label_encoder.classes_),
        bidirectional=effective_model_config.get('bidirectional', True),
        dropout=effective_model_config.get('dropout', 0.2),
        use_attention=effective_model_config.get('use_attention', True),
        max_length=data_config.get('max_length', 128),
        padding_idx=0,
        num_layers=effective_model_config.get('num_layers', 1)
    )
    
    # Load model weights if available
    best_model_path = os.path.join(model_dir, 'model_best_f1.pt')
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path, map_location=device, weights_only=True))
        print(f"Loaded trained model from {best_model_path}")
    else:
        print(f"Trained model not found at {best_model_path}. Using untrained model.")
    
    model.to(device)
    model.eval()  # Set model to evaluation mode
    
    return model, data_processor

In [None]:
# Load the model and processor based on selected config
model, data_processor = load_model_and_processor(config_selector.value)

## 2. Model Architecture and Configuration

Let's examine our model architecture and the configuration that was used.

In [None]:
# Display model architecture
print("LSTM Sentiment Analysis Model Architecture:")
print(model)

# Display configuration details
config = load_config(config_selector.value)
print("\nConfiguration:")
for section, params in config.items():
    print(f"\n{section.upper()}:")
    if isinstance(params, dict):
        for param, value in params.items():
            print(f"  {param}: {value}")
    else:
        print(f"  {params}")

## 3. Interactive Text Classification

Now let's create a widget to input your own text and see the sentiment prediction.

In [None]:
def predict_sentiment(text, model, data_processor, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Process text and make a prediction"""
    # Ensure model is in evaluation mode
    model.eval()
    
    # Preprocess the text
    processed_text = data_processor.preprocess_text(text)
    
    # Convert to sequence based on tokenization method
    if data_processor.tokenization_method == 'bpe':
        sequence = data_processor.texts_to_sequences_bpe([processed_text])[0]
    else:
        sequence = data_processor.texts_to_sequences_word([processed_text])[0]
    
    # Pad sequence
    padded_sequence = data_processor.pad_sequences([sequence], maxlen=data_processor.max_length)[0]
    
    # Convert to tensor
    input_tensor = torch.tensor(padded_sequence, dtype=torch.long).unsqueeze(0).to(device)
    
    # Make prediction
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)[0]
        predicted_class = torch.argmax(probabilities).item()
    
    # Get class name and probabilities
    predicted_label = data_processor.label_encoder.classes_[predicted_class]
    probs_dict = {data_processor.label_encoder.classes_[i]: prob.item() for i, prob in enumerate(probabilities)}
    
    return predicted_label, probs_dict, processed_text

In [None]:
# Create UI elements
text_input = widgets.Textarea(
    value='This restaurant was amazing! The food was delicious and the service was excellent.',
    placeholder='Enter your text here...',
    description='Review:',
    layout=widgets.Layout(width='100%', height='100px')
)

run_button = widgets.Button(
    description='Analyze Sentiment',
    button_style='primary',
    tooltip='Click to analyze the sentiment of the text'
)

config_change_button = widgets.Button(
    description='Change Config',
    button_style='info',
    tooltip='Click to load model from the selected config'
)

output_area = widgets.Output()

# Text styling for output
def style_prediction(label, probabilities):
    """Style the prediction output with colors and bars"""
    colors = {
        'positive': 'green',
        'neutral': 'orange',
        'negative': 'red'
    }
    
    result = f"<h3>Prediction: <span style='color:{colors.get(label, 'blue')}'>{label.upper()}</span></h3>"
    result += "<h4>Confidence Scores:</h4>"
    
    for label, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
        percentage = prob * 100
        color = colors.get(label, 'blue')
        result += f"<div style='margin-bottom:5px;'>"
        result += f"<span style='display:inline-block; width:100px;'>{label}:</span>"
        result += f"<div style='display:inline-block; width:{percentage}%; background-color:{color}; height:20px;'></div>"
        result += f"<span style='margin-left:10px;'>{percentage:.2f}%</span>"
        result += "</div>"
    
    return result

# Define button click handlers
def on_run_button_clicked(b):
    with output_area:
        output_area.clear_output()
        if model is None or data_processor is None:
            print("Error: Model or data processor not loaded properly.")
            return
        
        text = text_input.value
        if not text.strip():
            print("Please enter some text to analyze.")
            return
        
        predicted_label, probabilities, processed_text = predict_sentiment(
            text, model, data_processor
        )
        
        print(f"Original text: {text}")
        print(f"Processed text: {processed_text}")
        display(HTML(style_prediction(predicted_label, probabilities)))

def on_config_change_clicked(b):
    global model, data_processor
    with output_area:
        output_area.clear_output()
        try:
            model, data_processor = load_model_and_processor(config_selector.value)
            print("Model and data processor loaded successfully!")
        except Exception as e:
            print(f"Error loading model: {e}")

# Attach click handlers
run_button.on_click(on_run_button_clicked)
config_change_button.on_click(on_config_change_clicked)

# Display UI
display(text_input)
display(widgets.HBox([run_button, config_change_button]))
display(output_area)

# Initialize prediction
on_run_button_clicked(None)

## 4. Exploring Model Predictions on Sample Reviews

Let's look at some sample reviews and how the model predicts them.

In [None]:
sample_reviews = [
    "The food was absolutely terrible. I'll never come back to this restaurant again.",
    "The service was okay, but the food was mediocre. Not worth the price.",
    "It was an average experience. Nothing special but not bad either.",
    "The staff was friendly and the atmosphere was nice, but the food was just decent.",
    "Amazing experience! The chef prepared the best meal I've had in years."
]

for i, review in enumerate(sample_reviews):
    print(f"\nSample {i+1}: {review}")
    predicted_label, probabilities, processed_text = predict_sentiment(
        review, model, data_processor
    )
    print(f"Prediction: {predicted_label}")
    print("Probabilities:")
    for label, prob in sorted(probabilities.items(), key=lambda x: x[1], reverse=True):
        print(f"  {label}: {prob:.4f} ({prob*100:.2f}%)")

## 5. Exploring Token Attention (if model uses attention)

If the model uses attention, we can visualize which words the model pays attention to when making predictions.

In [None]:
def get_attention_weights(text, model, data_processor, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Get attention weights from the model"""
    # Make sure model uses attention
    if not model.use_attention:
        print("This model doesn't use attention mechanism.")
        return None, None, None
    
    # Preprocess the text
    processed_text = data_processor.preprocess_text(text)
    
    # Convert to tokens for visualization
    if data_processor.tokenization_method == 'bpe':
        tokens = data_processor.lstm_tokenizer.encode(processed_text).tokens
        sequence = data_processor.texts_to_sequences_bpe([processed_text])[0]
    else:
        tokens = processed_text.split()
        sequence = data_processor.texts_to_sequences_word([processed_text])[0]
    
    # Pad sequence
    padded_sequence = data_processor.pad_sequences([sequence], maxlen=data_processor.max_length)[0]
    
    # Convert to tensor
    input_tensor = torch.tensor(padded_sequence, dtype=torch.long).unsqueeze(0).to(device)
    
    # Register a hook to get attention weights
    attention_weights = []
    def hook_fn(module, input, output):
        attention_weights.append(output[1].detach().cpu().numpy())
    
    if hasattr(model, 'attention'):
        hook = model.attention.register_forward_hook(hook_fn)
    
        # Make prediction
        with torch.no_grad():
            output = model(input_tensor)
            probabilities = torch.nn.functional.softmax(output, dim=1)[0]
            predicted_class = torch.argmax(probabilities).item()
        
        # Remove the hook
        hook.remove()
        
        # Get class name and probabilities
        predicted_label = data_processor.label_encoder.classes_[predicted_class]
        
        # Limit tokens to the actual text length (remove padding)
        valid_token_length = min(len(tokens), data_processor.max_length)
        tokens = tokens[:valid_token_length]
        weights = attention_weights[0][0][:valid_token_length]
        
        return tokens, weights, predicted_label
    else:
        print("This model doesn't have the expected attention structure.")
        return None, None, None

## 6. Comparison between BPE and Word Tokenization

If you have models trained with both tokenization methods, you can compare their predictions.

In [None]:
# Create a function to visualize attention weights
def visualize_attention(text):
    tokens, weights, predicted_label = get_attention_weights(text, model, data_processor)
    
    if tokens is None or weights is None:
        print("Couldn't extract attention weights.")
        return
    
    # Normalize weights for visualization
    max_weight = max(weights)
    norm_weights = [w / max_weight for w in weights]
    
    # Create HTML visualization
    html = f"<h3>Attention Visualization for: <span style='color:blue'>{predicted_label.upper()}</span></h3>"
    html += "<div style='line-height: 2.5; font-family: monospace; font-size: 16px;'>"
    
    for token, weight in zip(tokens, norm_weights):
        # Map weight to color intensity
        color_intensity = int( 255 * ( 1 - weight ) )
        background_color = f"rgb(255, {color_intensity}, {color_intensity})"
        
        html += f"<span style='background-color: {background_color}; padding: 3px; margin: 2px; border-radius: 3px;'>{token}</span>"
    
    html += "</div>"
    display(HTML(html))

# Create UI for attention visualization
attention_text = widgets.Textarea(
    value='The food was delicious but the service was terrible.',
    placeholder='Enter text to visualize attention...',
    description='Text:',
    layout=widgets.Layout(width='100%', height='100px')
)

attention_button = widgets.Button(
    description='Visualize Attention',
    button_style='success',
    tooltip='Click to visualize token attention'
)

attention_output = widgets.Output()

def on_attention_button_clicked(b):
    with attention_output:
        attention_output.clear_output()
        visualize_attention(attention_text.value)

attention_button.on_click(on_attention_button_clicked)

# Display attention UI
print("\nVisualize which words the model pays attention to:")
display(attention_text)
display(attention_button)
display(attention_output)

# Initialize visualization
on_attention_button_clicked(None)