In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import RobertaTokenizer, RobertaModel
import networkx as nx
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from PIL import Image
from matplotlib.patches import Rectangle
import pandas as pd
import os
from tqdm.auto import tqdm
import matplotlib.ticker as ticker

def visualize_attention_weights(model, tokenizer, comment, context='', model_path=None, 
                               device='cuda', layer_indices=None, head_indices=None,
                               save_path='attention_visualization', attention_threshold=0.03):
    """
    Visualize attention weights from the RoBERTa component of the sarcasm detection model
    
    Args:
        model: The sarcasm detection model (SarcasmGCNLSTMDetector instance or None if model_path provided)
        tokenizer: RoBERTa tokenizer
        comment: The comment text to analyze
        context: The context for the comment (optional)
        model_path: Path to load model from (optional)
        device: Device to run model on ('cuda' or 'cpu')
        layer_indices: List of layer indices to visualize (if None, use last layer)
        head_indices: List of attention head indices to visualize (if None, use all heads)
        save_path: Base path to save visualization files
        attention_threshold: Threshold for considering attention connections significant
    
    Returns:
        Dictionary with attention visualization data
    """
    # Create directory for outputs
    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
    
    # Load model if path provided
    if model_path and model is None:
        model = SarcasmGCNLSTMDetector().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
    
    model.eval()
    
    # Format input
    if isinstance(context, list):
        context = " ".join([str(c) for c in context if c])
    
    if context.strip():
        combined_text = f"Context: {context} Comment: {comment}"
    else:
        combined_text = comment
    
    # Tokenize input
    encoding = tokenizer(
        combined_text,
        truncation=True,
        padding='max_length',
        max_length=128,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Get number of non-padding tokens
    n_tokens = attention_mask.sum().item()
    
    # Run model with attention output
    with torch.no_grad():
        # Extract RoBERTa model (assumes first component of hybrid model is RoBERTa)
        roberta_model = model.roberta
        
        # Run RoBERTa with attention outputs
        outputs = roberta_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True
        )
        
        # Get all attention layers
        attentions = outputs.attentions
        
        # Also get the full model prediction
        try:
            # Create empty placeholder for GCN
            graph_x = torch.zeros((1, 305), dtype=torch.float).to(device)  # GloVe + Sentiment dimensions
            graph_edge_index = torch.zeros((2, 0), dtype=torch.long).to(device)
            
            logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                graph_x=graph_x,
                graph_edge_index=graph_edge_index
            )
            
            prediction_prob = torch.sigmoid(logits).item()
            prediction = "Sarcastic" if prediction_prob > 0.5 else "Not Sarcastic"
            confidence = prediction_prob if prediction_prob > 0.5 else 1 - prediction_prob
        except Exception as e:
            print(f"Warning: Could not get full model prediction due to error: {str(e)}")
            prediction = "Unknown"
            confidence = 0.0
    
    # Convert token ids to tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    tokens = tokens[:n_tokens]  # Remove padding tokens
    
    # MODIFICATION: Remove special tokens <s> and </s> and clean token representations
    # First, find indices of special tokens
    special_tokens_indices = []
    for i, token in enumerate(tokens):
        if token in ['<s>', '</s>']:
            special_tokens_indices.append(i)
    
    # Filter out special tokens and adjust attention matrices
    if special_tokens_indices:
        # Create a mask for non-special tokens
        valid_indices = [i for i in range(len(tokens)) if i not in special_tokens_indices]
        
        # Update token list - remove special tokens
        tokens = [tokens[i] for i in valid_indices]
        
        # Clean tokens - remove "Ġ" prefix from tokens
        tokens = [token.replace('Ġ', '') for token in tokens]
        
        # Adjust attention matrices for all layers
        modified_attentions = []
        for layer_attn in attentions:
            # Extract and keep only rows and columns for non-special tokens
            layer_valid = layer_attn[0, :, valid_indices, :][:, :, valid_indices]
            modified_attentions.append(layer_valid.unsqueeze(0))
        
        # Replace original attentions with modified ones
        attentions = tuple(modified_attentions)
        
        # Update token count
        n_tokens = len(tokens)
    else:
        # Just clean tokens if no special tokens found
        tokens = [token.replace('Ġ', '') for token in tokens]
    
    # Process layer indices
    num_layers = len(attentions)
    if layer_indices is None:
        layer_indices = [num_layers - 1]  # Default to last layer
    else:
        # Ensure layer indices are valid
        layer_indices = [i for i in layer_indices if i < num_layers]
        if not layer_indices:
            layer_indices = [num_layers - 1]  # Default to last layer if none are valid
    
    # Process head indices
    num_heads = attentions[0].size(1)
    if head_indices is None:
        head_indices = list(range(num_heads))  # Default to all heads
    else:
        # Ensure head indices are valid
        head_indices = [i for i in head_indices if i < num_heads]
        if not head_indices:
            head_indices = list(range(num_heads))  # Default to all heads if none are valid
    
    print(f"Model prediction: {prediction} (Confidence: {confidence:.4f})")
    print(f"Number of layers: {num_layers}, Number of heads per layer: {num_heads}")
    print(f"Analyzing {len(layer_indices)} layers and {len(head_indices)} heads per layer")
    
    # Create visualizations
    attention_data = {}
    
    # 1. Heatmaps for individual attention heads
    for layer_idx in layer_indices:
        for head_idx in head_indices:
            # Extract attention matrix for this head
            attention_matrix = attentions[layer_idx][0, head_idx, :n_tokens, :n_tokens].cpu().numpy()
            
            # Create heatmap
            plt.figure(figsize=(10, 8))
            ax = plt.subplot()
            
            # Plot heatmap
            sns.heatmap(attention_matrix, cmap='viridis', xticklabels=tokens, yticklabels=tokens)
            
            # Format plot
            plt.title(f"Attention Weights - Layer {layer_idx+1}, Head {head_idx+1}")
            plt.xlabel("Token (attention to)")
            plt.ylabel("Token (attention from)")
            
            # Rotate x-axis labels for readability
            plt.xticks(rotation=45, ha='right')
            plt.yticks(rotation=0)
            
            # Save figure
            plt.tight_layout()
            plt.savefig(f"{save_path}_layer{layer_idx+1}_head{head_idx+1}.png", dpi=300, bbox_inches='tight')
            plt.close()
            
            # Store data
            attention_data[f"layer{layer_idx+1}_head{head_idx+1}"] = attention_matrix
    
    # 2. Aggregated attention map (average across selected heads)
    plt.figure(figsize=(12, 10))
    
    # Initialize aggregated matrix
    aggregated_attention = np.zeros((n_tokens, n_tokens))
    count = 0
    
    # Aggregate attention weights
    for layer_idx in layer_indices:
        for head_idx in head_indices:
            aggregated_attention += attentions[layer_idx][0, head_idx, :n_tokens, :n_tokens].cpu().numpy()
            count += 1
    
    aggregated_attention /= count
    
    # Plot aggregated heatmap
    sns.heatmap(aggregated_attention, cmap='viridis', xticklabels=tokens, yticklabels=tokens)
    
    # Format plot
    plt.title(f"Aggregated Attention Weights (Average across {count} heads)")
    plt.xlabel("Token (attention to)")
    plt.ylabel("Token (attention from)")
    
    # Rotate x-axis labels for readability
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    
    # Save figure
    plt.tight_layout()
    plt.savefig(f"{save_path}_aggregated.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # Store aggregated data
    attention_data["aggregated"] = aggregated_attention
    
    # 3. Attention flow graph visualization for the most important layer and head
    # Find the head with highest attention entropy (most informative)
    max_entropy = -float('inf')
    best_layer = 0
    best_head = 0
    
    for layer_idx in layer_indices:
        for head_idx in head_indices:
            attention_matrix = attentions[layer_idx][0, head_idx, :n_tokens, :n_tokens].cpu().numpy()
            
            # Calculate entropy for this attention distribution
            # Higher entropy means more distributed attention (less focused)
            entropy = 0
            for i in range(n_tokens):
                row = attention_matrix[i]
                row = row / (row.sum() + 1e-10)  # Normalize
                row_entropy = -np.sum(row * np.log(row + 1e-10))
                entropy += row_entropy
            
            if entropy > max_entropy:
                max_entropy = entropy
                best_layer = layer_idx
                best_head = head_idx
    
    # Create attention flow graph for most informative head
    attention_matrix = attentions[best_layer][0, best_head, :n_tokens, :n_tokens].cpu().numpy()
    
    # Create a graph
    G = nx.DiGraph()
    
    # Add nodes
    for i, token in enumerate(tokens):
        G.add_node(i, token=token)
    
    # Add weighted edges (only add edges with significant attention weight)
    edge_weights = []
    for i in range(n_tokens):
        for j in range(n_tokens):
            if attention_matrix[i, j] > attention_threshold:
                G.add_edge(i, j, weight=attention_matrix[i, j])
                edge_weights.append(attention_matrix[i, j])
    
    # Visualize the graph - handle empty edge_weights case
    plt.figure(figsize=(14, 12))
    
    if edge_weights:  # Only proceed if we have edges with weights above threshold
        # Create layout (try to make it more readable)
        pos = nx.spring_layout(G, k=0.5, iterations=50)
        
        # Draw the graph
        nx.draw_networkx_nodes(G, pos, node_size=700, node_color='skyblue', alpha=0.8)
        nx.draw_networkx_labels(G, pos, labels={i: data['token'] for i, data in G.nodes(data=True)})
        
        # Only draw edges if we have them
        if G.edges():
            # Get edge weights for width and color
            edge_widths = [G[u][v]['weight'] * 10 for u, v in G.edges()]
            
            nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.7, 
                                edge_color=edge_weights, edge_cmap=plt.cm.Reds,
                                connectionstyle='arc3,rad=0.1')  # Curved edges for better visibility
            
            # Add a colorbar
            sm = ScalarMappable(cmap=plt.cm.Reds, norm=Normalize(vmin=min(edge_weights), vmax=max(edge_weights)))
            sm.set_array([])
            cax = plt.axes([0.92, 0.1, 0.02, 0.8])  # Position for colorbar [left, bottom, width, height]
            plt.colorbar(sm, cax=cax, label='Attention Weight')
        else:
            # No edges, just add a text note
            plt.figtext(0.5, 0.5, "No strong attention connections found", 
                      ha='center', va='center', fontsize=12,
                      bbox=dict(boxstyle="round", fc="white", ec="gray", alpha=0.8))
    else:
        # No edges with weights above threshold, display a message
        plt.text(0.5, 0.5, f"No significant attention connections found above threshold {attention_threshold}", 
                ha='center', va='center', fontsize=12, transform=plt.gca().transAxes,
                bbox=dict(boxstyle="round", fc="white", ec="gray", alpha=0.8))
    
    plt.title(f"Attention Flow Graph (Layer {best_layer+1}, Head {best_head+1})")
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"{save_path}_flow_graph.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Layer comparison visualization
    if len(layer_indices) > 1:
        # Analyze how attention patterns change across layers
        plt.figure(figsize=(15, 10))
        
        # Calculate average attention per layer
        layer_avg_attention = []
        for layer_idx in layer_indices:
            # Average across selected heads
            layer_attention = np.zeros((n_tokens, n_tokens))
            for head_idx in head_indices:
                layer_attention += attentions[layer_idx][0, head_idx, :n_tokens, :n_tokens].cpu().numpy()
            layer_attention /= len(head_indices)
            layer_avg_attention.append(layer_attention)
        
        # Number of layers to visualize
        n_layers = len(layer_indices)
        
        # Calculate grid size
        n_cols = min(3, n_layers)
        n_rows = (n_layers + n_cols - 1) // n_cols
        
        # Plot each layer's attention
        for i, layer_idx in enumerate(layer_indices):
            plt.subplot(n_rows, n_cols, i+1)
            sns.heatmap(layer_avg_attention[i], cmap='viridis', xticklabels=False, yticklabels=False)
            plt.title(f"Layer {layer_idx+1}")
        
        plt.suptitle("Attention Patterns Across Layers (Averaged Across Heads)")
        plt.tight_layout()
        plt.savefig(f"{save_path}_layer_comparison.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    # 5. Attention distribution analysis
    # Analyze how attention is distributed for each token
    token_attention_stats = []
    
    for token_idx in range(n_tokens):
        token = tokens[token_idx]
        
        # Get attention statistics for this token across layers and heads
        token_stats = {
            'token': token,
            'token_idx': token_idx,
            'self_attention': []  # How much attention this token pays to itself
        }
        
        for layer_idx in layer_indices:
            for head_idx in head_indices:
                # Self-attention (attention to self)
                self_attn = attentions[layer_idx][0, head_idx, token_idx, token_idx].item()
                token_stats['self_attention'].append(self_attn)
        
        # Calculate average self-attention
        token_stats['avg_self_attention'] = np.mean(token_stats['self_attention'])
        
        # Calculate attention focus (higher means more focused attention on fewer tokens)
        attention_focus = []
        for layer_idx in layer_indices:
            for head_idx in head_indices:
                attn_dist = attentions[layer_idx][0, head_idx, token_idx, :n_tokens].cpu().numpy()
                # Use Gini coefficient as a measure of attention focus
                sorted_attn = np.sort(attn_dist)
                n = len(sorted_attn)
                index = np.arange(1, n+1)
                gini = 1 - 2 * np.sum((n + 1 - index) * sorted_attn) / (n * np.sum(sorted_attn))
                attention_focus.append(gini)
        
        token_stats['avg_attention_focus'] = np.mean(attention_focus)
        
        token_attention_stats.append(token_stats)
    
    # Create token attention analysis plot
    plt.figure(figsize=(14, 8))
    
    # Sort tokens by their position in the text
    token_indices = [stat['token_idx'] for stat in token_attention_stats]
    tokens_for_plot = [stat['token'] for stat in token_attention_stats]
    avg_self_attention = [stat['avg_self_attention'] for stat in token_attention_stats]
    avg_attention_focus = [stat['avg_attention_focus'] for stat in token_attention_stats]
    
    # Plot attention statistics
    plt.subplot(2, 1, 1)
    plt.bar(token_indices, avg_self_attention, color='cornflowerblue')
    plt.xticks(token_indices, tokens_for_plot, rotation=45, ha='right')
    plt.title('Average Self-Attention by Token')
    plt.ylabel('Self-Attention Weight')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.subplot(2, 1, 2)
    plt.bar(token_indices, avg_attention_focus, color='coral')
    plt.xticks(token_indices, tokens_for_plot, rotation=45, ha='right')
    plt.title('Attention Focus by Token (higher = more focused)')
    plt.ylabel('Attention Focus (Gini)')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(f"{save_path}_token_analysis.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 6. Create a summary report
    try:
        # Generate a markdown report
        report = f"""
# Attention Analysis Report for Sarcasm Detection

## Overview
- **Text**: "{comment}"
- **Context**: "{context}"
- **Prediction**: {prediction} (Confidence: {confidence:.4f})
- **Analyzed**: {len(layer_indices)} layers, {len(head_indices)} heads per layer

## Key Findings

### Most Attentive Tokens
The following tokens receive the most attention:
        
| Token | Position | Self-Attention | Attention Focus |
|-------|----------|----------------|----------------|
"""
        
        # Sort tokens by attention received
        sorted_tokens = sorted(token_attention_stats, key=lambda x: x['avg_self_attention'], reverse=True)
        for i, token_stat in enumerate(sorted_tokens[:5]):  # Top 5 tokens
            report += f"| {token_stat['token']} | {token_stat['token_idx']} | {token_stat['avg_self_attention']:.4f} | {token_stat['avg_attention_focus']:.4f} |\n"
        
        report += """
### Attention Patterns
"""
        
        # Find special attention patterns
        # 1. Tokens that attend strongly to each other (potential connections)
        strong_connections = []
        
        # Use the aggregated attention matrix to find strong connections
        for i in range(n_tokens):
            for j in range(n_tokens):
                if i != j and aggregated_attention[i, j] > 0.1:  # Threshold for strong connection
                    strong_connections.append({
                        'from': tokens[i],
                        'to': tokens[j],
                        'weight': aggregated_attention[i, j]
                    })
        
        # Sort by connection strength
        strong_connections = sorted(strong_connections, key=lambda x: x['weight'], reverse=True)
        
        if strong_connections:
            report += """
#### Strong Token Connections
The following token pairs show strong attention connections:

| From | To | Attention Weight |
|------|----|-----------------:|
"""
            
            for i, conn in enumerate(strong_connections[:5]):  # Top 5 connections
                report += f"| {conn['from']} | {conn['to']} | {conn['weight']:.4f} |\n"
        
        # Save report
        with open(f"{save_path}_report.md", 'w') as f:
            f.write(report)
        
        print(f"Full attention analysis report saved to {save_path}_report.md")
    
    except Exception as e:
        print(f"Error generating report: {str(e)}")
    
    # Return attention data
    return {
        'prediction': prediction,
        'confidence': confidence,
        'tokens': tokens,
        'attention_data': attention_data,
        'aggregated_attention': aggregated_attention,
        'token_attention_stats': token_attention_stats,
        'best_layer': best_layer,
        'best_head': best_head
    }

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    RobertaTokenizer,
    RobertaModel,
    # AdamW,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    confusion_matrix,
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import ast
import os
import gc
import networkx as nx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import numpy as np
import re
import nltk
from nltk.tokenize import word_tokenize
from sklearn.utils import resample
import spacy
from senticnet.senticnet import SenticNet
import gensim.downloader as gensim_downloader

print("Loading GloVe embeddings...")
try:
    glove_embeddings = gensim_downloader.load("glove-wiki-gigaword-300")
    EMBEDDING_DIM = 300
    print(f"Loaded GloVe embeddings with dimension: {EMBEDDING_DIM}")
except Exception as e:
    print(f"Error loading GloVe embeddings: {str(e)}")
    print("Using random embeddings instead")
    glove_embeddings = None
    EMBEDDING_DIM = 300

# Load spaCy model
try:
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy model successfully")
except:
    print("Downloading spaCy model...")
    import subprocess

    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")

# Initialize SenticNet
try:
    sn = SenticNet()
    print("Loaded SenticNet successfully")
except Exception as e:
    print(f"Error loading SenticNet: {str(e)}")
    sn = None

class SarcasmGraphDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.comments = df["comment"].values
        self.contexts = df["context"].values
        self.labels = df["label"].values
        self.max_length = max_length
        self.window_size = 2  # Window size for graph construction

    def __len__(self):
        return len(self.comments)

    def get_embedding(self, word):
        """Get the GloVe embedding for a word"""
        word = word.lower()
        if glove_embeddings and word in glove_embeddings:
            return torch.tensor(glove_embeddings[word], dtype=torch.float)
        else:
            # Use random embedding if word not found
            return torch.randn(EMBEDDING_DIM, dtype=torch.float)

    def get_sentiment_features(self, word):
        """Extract sentiment features using SenticNet"""
        try:
            if sn is not None:
                concept_info = sn.concept(word)
                # Extract polarity value (float between -1 and 1)
                polarity = float(concept_info["polarity_value"])
                # Create a 5-dimensional feature: [polarity, is_positive, is_negative, is_neutral, intensity]
                is_positive = 1.0 if polarity > 0.1 else 0.0
                is_negative = 1.0 if polarity < -0.1 else 0.0
                is_neutral = 1.0 if abs(polarity) <= 0.1 else 0.0
                intensity = abs(polarity)
                return torch.tensor(
                    [polarity, is_positive, is_negative, is_neutral, intensity],
                    dtype=torch.float,
                )
            else:
                return torch.zeros(5, dtype=torch.float)
        except:
            # Word not found in SenticNet
            return torch.zeros(5, dtype=torch.float)

    def create_graph_from_text(self, text):
        """Create a graph representation of text for GCN with enhanced features"""
        # Parse text with spaCy for dependency parsing
        doc = nlp(text.lower())

        # Create a graph where nodes are tokens
        G = nx.Graph()

        # Store tokens for later embedding lookup
        tokens = [token.text for token in doc]

        # Add nodes with positions
        for i, token in enumerate(doc):
            G.add_node(i, word=token.text, pos=token.pos_)

        # Add edges based on window and dependencies
        # 1. Window-based edges
        for i in range(len(tokens)):
            for j in range(i + 1, min(i + self.window_size + 1, len(tokens))):
                G.add_edge(i, j, edge_type=0)  # Type 0: window edge

        # 2. Dependency-based edges
        for token in doc:
            if token.i < len(tokens) and token.head.i < len(tokens):
                G.add_edge(
                    token.i, token.head.i, edge_type=1
                )  # Type 1: dependency edge

        # Convert to PyTorch Geometric Data object
        if len(G.nodes) > 0:
            data = from_networkx(G)

            # Create feature matrix for nodes [GloVe (25d) + Sentiment (5d)]
            feature_dim = EMBEDDING_DIM + 5
            features = torch.zeros((len(G.nodes), feature_dim), dtype=torch.float)

            for i, token_text in enumerate(tokens):
                if i < len(features):
                    # GloVe embedding
                    glove_feature = self.get_embedding(token_text)
                    # Sentiment features
                    sentiment_feature = self.get_sentiment_features(token_text)
                    # Concatenate features
                    if (
                        len(glove_feature) == EMBEDDING_DIM
                        and len(sentiment_feature) == 5
                    ):
                        features[i] = torch.cat([glove_feature, sentiment_feature])

            data.x = features
            return data, tokens
        else:
            # Return empty graph if there are no nodes
            empty_data = Data(
                x=torch.zeros((1, feature_dim), dtype=torch.float),
                edge_index=torch.zeros((2, 0), dtype=torch.long),
            )
            return empty_data, []

    def __getitem__(self, idx):
        comment = str(self.comments[idx])

        # Parse context if it's a string
        if isinstance(self.contexts[idx], str):
            try:
                context_list = ast.literal_eval(self.contexts[idx])
            except:
                context_list = [self.contexts[idx]]
        else:
            context_list = self.contexts[idx]

        # Join all context elements
        context = " ".join([str(c) for c in context_list])

        # Combine context and comment
        combined_text = f"Context: {context} Comment: {comment}"

        # Create graph data with enhanced features
        graph_data, tokens = self.create_graph_from_text(combined_text)

        # Encode with truncation and padding for transformer
        encoding = self.tokenizer(
            combined_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "graph_data": graph_data,
            "tokens": tokens,
            "label": torch.tensor(self.labels[idx], dtype=torch.float),
        }

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.gc = GCNConv(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index):
        x = self.gc(x, edge_index)
        if x.size(0) > 1:  # BatchNorm needs more than 1 element
            x = self.bn(x)
        x = F.relu(x)
        return self.dropout(x)


class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, bidirectional=True):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=bidirectional,
        )

    def forward(self, x):
        # x shape: (batch, seq_len, input_dim)
        lstm_out, _ = self.lstm(x)
        # Get the output from the last non-padded element
        last_output = lstm_out[:, -1, :]
        return last_output


