In [None]:
!pip install torch torchvision torchaudio

In [None]:
!pip install transformers

In [None]:
!pip install matplotlib seaborn

In [None]:
!pip install plotly

In [None]:
!pip install umap-learn

In [None]:
!pip install scikit-learn

In [None]:
pip install pandas numpy

In [None]:
!pip install ipywidgets

In [None]:
# BERT Visualization Setup - Alternative to BertViz
# This version works without the bertviz package

import torch
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# =============================================================================
# 1. BERT MODEL INITIALIZATION
# =============================================================================

def initialize_bert_model(model_name='bert-base-uncased'):
    """Initialize BERT model and tokenizer for visualization"""
    print(f"Loading {model_name}...")
    
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name, output_attentions=True)
    model.eval()
    
    print("Model loaded successfully!")
    return tokenizer, model

# =============================================================================
# 2. CUSTOM ATTENTION VISUALIZATION (Replaces BertViz functionality)
# =============================================================================

def create_attention_heatmap(text, tokenizer, model, layer=11, head=0):
    """Create custom attention heatmap"""
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    with torch.no_grad():
        outputs = model(**inputs)
        attention = outputs.attentions[layer][0, head].numpy()
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(attention, 
                xticklabels=tokens, 
                yticklabels=tokens,
                cmap='Blues',
                annot=True,
                fmt='.2f',
                square=True)
    plt.title(f'BERT Attention Heatmap - Layer {layer}, Head {head}')
    plt.xlabel('Key Tokens')
    plt.ylabel('Query Tokens')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

def create_interactive_attention_plot(text, tokenizer, model, layer=11, head=0):
    """Create interactive attention visualization using Plotly"""
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    with torch.no_grad():
        outputs = model(**inputs)
        attention = outputs.attentions[layer][0, head].numpy()
    
    fig = go.Figure(data=go.Heatmap(
        z=attention,
        x=tokens,
        y=tokens,
        colorscale='Blues',
        text=attention,
        texttemplate="%{text:.2f}",
        textfont={"size": 8},
        hoverongaps=False
    ))
    
    fig.update_layout(
        title=f'Interactive BERT Attention - Layer {layer}, Head {head}',
        xaxis_title='Key Tokens',
        yaxis_title='Query Tokens',
        width=800,
        height=700
    )
    
    fig.show()

def visualize_multi_head_attention(text, tokenizer, model, layer=11):
    """Visualize all attention heads for a given layer"""
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    with torch.no_grad():
        outputs = model(**inputs)
        layer_attention = outputs.attentions[layer][0]  # [num_heads, seq_len, seq_len]
    
    num_heads = layer_attention.shape[0]
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.flatten()
    
    for head in range(min(num_heads, 12)):
        attention = layer_attention[head].numpy()
        
        im = axes[head].imshow(attention, cmap='Blues')
        axes[head].set_title(f'Head {head}')
        axes[head].set_xticks(range(len(tokens)))
        axes[head].set_yticks(range(len(tokens)))
        axes[head].set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
        axes[head].set_yticklabels(tokens, fontsize=8)
    
    plt.tight_layout()
    plt.suptitle(f'All Attention Heads - Layer {layer}', y=1.02, fontsize=16)
    plt.show()

# =============================================================================
# 3. ATTENTION FLOW VISUALIZATION
# =============================================================================

def visualize_attention_flow(text, tokenizer, model, target_token_idx=1):
    """Visualize how attention to a specific token changes across layers"""
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
    
    with torch.no_grad():
        outputs = model(**inputs)
        attentions = outputs.attentions
    
    # Extract attention to target token across all layers
    attention_to_target = []
    layer_names = []
    
    for layer_idx, layer_attention in enumerate(attentions):
        # Average across heads
        avg_attention = layer_attention[0].mean(dim=0)  # [seq_len, seq_len]
        attention_scores = avg_attention[:, target_token_idx].numpy()
        attention_to_target.append(attention_scores)
        layer_names.append(f'Layer {layer_idx}')
    
    # Create interactive plot
    fig = go.Figure()
    
    for i, (scores, layer_name) in enumerate(zip(attention_to_target, layer_names)):
        fig.add_trace(go.Scatter(
            x=tokens,
            y=scores,
            mode='lines+markers',
            name=layer_name,
            line=dict(width=2),
            marker=dict(size=6)
        ))
    
    fig.update_layout(
        title=f'Attention Flow to Token: "{tokens[target_token_idx]}"',
        xaxis_title='Tokens',
        yaxis_title='Attention Score',
        width=900,
        height=600,
        hovermode='x unified'
    )
    
    fig.show()

# =============================================================================
# 4. EMBEDDING ANALYSIS
# =============================================================================

