# 06 - Attention Visualization and Model Explainability

This notebook provides explainability for our CodeBERT vulnerability detection model through attention visualization.

**Objectives:**
- Extract attention weights from the trained model
- Map attention to source code tokens and lines
- Generate highlighted HTML visualizations
- Create attention heatmaps for different vulnerability types
- Analyze which code patterns the model focuses on

**Outputs:**
- `results/visualizations/attention_*.html` - Interactive attention visualizations
- `results/visualizations/attention_heatmaps.png` - Attention pattern analysis
- `results/visualizations/token_importance.csv` - Token-level importance scores

## Setup and Imports

In [None]:
import os
import json
import pickle
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import re

# HTML and visualization
from IPython.display import HTML, display
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Transformers and ML
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Configuration and Setup

In [None]:
# Configuration
BASE_DIR = Path('/home/netweb/vasu/smart-contract-vuln-detector')
DATA_DIR = BASE_DIR / 'data/processed'
RESULTS_DIR = BASE_DIR / 'results'
MODELS_DIR = BASE_DIR / 'models'

# Create directories
(RESULTS_DIR / 'visualizations').mkdir(exist_ok=True)

# Model configuration
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Attention visualization settings
ATTENTION_HEAD_AVERAGE = True  # Average across attention heads
LAYER_TO_USE = -1  # Use last layer (-1) or specify layer index
TOP_K_TOKENS = 20  # Number of top tokens to highlight
MIN_ATTENTION_THRESHOLD = 0.01  # Minimum attention score to consider

print(f"Using device: {device}")
print(f"Results directory: {RESULTS_DIR}")
print(f"Attention configuration:")
print(f"  - Average heads: {ATTENTION_HEAD_AVERAGE}")
print(f"  - Layer: {LAYER_TO_USE}")
print(f"  - Top-K tokens: {TOP_K_TOKENS}")

## Load Model and Data

In [None]:
# Load processed data
print("Loading processed datasets...")
test_df = pd.read_csv(DATA_DIR / 'test_functions.csv')

# Load label encoder
with open(DATA_DIR / 'label_encoder.pkl', 'rb') as f:
    label_encoder = pickle.load(f)

print(f"Test set: {len(test_df)} samples")
print(f"Classes: {list(label_encoder.classes_)}")

# Show test set distribution
print("\nTest set distribution:")
for label, count in test_df['labels'].value_counts().sort_index().items():
    print(f"  {label}: {count} samples")

## Define Enhanced Model with Attention Extraction

In [None]:
class CodeBERTClassifierWithAttention(nn.Module):
    def __init__(self, model_name: str, num_classes: int, dropout_rate: float = 0.1):
        super().__init__()
        self.num_classes = num_classes
        self.codebert = AutoModel.from_pretrained(model_name)
        
        # Classification head
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(self.codebert.config.hidden_size, num_classes)
        
    def forward(self, input_ids, attention_mask, return_attention=False):
        outputs = self.codebert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=True  # Always output attentions for visualization
        )
        
        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        result = {
            'logits': logits,
            'attentions': outputs.attentions,
            'last_hidden_state': outputs.last_hidden_state
        }
            
        return result

# Dataset class with source code preservation
class VulnDatasetWithSource(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Tokenize the source code
        encoding = self.tokenizer(
            row['source_code'],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(row['label_encoded'], dtype=torch.long),
            'contract_id': row['contract_id'],
            'function_id': row['function_id'],
            'original_label': row['labels'],
            'source_code': row['source_code']
        }

print("‚úÖ Enhanced model with attention extraction defined!")

## Load Trained Model

In [None]:
# Find and load the best model
checkpoint_files = list((RESULTS_DIR / 'checkpoints').glob('best_model_*.pt'))
model_files = list(MODELS_DIR.glob('*.pt'))

if checkpoint_files:
    model_path = checkpoint_files[0]
    print(f"Loading checkpoint: {model_path.name}")
elif model_files:
    model_path = model_files[0]
    print(f"Loading model: {model_path.name}")
else:
    raise FileNotFoundError("No trained model found!")

# Initialize tokenizer and model
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
num_classes = len(label_encoder.classes_)
model = CodeBERTClassifierWithAttention(MODEL_NAME, num_classes)

# Load weights
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"‚úÖ Loaded checkpoint from step {checkpoint.get('step', 'unknown')}")
else:
    model.load_state_dict(checkpoint)
    print("‚úÖ Loaded model weights")