class SarcasmGCNLSTMDetector(nn.Module):
    def __init__(
        self, pretrained_model="roberta-base", gcn_hidden_dim=64, dropout_rate=0.3
    ):
        super(SarcasmGCNLSTMDetector, self).__init__()
        self.roberta = RobertaModel.from_pretrained(pretrained_model)
        self.hidden_dim = self.roberta.config.hidden_size

        # Feature dimensions
        feature_dim = EMBEDDING_DIM + 5  # GloVe + Sentiment

        # 4-layer GCN as per the paper
        self.gcn1 = GCNLayer(feature_dim, gcn_hidden_dim)
        self.gcn2 = GCNLayer(gcn_hidden_dim, gcn_hidden_dim * 2)
        self.gcn3 = GCNLayer(gcn_hidden_dim * 2, gcn_hidden_dim * 2)
        self.gcn4 = GCNLayer(gcn_hidden_dim * 2, gcn_hidden_dim)

        # LSTM for sequential processing
        self.lstm = LSTM(gcn_hidden_dim, gcn_hidden_dim // 2, bidirectional=True)

        # Attention mechanism for combining RoBERTa and GCN-LSTM outputs
        self.attention = nn.Linear(self.hidden_dim + gcn_hidden_dim, 1)

        # Final classification layers
        self.dropout = nn.Dropout(dropout_rate)
        self.fc1 = nn.Linear(self.hidden_dim + gcn_hidden_dim, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, input_ids, attention_mask, graph_x, graph_edge_index):
        # Process text with RoBERTa
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        roberta_embedding = outputs.pooler_output  # [CLS] token embedding

        # Process graph with multi-layer GCN
        x1 = self.gcn1(graph_x, graph_edge_index)
        x2 = self.gcn2(x1, graph_edge_index)
        x3 = self.gcn3(x2, graph_edge_index)
        x4 = self.gcn4(x3, graph_edge_index)

        # Prepare for LSTM - reshape if there's a batch
        batch_size = roberta_embedding.shape[0]
        if batch_size > 1:
            # For simplicity, we'll just take the mean of the node embeddings for batched graphs
            gcn_embedding = torch.mean(x4, dim=0).unsqueeze(0)
            gcn_embedding = gcn_embedding.expand(batch_size, -1)
        else:
            # Use LSTM for sequential processing (for single example)
            # Reshape for LSTM: [num_nodes, features] -> [1, num_nodes, features]
            lstm_input = x4.unsqueeze(0)
            gcn_embedding = self.lstm(lstm_input)

        # Concatenate RoBERTa and GCN-LSTM embeddings
        combined = torch.cat((roberta_embedding, gcn_embedding), dim=1)

        # Apply attention
        attention_weights = torch.sigmoid(self.attention(combined))
        weighted_embedding = combined * attention_weights

        # Final classification
        x = self.dropout(weighted_embedding)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)

        # Output logits (not sigmoid)
        return self.fc2(x)