def get_bert_embeddings(texts, tokenizer, model, pooling='cls'):
    """Extract BERT embeddings for a list of texts"""
    embeddings = []
    
    for text in texts:
        inputs = tokenizer.encode_plus(text, 
                                     return_tensors='pt', 
                                     add_special_tokens=True,
                                     max_length=512,
                                     truncation=True,
                                     padding='max_length')
        
        with torch.no_grad():
            outputs = model(**inputs)
            hidden_states = outputs.last_hidden_state[0]  # [seq_len, hidden_size]
            
            if pooling == 'cls':
                embedding = hidden_states[0].numpy()  # [CLS] token
            elif pooling == 'mean':
                # Mean pooling (excluding padding tokens)
                mask = inputs['attention_mask'][0].numpy()
                masked_hidden = hidden_states.numpy() * mask[:, None]
                embedding = masked_hidden.sum(axis=0) / mask.sum()
            else:
                embedding = hidden_states[0].numpy()  # Default to CLS
                
            embeddings.append(embedding)
    
    return np.array(embeddings)

def visualize_embeddings_2d(embeddings, labels=None, texts=None, method='umap'):
    """Visualize BERT embeddings in 2D"""
    if method == 'umap':
        reducer = umap.UMAP(n_components=2, random_state=42)
        title = "BERT Embeddings - UMAP"
    elif method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
        title = "BERT Embeddings - t-SNE"
    else:  # PCA
        reducer = PCA(n_components=2, random_state=42)
        title = "BERT Embeddings - PCA"
    
    embedding_2d = reducer.fit_transform(embeddings)
    
    fig = go.Figure()
    
    if labels is not None:
        unique_labels = list(set(labels))
        colors = px.colors.qualitative.Set1[:len(unique_labels)]
        
        for i, label in enumerate(unique_labels):
            mask = [l == label for l in labels]
            hover_text = [f"Label: {label}<br>Text: {texts[j] if texts else f'Point {j}'}" 
                         for j, m in enumerate(mask) if m]
            
            fig.add_trace(go.Scatter(
                x=embedding_2d[mask, 0],
                y=embedding_2d[mask, 1],
                mode='markers',
                name=label,
                marker=dict(size=10, color=colors[i]),
                text=hover_text,
                hovertemplate='%{text}<extra></extra>'
            ))
    else:
        hover_text = [f"Text: {texts[i] if texts else f'Point {i}'}" for i in range(len(embeddings))]
        fig.add_trace(go.Scatter(
            x=embedding_2d[:, 0],
            y=embedding_2d[:, 1],
            mode='markers',
            marker=dict(size=10, color='blue'),
            text=hover_text,
            hovertemplate='%{text}<extra></extra>'
        ))
    
    fig.update_layout(
        title=title,
        xaxis_title=f'{method.upper()} 1',
        yaxis_title=f'{method.upper()} 2',
        width=800,
        height=600
    )
    
    fig.show()

# =============================================================================
# 5. LAYER-WISE ANALYSIS
# =============================================================================

def analyze_layer_representations(text, tokenizer, model):
    """Analyze how representations change across BERT layers"""
    inputs = tokenizer.encode_plus(text, return_tensors='pt', add_special_tokens=True)
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states
    
    # Extract [CLS] representations from each layer
    cls_representations = []
    for layer_output in hidden_states:
        cls_rep = layer_output[0, 0, :].numpy()  # [CLS] token
        cls_representations.append(cls_rep)
    
    cls_representations = np.array(cls_representations)
    
    # Compute similarities between consecutive layers
    similarities = []
    for i in range(len(cls_representations) - 1):
        sim = np.dot(cls_representations[i], cls_representations[i+1]) / (
            np.linalg.norm(cls_representations[i]) * np.linalg.norm(cls_representations[i+1])
        )
        similarities.append(sim)
    
    # Create interactive plot
    fig = go.Figure()
    fig.add_trace(go.Scatter(
        x=list(range(1, len(similarities) + 1)),
        y=similarities,
        mode='lines+markers',
        line=dict(width=3, color='blue'),
        marker=dict(size=8, color='red')
    ))
    
    fig.update_layout(
        title='Cosine Similarity Between Consecutive BERT Layers',
        xaxis_title='Layer Transition',
        yaxis_title='Cosine Similarity',
        width=800,
        height=500,
        showlegend=False
    )
    
    fig.show()
    return cls_representations

# =============================================================================
# 6. DEMO FUNCTION
# =============================================================================

def run_bert_visualization_demo():
    """Run comprehensive BERT visualization demo"""
    print("=== BERT Visualization Demo (Without BertViz) ===\n")
    
    # Initialize model
    tokenizer, model = initialize_bert_model()
    
    # Sample text
    sample_text = "The quick brown fox jumps over the lazy dog."
    print(f"Analyzing text: '{sample_text}'\n")
    
    # 1. Basic attention heatmap
    print("1. Creating attention heatmap...")
    create_attention_heatmap(sample_text, tokenizer, model, layer=11, head=0)
    
    # 2. Interactive attention plot
    print("2. Creating interactive attention plot...")
    create_interactive_attention_plot(sample_text, tokenizer, model, layer=11, head=0)
    
    # 3. Multi-head visualization
    print("3. Visualizing all attention heads...")
    visualize_multi_head_attention(sample_text, tokenizer, model, layer=11)
    
    # 4. Layer analysis
    print("4. Analyzing layer representations...")
    analyze_layer_representations(sample_text, tokenizer, model)
    
    # 5. Embedding visualization
    print("5. Visualizing embeddings...")
    sample_texts = [
        "I love machine learning and AI",
        "Machine learning is fascinating to study",
        "Deep learning models are powerful",
        "Natural language processing is amazing",
        "NLP applications are everywhere",
        "I enjoy reading books",
        "Literature and poetry are beautiful",
        "Writing stories is creative"
    ]
    
    embeddings = get_bert_embeddings(sample_texts, tokenizer, model)
    labels = ['Tech']*5 + ['Literature']*3
    visualize_embeddings_2d(embeddings, labels, sample_texts, method='umap')
    
    print("\n=== Demo Complete! ===")