model = model.to(device)
model.eval()
print(f"Model loaded on: {device}")

## Select Representative Examples

In [None]:
# Select representative examples for visualization
def select_representative_examples(df, label_encoder, n_per_class=2):
    """Select representative examples for each vulnerability type."""
    selected_examples = []
    
    for class_name in label_encoder.classes_:
        class_samples = df[df['labels'] == class_name]
        
        if len(class_samples) > 0:
            # Select samples with varying source code lengths for diversity
            class_samples = class_samples.copy()
            class_samples['code_length'] = class_samples['source_code'].str.len()
            
            # Select samples: shortest, longest, and random ones in between
            if len(class_samples) >= n_per_class:
                sorted_samples = class_samples.sort_values('code_length')
                indices = np.linspace(0, len(sorted_samples)-1, n_per_class, dtype=int)
                selected = sorted_samples.iloc[indices]
            else:
                selected = class_samples.head(n_per_class)
            
            selected_examples.extend(selected.index.tolist())
            
            print(f"Selected {len(selected)} examples for class '{class_name}'")
            for _, row in selected.iterrows():
                print(f"  - {row['contract_id']}:{row['function_id']} (length: {len(row['source_code'])} chars)")
    
    return selected_examples

# Select examples
print("Selecting representative examples for visualization...")
selected_indices = select_representative_examples(test_df, label_encoder, n_per_class=1)

# Create subset dataset
visualization_df = test_df.iloc[selected_indices].reset_index(drop=True)
visualization_dataset = VulnDatasetWithSource(visualization_df, tokenizer, MAX_LENGTH)

print(f"\nüìä Selected {len(visualization_df)} examples for visualization")
print(f"Distribution: {dict(visualization_df['labels'].value_counts())}")

## Attention Extraction Functions

In [None]:
def extract_attention_weights(model, input_ids, attention_mask, layer_idx=-1, average_heads=True):
    """Extract attention weights from the model."""
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, return_attention=True)
        attentions = outputs['attentions']  # Tuple of attention tensors
        
        # Select layer
        if layer_idx == -1:
            attention = attentions[-1]  # Last layer
        else:
            attention = attentions[layer_idx]
        
        # attention shape: [batch_size, num_heads, seq_len, seq_len]
        if average_heads:
            attention = attention.mean(dim=1)  # Average across heads
        
        return attention, outputs['logits']

def get_token_importance(attention_weights, attention_mask, method='cls_attention'):
    """Calculate token importance from attention weights."""
    # attention_weights: [seq_len, seq_len] or [num_heads, seq_len, seq_len]
    if attention_weights.dim() == 3:
        attention_weights = attention_weights.mean(dim=0)  # Average heads if needed
    
    if method == 'cls_attention':
        # Use attention from CLS token (first token) to all other tokens
        token_importance = attention_weights[0, :]  # CLS attention to all tokens
    elif method == 'sum_attention':
        # Sum of attention weights for each token (as target)
        token_importance = attention_weights.sum(dim=0)
    elif method == 'mean_attention':
        # Mean attention for each token
        token_importance = attention_weights.mean(dim=0)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Apply attention mask (set padded tokens to 0)
    token_importance = token_importance * attention_mask.float()
    
    return token_importance