Loading GloVe embeddings...
Loaded GloVe embeddings with dimension: 300
Loaded spaCy model successfully
Loaded SenticNet successfully


In [3]:
def collate_batch(batch):
    """Custom collate function for handling graph data"""
    # Extract elements from batch
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    tokens_list = [item["tokens"] for item in batch]

    # For graph data, we create a simple representation with batch size of 1
    # In a production system, you would use proper batching from PyG
    graph_xs = [item["graph_data"].x for item in batch]
    graph_edge_indices = [item["graph_data"].edge_index for item in batch]

    # Use the first graph for simplicity (or you could merge graphs with proper shifts)
    feature_dim = EMBEDDING_DIM + 5  # GloVe + Sentiment
    if len(graph_xs) > 0 and graph_xs[0] is not None and graph_xs[0].numel() > 0:
        graph_x = graph_xs[0]
        graph_edge_index = graph_edge_indices[0]
    else:
        # Fallback for empty graphs
        graph_x = torch.zeros((1, feature_dim), dtype=torch.float)
        graph_edge_index = torch.zeros((2, 0), dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "graph_x": graph_x,
        "graph_edge_index": graph_edge_index,
        "tokens": tokens_list,
        "label": labels,
    }


In [4]:
MODEL_PATH = "../sarcasm_gcn_lstm_detector_best.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model_and_tokenizer():
    """Load the sarcasm detection model and tokenizer"""
    print(f"Loading model from {MODEL_PATH}...")
    
    # Initialize tokenizer
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    
    # Initialize model
    model = SarcasmGCNLSTMDetector().to(DEVICE)
    
    # Load trained weights
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print("Model loaded successfully!")
    else:
        print(f"Warning: Model file not found at {MODEL_PATH}")
    
    model.eval()  # Set model to evaluation mode
    
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()
# Example usage
comment = "Congratulations on stating the obvious. I am sure glaciers will start moving any minute now"
save_path = "./sarcasm_attention/sarcasm_attention"
attention_data = visualize_attention_weights(
    model,
    tokenizer,
    comment,
    # context=context,
    device=DEVICE,
    layer_indices=[0, 5, 11],
    head_indices=None,
    save_path=save_path,
    attention_threshold=0.1
)


Loading model from ../sarcasm_gcn_lstm_detector_best.pt...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded successfully!
Model prediction: Sarcastic (Confidence: 0.9531)
Number of layers: 12, Number of heads per layer: 12
Analyzing 3 layers and 12 heads per layer


  plt.tight_layout()


Full attention analysis report saved to ./sarcasm_attention/sarcasm_attention_report.md