# =============================================================================
# 7. SIMPLIFIED WIDGET INTERFACE
# =============================================================================

def create_simple_bert_explorer():
    """Create a simple interface for BERT exploration"""
    import ipywidgets as widgets
    from IPython.display import display, clear_output
    
    # Initialize model once
    tokenizer, model = initialize_bert_model()
    
    # Widgets
    text_input = widgets.Textarea(
        value="The quick brown fox jumps over the lazy dog.",
        placeholder="Enter text to analyze...",
        description="Text:",
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='70%', height='80px')
    )
    
    layer_slider = widgets.IntSlider(value=11, min=0, max=11, description="Layer:")
    head_slider = widgets.IntSlider(value=0, min=0, max=11, description="Head:")
    
    viz_dropdown = widgets.Dropdown(
        options=['Attention Heatmap', 'Interactive Attention', 'Multi-Head View', 'Layer Analysis'],
        value='Attention Heatmap',
        description="Visualization:"
    )
    
    button = widgets.Button(description="Generate Visualization", button_style='success')
    output = widgets.Output()
    
    def on_button_click(b):
        with output:
            clear_output(wait=True)
            text = text_input.value
            layer = layer_slider.value
            head = head_slider.value
            viz_type = viz_dropdown.value
            
            try:
                if viz_type == 'Attention Heatmap':
                    create_attention_heatmap(text, tokenizer, model, layer, head)
                elif viz_type == 'Interactive Attention':
                    create_interactive_attention_plot(text, tokenizer, model, layer, head)
                elif viz_type == 'Multi-Head View':
                    visualize_multi_head_attention(text, tokenizer, model, layer)
                elif viz_type == 'Layer Analysis':
                    analyze_layer_representations(text, tokenizer, model)
            except Exception as e:
                print(f"Error: {e}")
    
    button.on_click(on_button_click)
    
    display(text_input, layer_slider, head_slider, viz_dropdown, button, output)

if __name__ == "__main__":
    print("BERT Visualization Setup Complete!")
    print("Available functions:")
    print("- run_bert_visualization_demo()")
    print("- create_simple_bert_explorer()")
    print("\nNote: This version works without bertviz!")

In [None]:
create_simple_bert_explorer()

In [None]:
!jupyter nbextension enable --py widgetsnbextension

The 12 Attention Head "Specialists"
Early Heads (1-3): "The Basics Crew"

Head 1: "Word Boundaries" - Focuses on separating words and subwords
Head 2: "Local Grammar" - Looks at immediate word relationships (article → noun)
Head 3: "Punctuation Tracker" - Pays attention to commas, periods, special tokens

Middle Heads (4-6): "The Syntax Squad"

Head 4: "Subject-Verb Matcher" - Connects subjects with their verbs
Head 5: "Modifier Linker" - Links adjectives to nouns, adverbs to verbs
Head 6: "Dependency Parser" - Tracks grammatical dependencies

Deep Heads (7-9): "The Meaning Makers"

Head 7: "Coreference Resolver" - Connects pronouns to what they refer to
Head 8: "Semantic Similarity" - Groups words with similar meanings
Head 9: "Long-Distance Relations" - Handles complex sentence structures

Final Heads (10-12): "The Big Picture Team"

Head 10: "Entity Tracker" - Focuses on named entities and important concepts
Head 11: "Sentence Integrator" - Combines information for overall meaning
Head 12: "Context Collector" - Gathers global context for the [CLS] token

Important Notes:

Layer Matters: These patterns vary by layer! Head 1 in Layer 3 does different things than Head 1 in Layer 11
Context Dependent: The same head might focus on different things in different sentences
Collaborative: Real understanding comes from all heads working together across all 12 layers
Emergent Behavior: BERT wasn't programmed for these roles - they emerged naturally from training

For Teaching:
"Think of BERT as having 144 different reading specialists (12 heads × 12 layers) that each developed their own expertise just from reading lots of text. Some became grammar experts, others meaning detectives, others relationship trackers - all working together to understand language!"
Try This in Your Visualization:

Early layers (0-3): Look for basic word-level patterns
Middle layers (4-8): Watch for syntactic relationships
Late layers (9-11): Observe semantic and contextual integration

Use your interactive explorer to see these specialists in action!