def map_tokens_to_source(tokenizer, input_ids, source_code):
    """Map tokenizer tokens back to source code positions."""
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    
    # Remove special tokens and convert back to text
    decoded_tokens = []
    token_positions = []
    
    current_pos = 0
    for i, token in enumerate(tokens):
        if token in ['<s>', '</s>', '<pad>', '<unk>']:
            decoded_tokens.append(token)
            token_positions.append(-1)  # Special token
        else:
            # Convert RoBERTa token back to text
            token_text = token.replace('ƒ†', ' ').replace('ƒä', '\n')
            
            # Find position in source code
            if current_pos < len(source_code):
                # Simple position tracking (this is a simplification)
                pos = source_code.find(token_text.strip(), current_pos)
                if pos != -1:
                    token_positions.append(pos)
                    current_pos = pos + len(token_text.strip())
                else:
                    token_positions.append(current_pos)
            else:
                token_positions.append(-1)
            
            decoded_tokens.append(token_text)
    
    return tokens, decoded_tokens, token_positions

print("‚úÖ Attention extraction functions defined!")

## Generate Attention Visualizations

In [None]:
def create_attention_html(tokens, attention_scores, source_code, prediction, true_label, 
                         contract_id, function_id, confidence, top_k=20):
    """Create HTML visualization of attention weights."""
    
    # Normalize attention scores
    max_attention = attention_scores.max().item() if attention_scores.max() > 0 else 1
    normalized_attention = attention_scores / max_attention
    
    # Get top-k most attended tokens
    top_indices = attention_scores.argsort(descending=True)[:top_k]
    
    # Create HTML
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>Attention Visualization - {contract_id}:{function_id}</title>
        <style>
            body {{ font-family: 'Courier New', monospace; margin: 20px; background-color: #f8f9fa; }}
            .header {{ background-color: #343a40; color: white; padding: 15px; border-radius: 5px; margin-bottom: 20px; }}
            .info {{ background-color: #e9ecef; padding: 10px; border-radius: 5px; margin-bottom: 15px; }}
            .code-container {{ background-color: white; padding: 20px; border-radius: 5px; border: 1px solid #dee2e6; }}
            .token {{ display: inline; padding: 2px; margin: 1px; border-radius: 3px; }}
            .legend {{ margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 5px; }}
            .top-tokens {{ margin-top: 15px; }}
            .attention-bar {{ width: 100%; height: 20px; background: linear-gradient(to right, #fff, #ff0000); margin-bottom: 10px; }}
        </style>
    </head>
    <body>
        <div class="header">
            <h2>üîç Attention Visualization</h2>
            <p><strong>Contract:</strong> {contract_id} | <strong>Function:</strong> {function_id}</p>
        </div>
        
        <div class="info">
            <p><strong>üéØ Prediction:</strong> <span style="color: {'green' if prediction == true_label else 'red'};">{prediction}</span></p>
            <p><strong>‚úì True Label:</strong> {true_label}</p>
            <p><strong>üìä Confidence:</strong> {confidence:.3f}</p>
            <p><strong>ü§ñ Model Focus:</strong> Tokens highlighted by attention intensity (red = high attention)</p>
        </div>
        
        <div class="code-container">
            <h3>üíª Source Code with Attention Highlighting:</h3>
            <pre style="white-space: pre-wrap; line-height: 1.6;">
    """
    
    # Add tokens with attention highlighting
    for i, (token, attention) in enumerate(zip(tokens, normalized_attention)):
        if token in ['<s>', '</s>', '<pad>']:
            continue
            
        # Calculate color intensity based on attention
        intensity = attention.item()
        
        if intensity > 0.1:  # Only highlight significant attention
            # Convert to RGB color (white to red gradient)
            red = min(255, int(255 * intensity))
            green = max(0, int(255 * (1 - intensity)))
            blue = max(0, int(255 * (1 - intensity)))
            
            color = f"rgb({red}, {green}, {blue})"
            html_content += f'<span class="token" style="background-color: {color}; color: white;">{token.replace("ƒ†", " ").replace("ƒä", "<br>")}</span>'
        else:
            html_content += token.replace('ƒ†', ' ').replace('ƒä', '<br>')
    
    html_content += """
            </pre>
        </div>
        
        <div class="legend">
            <h3>üé® Attention Legend:</h3>
            <div class="attention-bar"></div>
            <p><strong>Low Attention</strong> ‚Üê ‚Üí <strong>High Attention</strong></p>
    """
    
    # Add top attended tokens
    html_content += '<div class="top-tokens"><h4>üî• Top Attended Tokens:</h4><ol>'
    
    for idx in top_indices[:10]:  # Show top 10
        token = tokens[idx]
        score = attention_scores[idx].item()
        if token not in ['<s>', '</s>', '<pad>'] and score > 0.01:
            clean_token = token.replace('ƒ†', ' ').replace('ƒä', '\\n')
            html_content += f'<li><code>{clean_token}</code> (attention: {score:.4f})</li>'
    
    html_content += """
            </ol>
        </div>
        </div>
    </body>
    </html>
    """
    
    return html_content

# Process examples and generate visualizations
print("Generating attention visualizations...")

attention_results = []
visualization_files = []

for idx in tqdm(range(len(visualization_dataset)), desc="Processing examples"):
    example = visualization_dataset[idx]
    
    # Prepare inputs
    input_ids = example['input_ids'].unsqueeze(0).to(device)
    attention_mask = example['attention_mask'].unsqueeze(0).to(device)
    
    # Extract attention
    attention_weights, logits = extract_attention_weights(
        model, input_ids, attention_mask, 
        layer_idx=LAYER_TO_USE, 
        average_heads=ATTENTION_HEAD_AVERAGE
    )
    
    # Get token importance
    token_importance = get_token_importance(
        attention_weights[0], example['attention_mask'], 
        method='cls_attention'
    )
    
    # Get prediction
    prediction_idx = torch.argmax(logits, dim=-1).item()
    prediction = label_encoder.classes_[prediction_idx]
    confidence = torch.softmax(logits, dim=-1).max().item()
    
    # Map tokens
    tokens, decoded_tokens, token_positions = map_tokens_to_source(
        tokenizer, example['input_ids'], example['source_code']
    )
    
    # Create HTML visualization
    html_content = create_attention_html(
        tokens, token_importance, example['source_code'],
        prediction, example['original_label'],
        example['contract_id'], example['function_id'],
        confidence, top_k=TOP_K_TOKENS
    )
    
    # Save HTML file
    filename = f"attention_{example['original_label']}_{example['contract_id']}_{example['function_id']}.html"
    filepath = RESULTS_DIR / 'visualizations' / filename
    
    with open(filepath, 'w', encoding='utf-8') as f:
        f.write(html_content)
    
    visualization_files.append(str(filepath))
    
    # Store results for analysis
    attention_results.append({
        'contract_id': example['contract_id'],
        'function_id': example['function_id'],
        'true_label': example['original_label'],
        'predicted_label': prediction,
        'confidence': confidence,
        'correct': prediction == example['original_label'],
        'attention_scores': token_importance.cpu().numpy(),
        'tokens': tokens,
        'source_code': example['source_code'],
        'html_file': str(filepath)
    })
    
    print(f"‚úÖ Generated: {filename}")
    print(f"   True: {example['original_label']} | Pred: {prediction} | Conf: {confidence:.3f}")

print(f"\nüéâ Generated {len(visualization_files)} attention visualizations!")
print(f"üìÅ Saved to: {RESULTS_DIR / 'visualizations'}")

## Attention Pattern Analysis

In [None]:
# Analyze attention patterns across vulnerability types
print("Analyzing attention patterns...")

# Create attention analysis dataframe
attention_analysis = []

for result in attention_results:
    attention_scores = result['attention_scores']
    tokens = result['tokens']
    
    # Skip special tokens and calculate statistics
    valid_mask = np.array([token not in ['<s>', '</s>', '<pad>'] for token in tokens])
    valid_attention = attention_scores[valid_mask]
    
    if len(valid_attention) > 0:
        attention_analysis.append({
            'vulnerability_type': result['true_label'],
            'prediction_correct': result['correct'],
            'max_attention': valid_attention.max(),
            'mean_attention': valid_attention.mean(),
            'std_attention': valid_attention.std(),
            'attention_entropy': -np.sum(valid_attention * np.log(valid_attention + 1e-8)),
            'top_10_attention_sum': np.sort(valid_attention)[-10:].sum(),
            'num_high_attention_tokens': np.sum(valid_attention > 0.05),
            'contract_id': result['contract_id'],
            'function_id': result['function_id']
        })

attention_df = pd.DataFrame(attention_analysis)

print(f"üìä Attention Analysis Summary:")
print(f"Total examples analyzed: {len(attention_df)}")
print(f"\nBy vulnerability type:")
summary_stats = attention_df.groupby('vulnerability_type').agg({
    'max_attention': ['mean', 'std'],
    'attention_entropy': ['mean', 'std'],
    'num_high_attention_tokens': ['mean', 'std'],
    'prediction_correct': 'mean'
}).round(4)

print(summary_stats)

# Create attention pattern visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Max attention by vulnerability type
attention_df.boxplot(column='max_attention', by='vulnerability_type', ax=axes[0,0])
axes[0,0].set_title('Maximum Attention Score by Vulnerability Type')
axes[0,0].set_xlabel('Vulnerability Type')
axes[0,0].set_ylabel('Max Attention Score')

# 2. Attention entropy by vulnerability type  
attention_df.boxplot(column='attention_entropy', by='vulnerability_type', ax=axes[0,1])
axes[0,1].set_title('Attention Entropy by Vulnerability Type')
axes[0,1].set_xlabel('Vulnerability Type')
axes[0,1].set_ylabel('Attention Entropy')

# 3. Number of high attention tokens
attention_df.boxplot(column='num_high_attention_tokens', by='vulnerability_type', ax=axes[1,0])
axes[1,0].set_title('Number of High Attention Tokens by Vulnerability Type')
axes[1,0].set_xlabel('Vulnerability Type')
axes[1,0].set_ylabel('Number of High Attention Tokens')

# 4. Attention vs prediction correctness
correct_attention = attention_df[attention_df['prediction_correct']]['max_attention']
incorrect_attention = attention_df[~attention_df['prediction_correct']]['max_attention']

axes[1,1].hist([correct_attention, incorrect_attention], 
              label=['Correct Predictions', 'Incorrect Predictions'], 
              alpha=0.7, bins=15)
axes[1,1].set_title('Attention Distribution: Correct vs Incorrect Predictions')
axes[1,1].set_xlabel('Max Attention Score')
axes[1,1].set_ylabel('Frequency')
axes[1,1].legend()

plt.suptitle('Attention Pattern Analysis', fontsize=16, fontweight='bold')
plt.tight_layout()

# Save attention analysis plot
attention_plot_path = RESULTS_DIR / 'visualizations' / 'attention_patterns_analysis.png'
plt.savefig(attention_plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"üíæ Attention analysis plot saved to: {attention_plot_path}")

## Token Importance Analysis

In [None]:
# Analyze most important tokens across all examples
print("Analyzing token importance patterns...")

# Collect all token-importance pairs
all_token_importance = []

for result in attention_results:
    tokens = result['tokens']
    attention_scores = result['attention_scores']
    vulnerability_type = result['true_label']
    
    for token, score in zip(tokens, attention_scores):
        if token not in ['<s>', '</s>', '<pad>'] and score > MIN_ATTENTION_THRESHOLD:
            # Clean token for analysis
            clean_token = token.replace('ƒ†', '').replace('ƒä', '').strip()
            if clean_token:  # Skip empty tokens
                all_token_importance.append({
                    'token': clean_token.lower(),  # Normalize case
                    'raw_token': token,
                    'attention_score': float(score),
                    'vulnerability_type': vulnerability_type
                })

token_df = pd.DataFrame(all_token_importance)

print(f"Collected {len(token_df)} token-attention pairs")

# Find most important tokens overall
print("\nüî• Most Important Tokens (by average attention):")
top_tokens_overall = token_df.groupby('token').agg({
    'attention_score': ['mean', 'count', 'std'],
    'vulnerability_type': lambda x: list(set(x))
}).round(4)

top_tokens_overall.columns = ['avg_attention', 'frequency', 'std_attention', 'vuln_types']
top_tokens_overall = top_tokens_overall[top_tokens_overall['frequency'] >= 2]  # At least 2 occurrences
top_tokens_overall = top_tokens_overall.sort_values('avg_attention', ascending=False)

print(top_tokens_overall.head(20))

# Find vulnerability-specific important tokens
print("\nüéØ Most Important Tokens by Vulnerability Type:")
for vuln_type in label_encoder.classes_:
    vuln_tokens = token_df[token_df['vulnerability_type'] == vuln_type]
    if len(vuln_tokens) > 0:
        top_vuln_tokens = vuln_tokens.groupby('token').agg({
            'attention_score': ['mean', 'count']
        })
        top_vuln_tokens.columns = ['avg_attention', 'frequency']
        top_vuln_tokens = top_vuln_tokens[top_vuln_tokens['frequency'] >= 1]
        top_vuln_tokens = top_vuln_tokens.sort_values('avg_attention', ascending=False)
        
        print(f"\n{vuln_type}:")
        print(top_vuln_tokens.head(10))

# Save token importance data
token_importance_path = RESULTS_DIR / 'visualizations' / 'token_importance_analysis.csv'
top_tokens_overall.to_csv(token_importance_path)

print(f"\nüíæ Token importance analysis saved to: {token_importance_path}")

## Create Interactive Attention Heatmap

In [None]:
# Create an interactive attention heatmap using plotly
print("Creating interactive attention heatmap...")

# Select one example for detailed heatmap (pick the first one with good attention)
detailed_example_idx = 0
for i, result in enumerate(attention_results):
    if result['attention_scores'].max() > 0.1:  # Good attention scores
        detailed_example_idx = i
        break

detailed_example = attention_results[detailed_example_idx]

# Get the full attention matrix for this example
example_data = visualization_dataset[detailed_example_idx]
input_ids = example_data['input_ids'].unsqueeze(0).to(device)
attention_mask = example_data['attention_mask'].unsqueeze(0).to(device)

with torch.no_grad():
    outputs = model(input_ids, attention_mask)
    # Get last layer attention without averaging heads
    full_attention = outputs['attentions'][-1][0]  # [num_heads, seq_len, seq_len]

# Convert to numpy and get relevant tokens
attention_matrix = full_attention.cpu().numpy()
tokens = detailed_example['tokens']

# Find actual sequence length (non-padded)
actual_length = int(attention_mask.sum().item())
attention_matrix = attention_matrix[:, :actual_length, :actual_length]
relevant_tokens = tokens[:actual_length]

# Create heatmap for averaged attention
avg_attention = attention_matrix.mean(axis=0)  # Average across heads

# Clean tokens for display
display_tokens = [token.replace('ƒ†', ' ').replace('ƒä', '\n')[:15] for token in relevant_tokens]

fig = go.Figure(data=go.Heatmap(
    z=avg_attention,
    x=display_tokens,
    y=display_tokens,
    colorscale='Reds',
    hoverongaps=False,
    hovertemplate='From: %{y}<br>To: %{x}<br>Attention: %{z:.4f}<extra></extra>'
))

fig.update_layout(
    title=f'Attention Heatmap - {detailed_example["true_label"]} ({detailed_example["contract_id"]}:{detailed_example["function_id"]})',
    xaxis_title='Target Tokens',
    yaxis_title='Source Tokens',
    width=800,
    height=800
)

# Save interactive heatmap
heatmap_path = RESULTS_DIR / 'visualizations' / f'attention_heatmap_{detailed_example["true_label"]}_{detailed_example["contract_id"]}_{detailed_example["function_id"]}.html'
fig.write_html(str(heatmap_path))

# Show in notebook
fig.show()

print(f"üíæ Interactive attention heatmap saved to: {heatmap_path}")
print(f"üìä Showing attention for: {detailed_example['true_label']} example")
print(f"   Contract: {detailed_example['contract_id']}")
print(f"   Function: {detailed_example['function_id']}")
print(f"   Prediction: {detailed_example['predicted_label']} (correct: {detailed_example['correct']})")

## Summary and Results

In [None]:
# Create comprehensive summary of explainability results
explainability_summary = {
    'timestamp': datetime.now().isoformat(),
    'model_path': str(model_path),
    'configuration': {
        'attention_layer': LAYER_TO_USE,
        'average_heads': ATTENTION_HEAD_AVERAGE,
        'top_k_tokens': TOP_K_TOKENS,
        'min_attention_threshold': MIN_ATTENTION_THRESHOLD
    },
    'examples_analyzed': len(attention_results),
    'visualizations_generated': len(visualization_files),
    'attention_statistics': {
        'max_attention_overall': float(attention_df['max_attention'].max()),
        'mean_attention_overall': float(attention_df['max_attention'].mean()),
        'attention_entropy_mean': float(attention_df['attention_entropy'].mean()),
        'high_attention_tokens_mean': float(attention_df['num_high_attention_tokens'].mean())
    },
    'files_generated': {
        'html_visualizations': visualization_files,
        'attention_analysis_plot': str(attention_plot_path),
        'token_importance_csv': str(token_importance_path),
        'interactive_heatmap': str(heatmap_path)
    },
    'insights': {
        'most_attended_tokens_overall': top_tokens_overall.head(10).index.tolist(),
        'attention_patterns_by_vulnerability': dict(attention_df.groupby('vulnerability_type')['max_attention'].mean()),
        'prediction_accuracy_analyzed': float(attention_df['prediction_correct'].mean())
    }
}

# Save summary
summary_path = RESULTS_DIR / 'explainability_summary.json'
with open(summary_path, 'w') as f:
    json.dump(explainability_summary, f, indent=2, default=str)

print("üéâ ATTENTION VISUALIZATION COMPLETED!")
print("=" * 50)
print(f"üìä Summary:")
print(f"  ‚Ä¢ Analyzed {len(attention_results)} examples")
print(f"  ‚Ä¢ Generated {len(visualization_files)} HTML visualizations")
print(f"  ‚Ä¢ Created interactive attention heatmap")
print(f"  ‚Ä¢ Analyzed {len(token_df)} token-attention pairs")
print(f"\nüîç Key Insights:")
print(f"  ‚Ä¢ Average max attention: {attention_df['max_attention'].mean():.4f}")
print(f"  ‚Ä¢ Most attended token: '{top_tokens_overall.index[0]}'")
print(f"  ‚Ä¢ Prediction accuracy on analyzed examples: {attention_df['prediction_correct'].mean():.2%}")
print(f"\nüìÅ All results saved to: {RESULTS_DIR / 'visualizations'}")
print(f"üíæ Summary saved to: {summary_path}")

# Display one example visualization in notebook
if visualization_files:
    print(f"\nüîç Sample visualization (first example):")
    sample_file = visualization_files[0]
    print(f"File: {Path(sample_file).name}")
    
    # You can uncomment the following lines to display HTML in notebook
    # with open(sample_file, 'r', encoding='utf-8') as f:
    #     html_content = f.read()
    # display(HTML(html_content))
    
print("\n‚ú® Open the generated HTML files in a browser to view interactive attention visualizations!")