<a href="https://colab.research.google.com/github/patemotter/demystifying-ai/blob/main/notebooks/session_4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<div align="center">

#  
  
# Demystifying AI - Session 4
## Transformers and Attention


### Pate Motter, PhD  

AI Performance Engineer @ Google

[LinkedIn](https://www.linkedin.com/in/patemotter/) | [GitHub](https://github.com/patemotter)

---

</div>


# About This Notebook

In Session 3, we saw how RNNs process sequences step-by-step, passing a hidden state forward like a baton in a relay race. This sequential nature, while intuitive, has drawbacks:

* **Long-Range Dependencies:** Information from early parts of a long sequence can get "diluted" or lost by the time the RNN reaches the end (the "vanishing gradient" problem in training). Remembering the subject of a paragraph from the first sentence when generating the last sentence is hard.
* **Parallelization Limits:** The sequential calculation (output at step `t` depends on step `t-1`) makes it difficult to fully utilize parallel processors (like GPUs/TPUs) during training, as each step must wait for the previous one.






In [None]:
# @title Setup and Imports
# Run this cell first to install and import necessary libraries.
# Install plotly for interactive visualizations if needed
# !pip install plotly numpy torch matplotlib

from google.colab import output
output.enable_custom_widget_manager()

import numpy as np
import torch
import torch.nn as nn
import plotly.graph_objects as go
import plotly.subplots as sp
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import math # For positional encoding visualization
import re # For parsing color strings
from transformers import AutoTokenizer, AutoModel

# Consistent color scheme (similar to Session 3)
COLOR_INPUT = 'rgba(99, 110, 250, 0.7)'     # Blue
COLOR_HIDDEN = 'rgba(239, 85, 59, 0.7)'     # Red (used for Feed-Forward)
COLOR_OUTPUT = 'rgba(0, 204, 150, 0.7)'     # Green
COLOR_EDGE = 'rgba(210, 210, 210, 0.8)'     # Light Gray
COLOR_ATTENTION = 'rgba(255, 127, 14, 0.8)' # Orange (for Attention layers)
COLOR_POS_ENC = 'rgba(148, 103, 189, 0.7)'  # Purple (for Positional Encoding)
COLOR_ADDNORM = 'rgba(140, 86, 75, 0.6)'    # Brown (for Add & Norm)
BACKGROUND = 'rgba(248, 248, 248, 0.95)'    # Light background


# Choose a pre-trained model (e.g., a small BERT model)
model_name = "bert-base-uncased"

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the PyTorch model with output_attentions=True
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# Example input sequence
input_sequence_text = "The cat sat on the mat."

# Tokenize the input sequence
inputs = tokenizer(input_sequence_text, return_tensors="pt")
input_ids = inputs["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

# Run the model and get the attention outputs
with torch.no_grad():
    outputs = model(**inputs)

# Extract the attention weights from the last layer (you can choose other layers)
# The 'attentions' output is a tuple of attention maps, one for each layer
attention_weights = outputs.attentions[-1]  # Get the last layer's attention

## The Transformer Solution: Parallel Processing with Attention

The Transformer architecture proposed a radical shift: **process all elements of the sequence simultaneously**.

**The Core Idea:** Instead of a sequential chain, the Transformer allows every element in the sequence to directly interact with every other element. This is achieved through the **Self-Attention** mechanism:
- Look at ALL words in the sequence simultaneously
- For each word, determine which OTHER words are most relevant to understanding it
- Focus "attention" on those relevant words regardless of distance

This parallel processing and direct interaction allow Transformers to better capture long-range dependencies and train much faster on parallel hardware.

This solves both major problems:
1. No information loss over distance - word #50 can directly connect to word #1
2. Parallel processing - the entire sequence is processed simultaneously


In [None]:
# @title RNN vs Transformer Flow
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def create_rnn_vs_transformer_viz():
    """
    Create a visualization comparing RNN and Transformer approaches.
    """
    # Create a sample sentence
    sentence = "The quick brown fox jumps over the lazy dog."
    words = sentence.split()
    n_words = len(words)

    # Create the figure with two subplots
    fig = make_subplots(
        rows=2,
        cols=1,
        subplot_titles=("RNN: Sequential Processing", "Transformer: Parallel Processing with Attention"),
        vertical_spacing=0.2,
        row_heights=[0.4, 0.6]
    )

    # --- RNN Visualization ---

    # Node positions for RNN
    x_positions = np.arange(n_words)
    y_rnn = np.zeros(n_words)

    # Hidden state positions (slightly above the word nodes)
    x_hidden = np.arange(n_words)
    y_hidden = np.ones(n_words) * 0.6

    # Add word nodes
    for i, word in enumerate(words):
        # Add the word node
        fig.add_trace(
            go.Scatter(
                x=[x_positions[i]],
                y=[y_rnn[i]],
                mode='markers+text',
                marker=dict(size=30, color='rgba(0, 0, 255, 0.5)'),
                text=[word],
                textposition="bottom center",
                name="Words" if i == 0 else None,  # Show only once in legend
                legendgroup="words",
                hoverinfo='skip',
                showlegend=(i == 0)
            ),
            row=1, col=1
        )

        # Add hidden state node
        if i > 0:  # No hidden state before the first word
            fig.add_trace(
                go.Scatter(
                    x=[x_hidden[i-1]],
                    y=[y_hidden[i-1]],
                    mode='markers',
                    marker=dict(size=20, color='rgba(255, 0, 0, 0.5)'),
                    name="Hidden States" if i == 1 else None,
                    legendgroup="hidden",
                    hoverinfo='text',
                    hovertext=f"Hidden state after word: {words[i-1]}",
                    showlegend=(i == 1)
                ),
                row=1, col=1
            )

            # Add arrows from previous hidden state to current word
            fig.add_trace(
                go.Scatter(
                    x=[x_hidden[i-1], x_positions[i]],
                    y=[y_hidden[i-1], y_rnn[i]],
                    mode='lines',
                    line=dict(color='rgba(100, 100, 100, 0.8)', width=2),
                    showlegend=False
                ),
                row=1, col=1
            )

        # Add arrows from current word to current hidden state
        if i < n_words - 1:  # No hidden state after the last word (for simplicity)
            fig.add_trace(
                go.Scatter(
                    x=[x_positions[i], x_hidden[i]],
                    y=[y_rnn[i], y_hidden[i]],
                    mode='lines',
                    line=dict(color='rgba(100, 100, 100, 0.8)', width=2),
                    showlegend=False
                ),
                row=1, col=1
            )

            # Add arrows from current hidden state to next hidden state
            if i < n_words - 2:
                fig.add_trace(
                    go.Scatter(
                        x=[x_hidden[i], x_hidden[i+1]],
                        y=[y_hidden[i], y_hidden[i+1]],
                        mode='lines',
                        line=dict(color='rgba(255, 0, 0, 0.8)', width=2),
                        showlegend=False
                    ),
                    row=1, col=1
                )

    # --- Transformer Visualization ---

    # Node positions for Transformer visualization
    x_trans = np.arange(n_words)
    y_trans = np.zeros(n_words)

    # Add word nodes
    for i, word in enumerate(words):
        fig.add_trace(
            go.Scatter(
                x=[x_trans[i]],
                y=[y_trans[i]],
                mode='markers+text',
                marker=dict(size=30, color='rgba(0, 0, 255, 0.5)'),
                text=[word],
                textposition="bottom center",
                name=word if i == 0 else None,
                legendgroup="words_trans",
                hoverinfo='skip',
                showlegend=False
            ),
            row=2, col=1
        )

    # Add attention connections
    # Let's create example attention patterns:
    # 1. Pronoun to referent ("fox" → "it")
    # 2. Adjective to noun ("quick" → "fox", "brown" → "fox", "lazy" → "dog")
    # 3. Verb to subject ("jumps" → "fox")
    # 4. Articles to nouns ("the" → "fox", "the" → "dog")

    attention_pairs = [
        # Format: (from_word, to_word, strength, color)
        ("jumps", "fox", 0.8, 'rgba(255, 0, 0, 0.7)'),
        ("quick", "fox", 0.6, 'rgba(0, 255, 0, 0.7)'),
        ("brown", "fox", 0.7, 'rgba(0, 255, 0, 0.7)'),
        ("lazy", "dog.", 0.7, 'rgba(0, 255, 0, 0.7)'),
        ("over", "jumps", 0.5, 'rgba(255, 165, 0, 0.7)'),
        ("the", "dog.", 0.6, 'rgba(128, 0, 128, 0.7)')
    ]

    added_legend_items = set()

    for from_word, to_word, strength, color in attention_pairs:
        from_idx = words.index(from_word)
        to_idx = words.index(to_word)

        # Calculate control points for curved lines
        x0, y0 = x_trans[from_idx], y_trans[from_idx]
        x1, y1 = x_trans[to_idx], y_trans[to_idx]

        # Higher arc for longer distances
        arc_height = 0.5 + 0.1 * abs(from_idx - to_idx)

        # Determine control points for quadratic Bezier curve
        if from_idx < to_idx:
            xcp = (x0 + x1) / 2
            ycp = arc_height
        else:
            xcp = (x0 + x1) / 2
            ycp = -arc_height

        # Create points for the curve
        t = np.linspace(0, 1, 20)
        x_curve = (1-t)**2 * x0 + 2*(1-t)*t * xcp + t**2 * x1
        y_curve = (1-t)**2 * y0 + 2*(1-t)*t * ycp + t**2 * y1

        # Determine legend name based on relationship type
        legend_name = None
        relationship_type = ""

        if from_word == "jumps" and to_word == "fox":
            relationship_type = "verb_subject"
            legend_name = "Verb-Subject Attention" if "verb_subject" not in added_legend_items else None
        elif from_word in ["quick", "brown"] and to_word == "fox":
            relationship_type = "adj_noun"
            legend_name = "Adjective-Noun Attention" if "adj_noun" not in added_legend_items else None
        elif from_word == "lazy" and to_word == "dog.":
            relationship_type = "adj_noun"
            legend_name = "Adjective-Noun Attention" if "adj_noun" not in added_legend_items else None
        elif from_word == "the" and to_word == "dog.":
            relationship_type = "article_noun"
            legend_name = "Article-Noun Attention" if "article_noun" not in added_legend_items else None
        elif from_word == "over" and to_word == "jumps":
            relationship_type = "preposition_verb"
            legend_name = "Preposition-Verb Attention" if "preposition_verb" not in added_legend_items else None

        if legend_name:
            added_legend_items.add(relationship_type)

        # Add the curve
        fig.add_trace(
            go.Scatter(
                x=x_curve,
                y=y_curve,
                mode='lines',
                line=dict(color=color, width=3 * strength),
                name=legend_name,
                showlegend=legend_name is not None,
                hoverinfo='text',
                hovertext=f"Attention: {from_word} → {to_word} ({strength:.1f})"
            ),
            row=2, col=1
        )

    # Update layout
    fig.update_layout(
        title="RNN vs Transformer: Different Ways to Process Sequences",
        width=1000,
        height=800,
        legend=dict(
            x=1.1,
            y=0.9,
            traceorder="grouped"
        )
    )

    # Update axes
    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, row=2, col=1)
    fig.update_yaxes(showticklabels=False, row=2, col=1)

    # Add annotations
    fig.add_annotation(
        x=4, y=1.1,
        text="RNNs process words sequentially, with information flowing through a chain of hidden states",
        showarrow=False,
        row=1, col=1
    )

    fig.add_annotation(
        x=4, y=1.1,
        text="Transformers process all words at once, with attention connecting related words directly",
        showarrow=False,
        row=2, col=1
    )

    return fig

# Create and display the visualization
fig = create_rnn_vs_transformer_viz()
fig.show()

## Analogy of RNN and Transformer

### RNN - The Lone Detective

* Imagine a detective receiving reports and evidence one piece at a time over several weeks. They read each new report, trying to remember previous details and keep a running mental summary.

* When investigating a clue found in Week 10, their understanding heavily relies on the reports from Week 9. To connect it back to a specific detail mentioned in a Week 1 report, they'd have to manually sift back through the entire stack of files in order, hoping they remember or spot the relevant connection. Information is processed sequentially, and early context can easily get lost or buried.

### Transformer - The Detective Team:

* Now, picture a detective team standing in front of a giant evidence board where all the case files, photos, witness statements, and clues are displayed at once.

* To understand the significance of one specific clue (say, a peculiar muddy footprint), the lead detective points to it and asks the team, "What's the story with this specific footprint?"

* Simultaneously, different team members quickly scan the entire board and all files, but each focuses on finding connections specifically relevant to that footprint:
  * One checks suspect shoe sizes and types.
  * Another checks soil analysis reports from various locations.
  * Another scans witness statements mentioning muddy shoes or relevant locations.
  * Another checks security logs for relevant times.

* They ignore unrelated information. They quickly pull out only the pieces of evidence directly relevant to that footprint (e.g., "This mud matches the park," "Suspect X owns these shoes," "Witness saw muddy shoes near the park entrance").

* The lead detective instantly sees all these targeted, relevant connections drawn from across the entire case and synthesizes them to grasp the footprint's importance. This focused, parallel search happens for every clue they analyze.

---
# Core Concept #1: Positional Encoding

## 1.1 The Challenge of Language

Human language is inherently sequential - we communicate through ordered words, and
changing the order changes the meaning

**Subject-Object Reversal:**
- "Dog bites man" vs. "Man bites dog"
- "The child gave the parent a gift" vs. "The parent gave the child a gift"

**Modifier Scope:**
- "Only I saw the movie yesterday" vs. "I only saw the movie yesterday"
  ("no one else saw it" vs "I did nothing else but see it")

**Negation Position:**
- "Everyone is not invited" vs. "Not everyone is invited"
  ("no one is invited" vs "some people are invited")

**Question vs. Statement:**
- "Is the train arriving?" vs. "The train is arriving"

**Different Meanings with Same Words/Structure:**
- "Turn down the offer" (reject it) vs. "Turn down the street" (change direction) vs. "Turn down the volume" (lower the sound)
- "Stand by your friend" (support them) vs. "Stand by your friend" (physically wait near them)

AI needs to process these sequences in a way that:
1. Understands the meaning of individual words
2. Captures how words relate to each other
3. Preserves important information over long distances
4. Recognizes the significance of word order

Since Transformers process all words in parallel (rather than sequentially), they need a way
to understand word order. Without this, similar but fundamentally different sentences would look nearly identical.

## 1.2 The Solution: Positional Encoding

Since the Transformer processes all input tokens (e.g., words or sub-words) in parallel without inherent recurrence, it loses the natural sense of order provided by RNNs. "The cat chased the dog" and "The dog chased the cat" would look identical to the core self-attention mechanism without additional information.

To solve this, Transformers inject information about the *position* of each token in the sequence. This is done using **Positional Encodings**.

### How it Works
1.  Each token in the input sequence is first converted into a vector (an **embedding**), similar to what we saw in RNNs/MLPs. This vector represents the token's meaning.
2.  A second vector, the **positional encoding**, which depends only on the token's position in the sequence (e.g., 1st word, 2nd word, etc.), is generated.
3.  This positional encoding vector is **added** to the token's embedding vector.

The result is an enriched embedding that contains information about both the token's meaning *and* its position.

### Why it Works
Positional encodings use a clever mathematical pattern (based on sine and cosine functions)
with important properties:

1. **Each position gets a unique pattern**: Position #1 is distinct from position #2, #3, etc.
2. **Similar positions have similar patterns**: Position #5 and #6 have more similar patterns
   than position #5 and #100
3. **Works for any length**: The pattern can extend to sequences longer than what the model
   was trained on



In [None]:
# @title Heatmap of Positional Encodings
# Requires 'pos_encoding', 'embedding_dimension' from previous cells
import plotly.express as px
import numpy as np
from IPython.display import display, Markdown # Ensure these are available

# --- Assume pos_encoding is defined from previous cell ---
# Example definition if needed:
def get_positional_encoding(max_seq_len, d_model):
    pos = np.arange(max_seq_len)[:, np.newaxis]; i = np.arange(d_model)[np.newaxis, :]
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)); angle_rads = pos * angle_rates
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]); angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return angle_rads
max_sequence_length = 512
embedding_dimension = 128
pos_encoding = get_positional_encoding(max_sequence_length, embedding_dimension)
# ---

# --- Parameters for Heatmap ---
num_positions_heatmap = 512  # How many positions (now columns) to show
num_dimensions_heatmap = 128 # How many dimensions (now rows) to show

# --- Create Heatmap (Transposed) ---
# Slice the positional encoding matrix to the desired size
pos_encoding_slice = pos_encoding[:min(num_positions_heatmap, pos_encoding.shape[0]),
                                  :min(num_dimensions_heatmap, pos_encoding.shape[1])]

# Transpose the slice so position is on the x-axis
pos_encoding_slice_transposed = pos_encoding_slice.T

fig_heatmap_transposed = px.imshow(
    pos_encoding_slice_transposed,
    labels=dict(x="Position Index", y="Dimension Index", color="Encoding Value"),
    x=np.arange(pos_encoding_slice_transposed.shape[1]), # Label position indices (original rows) correctly
    y=np.arange(pos_encoding_slice_transposed.shape[0]), # Label dimension indices (original columns) correctly
    title=f"Heatmap of Positional Encodings (First {pos_encoding_slice_transposed.shape[0]} Dimensions, First {pos_encoding_slice_transposed.shape[1]} Positions)",
    color_continuous_scale='viridis',
    aspect="auto"
)

fig_heatmap_transposed.update_layout(height=600, width=1000)

explanation_heatmap_transposed = """
* This **heatmap** visualizes the positional encoding values.
* **Columns:** Represent the position in the sequence (Position 0 at the left).
* **Rows:** Represent the dimension index within the embedding vector (Dimension 0 at the top).
* **Color:** Indicates the encoding value (ranging from -1 to 1, see color bar).
* **Takeaway:** Observe the distinct patterns. Horizontal bands indicate slower changing frequencies (early dimensions, top rows), while more rapid vertical changes/checkerboards indicate higher frequencies (later dimensions, bottom rows). Crucially, each column (position) has a unique pattern across the dimensions.
"""
display(Markdown(explanation_heatmap_transposed))
fig_heatmap_transposed.show()

In [None]:
# @title Computing the Positional Encoding Vectors
# Requires 'pos_encoding', 'embedding_dimension' from previous cells
import plotly.graph_objects as go

# --- Parameters ---
dims_to_show = 30 # How many dimensions to show
positions_to_show = [0, 2, 4] # Which position vectors to plot

# --- Create Plot ---
fig = go.Figure()

# Plot selected P_pos vectors
for pos in positions_to_show:
    fig.add_trace(go.Scatter(x=np.arange(dims_to_show), y=pos_encoding[pos, :dims_to_show],
                             mode='lines+markers', name=f'Positional Vector P_{pos}'))

# Update layout
fig.update_layout(
    title=f"Each Position Gets a Unique Vector",
    xaxis_title="Dimension Index", yaxis_title="Encoding Value",
    yaxis=dict(range=[-1.1, 1.1]), height=600, width=1200,
    legend=dict(
        yanchor="bottom",
        y=0.01,
        xanchor="right",
        x=0.98,
    )
)

explanation = f"""
* This plot shows the actual positional encoding vectors ({positions_to_show}) for the first few positions.
* Each line represents the values across the first {dims_to_show} dimensions for that specific position.
* **Takeaway:** Notice that the line shape (the vector's 'signature') is clearly different for each position. This uniqueness is essential. (These lines are like the first few rows of the heatmap).
"""
display(Markdown(explanation))
fig.show()

In [None]:
# @title Adding Positional Encodings to Embeddings

import numpy as np
import plotly.graph_objects as go
from IPython.display import display, Markdown

def get_positional_encoding(max_seq_len, d_model):
    pos = np.arange(max_seq_len)[:, np.newaxis]; i = np.arange(d_model)[np.newaxis, :]
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)); angle_rads = pos * angle_rates
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]); angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    return angle_rads

# --- Parameters ---
max_sequence_length = 10
embedding_dimension = 128 # Full dimension
vis_dimension = 2 # Only visualize first 2 dims for THIS cell's plot

# --- Simulate Word Embeddings (Clearer Dummy Data) ---
E_words = np.array([[1, 1], [1, 3], [3, 1], [3, 3]])
word_names = ["'cat'", "'sat'", "'on'", "'mat'"]
e_word_labels = [f"E_{name}" for name in word_names] # Labels for blue points

# --- Get Real Positional Encodings ---
# Calculate the FULL matrix and store it as 'pos_encoding'
pos_encoding = get_positional_encoding(max_sequence_length, embedding_dimension)

# Extract the first few vectors/dimensions needed just for this cell's visualization
P_vectors = pos_encoding[:4, :vis_dimension] # Use pos_encoding here
pos_labels = [f"P_{i}" for i in range(4)] # Generate P_0, P_1, ...

# --- Calculate Combined Vectors ---
Input_vectors = E_words + P_vectors
# Create more descriptive labels for the green points
input_labels = [f" {e_word_labels[i]}+{pos_labels[i]}" for i in range(len(word_names))]

# --- Create Plot ---
fig = go.Figure()

# Plot E_words
fig.add_trace(go.Scatter(x=E_words[:, 0], y=E_words[:, 1], mode='markers+text', name='Original Word Embedding (E)',
                         text=e_word_labels, textposition="middle right", marker=dict(color='blue', size=12)))

# Plot Combined Input_vectors using the new labels
fig.add_trace(go.Scatter(x=Input_vectors[:, 0], y=Input_vectors[:, 1], mode='markers+text', name='Modified Embedding (E+P)',
                         text=input_labels, textposition="middle right", marker=dict(color='green', size=12)))

# Add arrows showing the addition (E -> E+P)
for i in range(len(Input_vectors)):
    fig.add_annotation(ax=E_words[i, 0], ay=E_words[i, 1], x=Input_vectors[i, 0], y=Input_vectors[i, 1],
                       xref='x', yref='y', axref='x', ayref='y',
                       showarrow=True, arrowhead=2, arrowcolor='purple', arrowwidth=2)

# Update layout
fig.update_layout(
    title_text="Combining Word Meaning (E) with Position Signal (P)",
    xaxis_title="Dim 1", yaxis_title="Dim 2",
    legend=dict(yanchor="bottom", y=0.99, xanchor="left", x=0.01),
    height=500, width=600
)
fig.update_yaxes(scaleanchor='x', scaleratio=1)

# --- Explanation Section ---
explanation = """
**Takeaway:** The final input vectors (green) incorporate both original word meaning (blue) and a position-specific signal (purple vector).
* **Blue Points:** Represent the 'meaning' vectors (embeddings) for different words (e.g., E_'cat').
* **Green Points:** Represent the final vectors fed into the Transformer model.
    The label (e.g., `E_'cat'+P_0`) explicitly shows it's the sum of the word embedding (E) for 'cat' and the positional encoding for position 0 (P_0).
* **Purple Arrows:** Show the vector addition process (adding the specific P vector).
"""

display(Markdown(explanation))
fig.show()

---
# Core Concept #2: Attention

## 2.1 The Intuition Behind Attention

**Attention** is the mechanism that allows Transformers to focus on relevant information
when processing each word.

### Human Attention Analogy

When you read the sentence: "After running the marathon, she was so tired that **SHE** couldn't
even lift **HER** water bottle," you naturally understand:
- "**SHE**" refers to the marathon runner
- "**HER**" also refers to the same person

Your brain performs attention by connecting related parts of the sentence, even when they're
separated by several words.

## 2.2 Self-Attention - Letting Inputs Talk to Each Other
The heart of the Transformer is the **Self-Attention** mechanism. It allows the model, when processing one token (e.g., a word), to look at *all other tokens* in the *same* input sequence and determine how relevant they are.

This allows the model to understand context. For example, in the sentence "The **animal** didn't cross the street because **it** was too tired," self-attention helps the model understand that "**it**" refers to "**animal**".


**Analogy: Library Research**
Imagine you're researching a topic (**Query**). You go to a library catalogue and compare your query against the keywords or titles of books (**Keys**). When you find a strong match, you retrieve the actual book content (**Value**). Self-attention works similarly.


### Query, Key, Value (QKV)

For each input token (after adding positional encoding), the model learns to generate three distinct vectors:
1.  **Query (Q):** Represents the current token's "question" or what it's looking for.
 * "I am the word 'it', who might I refer to?"
2.  **Key (K):** Represents a token's "label" or "identifier" used for matching.
  * "I am the word 'animal', this is what I represent."
3.  **Value (V):** Represents the actual content or meaning of the token.
  * "I am the word 'animal', here's my semantic information."


### The Attention Process

1.  **Calculate Scores:**
  * The Query vector of the current token is compared with the Key vectors of *all* tokens in the sequence (including itself).
  * This comparison is typically done using a **dot product**, similar to how we saw matrix multiplication relates inputs and weights in MLPs.
  * A higher dot product result means a better match (higher relevance).
2.  **Normalize Scores (Softmax):** These raw scores are scaled (often divided by the square root of the key dimension) and then converted into probabilities using a **softmax** function. This ensures all scores for a given query sum to 1, representing the "distribution of attention".
3.  **Weighted Sum of Values:** The softmax scores are used as weights. Each Value vector is multiplied by its corresponding attention weight, and the results are summed up. Tokens with higher attention scores contribute more of their Value (meaning) to the final output representation for the current token.

This process happens *for every token* in the sequence, allowing each token to gather context from all other tokens based on learned relevanced



In [None]:
# @title Self-Attention Visualization (Improved Clarity - Layout & Legend Fixes)
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import traceback

# Function definition with fixes
def create_qkv_visualization_clearer(sentence, focus_word_index):
    """
    Create a clearer interactive visualization of Query-Key-Value attention mechanism.
    (Fixes for cutoff text, legend clutter)
    """
    words = sentence.split()
    n_words = len(words)
    embedding_dim = 4
    np.random.seed(42)
    # --- IMPORTANT NOTE ---
    # Embeddings and Weight matrices (W_q, W_k, W_v) are RANDOM.
    # In a real Transformer, these are learned during training.
    # Therefore, the attention patterns shown here are illustrative of the
    # *mechanism*, not necessarily semantically meaningful for the example sentence.
    # ---
    embeddings = np.random.randn(n_words, embedding_dim) * 0.5
    W_q = np.random.randn(embedding_dim, embedding_dim)
    W_k = np.random.randn(embedding_dim, embedding_dim)
    W_v = np.random.randn(embedding_dim, embedding_dim)
    Q = np.matmul(embeddings, W_q)
    K = np.matmul(embeddings, W_k)
    V = np.matmul(embeddings, W_v)
    query_vector = Q[focus_word_index]
    scores = np.matmul(query_vector, K.T)
    scaled_scores = scores / np.sqrt(embedding_dim)
    exp_scores = np.exp(scaled_scores - np.max(scaled_scores))
    attention_weights = exp_scores / exp_scores.sum()
    weighted_values = V * attention_weights[:, np.newaxis]
    output = weighted_values.sum(axis=0)

    # --- Create Visualization ---
    fig = make_subplots(
        rows=3, cols=1,
        subplot_titles=(
            f"<b>Step 1: Calculate Scores</b> (Query '<i>{words[focus_word_index]}</i>' dot Keys)", # Slightly shortened
            f"<b>Step 2: Calculate Attention Weights</b> (Softmax of Scores)", # Slightly shortened
            f"<b>Step 3: Combine Values using Weights</b> (Weighted Sum)"
        ),
        vertical_spacing=0.18, row_heights=[0.3, 0.3, 0.4]
    )
    focus_color = 'rgba(220, 50, 50, 0.8)' # Red
    other_color = 'rgba(50, 100, 200, 0.7)' # Blue
    bar_colors = [other_color] * n_words
    bar_colors[focus_word_index] = focus_color

    # --- Plot 1: Scores ---
    # Add trace WITHOUT name/legend entry
    fig.add_trace(go.Bar(
            x=words,
            y=scaled_scores,
            # name="Scaled Scores", # REMOVED name to declutter legend
            showlegend=False,      # Explicitly hide from legend
            text=[f"{s:.2f}" for s in scaled_scores],
            textposition="outside",
            marker_color=bar_colors,
            hovertemplate="<b>Word (Key):</b> %{x}<br><b>Scaled Score:</b> %{y:.3f}<extra></extra>"
        ),
        row=1, col=1
    )

    # --- Plot 2: Weights ---
     # Add trace WITHOUT name/legend entry
    fig.add_trace(go.Bar(
            x=words,
            y=attention_weights,
            # name="Attention Weights", # REMOVED name to declutter legend
            showlegend=False,         # Explicitly hide from legend
            text=[f"{w:.2f}" for w in attention_weights],
            textposition="outside", # Weights are 0-1, outside usually works
            marker_color=bar_colors,
            hovertemplate="<b>Word:</b> %{x}<br><b>Attention Weight:</b> %{y:.3f}<extra></extra>"
        ),
        row=2, col=1
    )

    # --- Plot 3: Weighted Values (Stacked) & Output ---
    dim_labels = [f"Dim {j+1}" for j in range(embedding_dim)]
    # Add weighted value bars WITHOUT legend entries
    for i, word in enumerate(words):
        fig.add_trace(go.Bar(
                x=dim_labels,
                y=weighted_values[i], # Plot weighted values
                name=f"Weighted V: {word}", # Keep name for hover info
                showlegend=False,           # Explicitly hide from legend
                marker_color=bar_colors[i],
                hovertemplate=(f"<b>Word:</b> {word}<br>" +
                               f"<b>Attn W:</b> {attention_weights[i]:.3f}<br>" +
                               f"<b>Dim:</b> %{{x}}<br>" +
                               f"<b>Weighted V:</b> %{{y:.3f}}<extra></extra>")
            ),
            row=3, col=1
        )
    # Add Output Line/Markers - THIS WILL BE IN THE LEGEND
    fig.add_trace(go.Scatter(
            x=dim_labels,
            y=output,
            mode='markers+lines',
            marker=dict(size=12, color='black', symbol='star'),
            line=dict(width=4, color='black', dash='dashdot'),
            name="Output (Sum of Weighted Values)", # Keep name for legend
            hovertemplate=("<b>Output Vector</b><br>" +
                           "<b>Dim:</b> %{x}<br>" +
                           "<b>Value:</b> %{y:.3f}<extra></extra>")
        ),
        row=3, col=1
    )

    # --- Update Layout ---
    fig.update_layout(
        title={
            'text': f"<b>Self-Attention Visualization: Focus Word = '{words[focus_word_index]}'</b>"
                    f"<br>(Focus Word: <span style='color:{focus_color};'>Red</span>, Other Words: <span style='color:{other_color};'>Blue</span>)" # Color key in title
                    f"<br><sup><i>Note: Embeddings & Weights are random.</sup>",
            'x': 0.5, 'y': 0.98, 'xanchor': 'center', 'yanchor': 'top'
        },
        width=900, height=950,
        barmode='stack', # Apply 'stack' globally - primarily affects Plot 3
        bargap=0.3,
        # showlegend=True, # Default is True, only named traces with showlegend=True appear
        legend=dict(
            traceorder='reversed', # Show Output first if multiple items were present
            itemsizing='constant'  # Keep legend marker size constant
        ),
        hovermode='x unified'
    )

    # --- Update Axes ---
    # Calculate dynamic range for Plot 1 y-axis to prevent cutoff text
    min_score = np.min(scaled_scores)
    max_score = np.max(scaled_scores)
    score_range = max_score - min_score
    # Add buffer (e.g., 15% of range, or a minimum buffer if range is tiny)
    y_buffer_plot1 = max(score_range * 0.15, 0.1)
    y_range_plot1 = [min_score - y_buffer_plot1, max_score + y_buffer_plot1]

    fig.update_yaxes(title_text="Scaled Score", range=y_range_plot1, row=1, col=1) # Apply calculated range
    fig.update_yaxes(title_text="Attention Weight", range=[0, 1.05], row=2, col=1) # Range 0-1, small buffer for text
    fig.update_yaxes(title_text="Weighted Value", row=3, col=1) # Auto-range for stacked bars

    fig.update_xaxes(title_text="Word (Key Source)", row=1, col=1)
    fig.update_xaxes(title_text="Word (Value Source)", row=2, col=1)
    fig.update_xaxes(title_text="Embedding Dimension", row=3, col=1)

    fig.update_layout(hoverlabel=dict(bgcolor="white", font_size=12))

    return fig


# === Example Execution Code (No Changes Needed Here) ===

# === Example 1: Focusing on 'it' ===
print("--- Running Example 1 ('it') ---")
# (Rest of the execution code remains the same as the previous working version)
example = "The animal didn't cross the street because it was too tired."
words = example.split()
focus_word_1 = "it"
it_index = -1
try:
    it_index = words.index(focus_word_1)
    print(f"Found '{focus_word_1}' at index {it_index}.")
except ValueError:
    print(f"ERROR: Word '{focus_word_1}' not found in the sentence: '{example}'")

if it_index != -1:
    try:
        print(f"Visualizing attention for word '{words[it_index]}' (index {it_index})...")
        fig = create_qkv_visualization_clearer(example, it_index)
        fig.show()
        print(f"Example 1 visualization displayed.")
    except Exception as e:
        print(f"\n!!! ERROR during visualization/display for Example 1 ('{focus_word_1}') !!!")
        print(f"Error Type: {type(e).__name__}")
        print(f"Error Details: {e}")
        print("Traceback:")
        traceback.print_exc()
else:
    print("Skipping visualization for Example 1 because word was not found.")
print("--- Finished Example 1 ---")


# === Example 2: Focusing on 'cat' ===
# (Uncomment to run)
print("\n--- Running Example 2 ('cat') ---")
example2 = "The big red dog chased the small cat."
words2 = example2.split()
focus_word_2 = "cat"
cat_index = -1
actual_word_found = None
try:
    actual_word_found = next((w for w in words2 if focus_word_2 in w), None)
    if actual_word_found:
        cat_index = words2.index(actual_word_found)
        print(f"Found '{actual_word_found}' (containing '{focus_word_2}') at index {cat_index}.")
    else:
        print(f"ERROR: Word containing '{focus_word_2}' not found in the sentence: '{example2}'")
except ValueError:
     print(f"ERROR: Value error occurred trying to find index for '{actual_word_found}' in sentence: '{example2}'")
except Exception as e:
     print(f"\n!!! UNEXPECTED ERROR during word search for Example 2 ('{focus_word_2}') !!!")
     print(f"Error Type: {type(e).__name__}")
     print(f"Error Details: {e}")
     traceback.print_exc()

if cat_index != -1 and actual_word_found:
    try:
        print(f"Visualizing attention for word '{actual_word_found}' (index {cat_index})...")
        fig2 = create_qkv_visualization_clearer(example2, cat_index)
        fig2.show()
        print(f"Example 2 visualization displayed.")
    except Exception as e:
        print(f"\n!!! ERROR during visualization/display for Example 2 ('{actual_word_found}') !!!")
        print(f"Error Type: {type(e).__name__}")
        print(f"Error Details: {e}")
        print("Traceback:")
        traceback.print_exc()
else:
     print("Skipping visualization for Example 2 because word was not found or index search failed.")
# print("--- Finished Example 2 ---")

## 2.3 Multi-Head Attention: Different Perspectives 🧐🤯🤓
Transformers don't just calculate attention once. They use **Multi-Head Attention**.

**The Idea:** Instead of having one set of Q, K, V weight matrices, Multi-Head Attention has multiple sets (e.g., 8 or 12 "heads"). Each head learns a *different* set of Wq, Wk, Wv matrices.

1.  The input embedding (with positional encoding) is passed through each head independently.
2.  Each head performs the self-attention calculation (QKV dot products, softmax, weighted sum) in parallel, potentially focusing on different types of relationships or aspects of the sequence.
3.  The output vectors from all heads are concatenated together.
4.  This concatenated vector is passed through one final linear layer (matrix multiplication) to produce the final output of the Multi-Head Attention block.

> **Analogy: Expert Panel** 🧑‍🏫👩‍🔬👨‍🎨
> Imagine analyzing a complex sentence. Instead of one person trying to understand everything, you have a panel of experts (attention heads):
> * One expert focuses on grammatical dependencies (subject-verb, pronoun references).
> * Another focuses on semantic similarity (synonyms, related concepts).
> * Another might focus on positional relationships ("the word before X").
> Each head provides its own contextual understanding based on its learned specialty. Combining their outputs gives a much richer, multi-faceted representation than a single attention calculation could.

Multi-Head Attention allows the model to jointly attend to information from different representation subspaces at different positions.


In [None]:
#@title Attention visualizer (expand to see code)

import warnings
warnings.filterwarnings("ignore")

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from IPython.display import display, HTML, clear_output, Markdown
import torch
import ipywidgets as widgets
import seaborn as sns

def get_attention_patterns(model, tokenizer, text):
    """
    Extract attention patterns from the model for visualization.
    """
    # Tokenize input
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Get model outputs with attention
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # Extract attention weights and convert to numpy
    attention_weights = []
    for layer_attentions in outputs.attentions:
        layer_attentions = layer_attentions.squeeze(0)
        layer_attentions = layer_attentions.cpu().numpy()
        attention_weights.append(layer_attentions)

    # Get tokens
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    return attention_weights, tokens

def create_attention_visualizer(model, tokenizer, examples=None):
    # Create informative explanation of attention visualization
    explanation = """
    <div style="background-color: #1e1e1e; padding: 20px; border-radius: 8px; margin-bottom: 20px; color: #e1e1e1;">
        <h3 style="color: #00b4d8; margin-top: 0;">How to Read This Attention Heatmap</h3>

        <p>This visualization shows how each token (word piece) pays attention to other tokens and itself when processing text:</p>

        <ul style="margin-bottom: 15px;">
            <li><strong style="color: #00b4d8;">Reading the Grid:</strong> Each row shows how much attention a token pays to previous tokens and itself (columns)</li>
            <li><strong style="color: #00b4d8;">Numbers & Colors:</strong> Darker blue and higher numbers (0-1) indicate stronger attention</li>
            <li><strong style="color: #00b4d8;">Blank Upper Area:</strong> Tokens can only attend to themselves and previous tokens, so the upper triangle is always blank</li>
        </ul>

        <p><strong style="color: #00b4d8;">Example Pattern:</strong> In the sentence "The cat sits", you might see:</p>
        <ul>
            <li>"cat" paying strong attention to "The" (article-noun relationship)</li>
            <li>"sits" paying attention to "cat" (subject-verb relationship)</li>
            <li>Each token typically paying some attention to itself</li>
        </ul>
    </div>
    """

    if examples is None:
        examples = {
            "Complex Relationship": "The dog chased its tail because it was wagging",
            "Question-Answer": "Q: What is the capital of France? A: Paris",
            "Local Grammar": "The red and blue car",
            "Grammar Pattern": "The red and blue car drove fast",
            "Completion": "The students studied hard for their final",
            "Simple Example": "The cat sits on the mat",
            "Comparison": "Although it was expensive, the quality was excellent",
        }

    # Create widgets
    text_input = widgets.Text(
        value='',
        placeholder='Enter custom text...',
        description='Input:',
        layout=widgets.Layout(width='80%')
    )

    examples_dropdown = widgets.Dropdown(
        options=examples,
        description='Examples:',
        layout=widgets.Layout(width='80%')
    )

    layer_dropdown = widgets.Dropdown(
        options=[(f'Layer {i}', i) for i in range(32)],
        value=0,
        description='Layer:',
        layout=widgets.Layout(width='200px')
    )

    head_dropdown = widgets.Dropdown(
        options=[(f'Head {i}', i) for i in range(32)],
        value=0,
        description='Head:',
        layout=widgets.Layout(width='200px')
    )

    viz_output = widgets.Output()
    rec_output = widgets.Output()

    def update_display(change=None):
        """Update both visualization and recommendations"""
        text = text_input.value if text_input.value else examples_dropdown.value
        layer = layer_dropdown.value
        head = head_dropdown.value

        attention_weights, tokens = get_attention_patterns(model, tokenizer, text)

        with rec_output:
            clear_output(wait=True)
            interesting_patterns = find_interesting_patterns(attention_weights, tokens)
            print("\nRecommended interesting patterns to explore:")
            for l, h, score, reason in interesting_patterns:
                print(f"Layer {l}, Head {h}: {reason} (score: {score:.2f})")

        with viz_output:
            clear_output(wait=True)

            n_tokens = len(tokens)
            fig_size = max(6, n_tokens * 0.5)
            plt.figure(figsize=(fig_size, fig_size))

            # Create mask for upper triangle
            mask = np.triu(np.ones_like(attention_weights[layer][head]), k=1)

            # Plot heatmap with mask
            sns.heatmap(attention_weights[layer][head],
                       xticklabels=tokens,
                       yticklabels=tokens,
                       cmap='Blues',
                       center=0.5,
                       square=True,
                       fmt='.2f',
                       annot=True,
                       mask=mask,  # Apply mask to hide upper triangle
                       annot_kws={'size': 8},
                       cbar_kws={'shrink': .8})

            plt.title(f'Attention Pattern (Layer {layer}, Head {head})')
            plt.xticks(rotation=45, ha='right')
            plt.yticks(rotation=0)
            plt.tight_layout()
            plt.show()

    # Connect callbacks
    text_input.observe(update_display, names='value')
    examples_dropdown.observe(update_display, names='value')
    layer_dropdown.observe(update_display, names='value')
    head_dropdown.observe(update_display, names='value')

    # Create layout with dropdowns side by side
    controls = widgets.VBox([
        widgets.HTML(explanation),  # Add explanation at the top
        examples_dropdown,
        text_input,
        widgets.HBox([layer_dropdown, head_dropdown])
    ])

    # Display everything
    display(widgets.HTML("<h2>Attention Pattern Visualizer</h2>"))
    display(controls)
    display(viz_output)
    display(rec_output)

    # Show initial visualization
    update_display()

#@title Function to find interesting patterns in attention heatmaps
def find_interesting_patterns(attention_weights, tokens, top_k=5):
    scores = []
    n_layers = len(attention_weights)
    n_heads = attention_weights[0].shape[0]

    for layer in range(n_layers):
        for head in range(n_heads):
            matrix = attention_weights[layer][head]

            # Look for patterns where first token gets consistent attention
            first_token_pattern = np.mean(matrix[:, 0]) > 0.5

            # Look for adjective-noun patterns (decreasing attention)
            decreasing_pattern = np.all(np.diff(matrix.mean(axis=1)) < 0.1)

            # Look for distributed attention in nouns
            last_token_distribution = matrix[-1].std() < 0.3 and matrix[-1].mean() > 0.1

            # Calculate linguistic pattern score
            linguistic_score = (
                first_token_pattern * 0.4 +
                decreasing_pattern * 0.3 +
                last_token_distribution * 0.3
            )

            # Determine pattern type and score
            if linguistic_score > 0.6:
                reason = "Shows grammatical structure patterns"
                score = linguistic_score
            elif first_token_pattern:
                reason = "Shows article-word relationships"
                score = linguistic_score
            else:
                # Calculate general interest metrics as fallback
                attention_spread = len(matrix[matrix > 0.05]) / matrix.size
                peak_contrast = np.max(matrix) - np.mean(matrix)
                weighted_distances = np.mean(np.abs(np.arange(len(tokens))[:, None] - np.arange(len(tokens))) * matrix)

                score = (
                    attention_spread * 0.4 +
                    peak_contrast * 0.3 +
                    weighted_distances * 0.3
                )

                if attention_spread > 0.3:
                    reason = "Shows distributed attention"
                elif peak_contrast > 0.4:
                    reason = "Shows focused attention peaks"
                elif weighted_distances > len(tokens)/3:
                    reason = "Shows long-range connections"
                else:
                    continue

            scores.append((layer, head, score, reason))

    # Sort by score and return top_k unique patterns
    scores.sort(key=lambda x: x[2], reverse=True)
    seen_reasons = set()
    filtered_scores = []
    for score in scores:
        if score[3] not in seen_reasons and len(filtered_scores) < top_k:
            filtered_scores.append(score)
            seen_reasons.add(score[3])

    return filtered_scores


model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
examples = {
    "Grammar Pattern": "The red and blue car drove fast",
    "Complex Relationship": "The dog chased its tail because it was wagging",
    "Question-Answer": "Q: What is the capital of France? A: Paris",
    "Completion": "The students studied hard for their final",
    "Simple Example": "The cat sits on the mat",
    "Comparison": "Although it was expensive, the quality was excellent",
}

create_attention_visualizer(model, tokenizer, examples=examples)

---
# Core Concept #3: The Transformer Architecture

Now that we understand positional encoding and the powerful attention mechanism, let's assemble these pieces into the full Transformer architecture. Transformers are typically built by stacking multiple identical layers, often referred to as **Transformer Blocks**

There is already a fantastic and full-featured visualization for the Transformer: https://poloclub.github.io/transformer-explainer/.

## 3.1 The Building Block: A Single Transformer Layer

Think of a Transformer model (like BERT or GPT) as a skyscraper built from many identical floors. Each "floor" is a Transformer Block (or Layer) that processes the sequence representation passed up from the floor below. A standard Transformer block typically contains two main sub-layers:

1.  **Multi-Head Self-Attention Sub-layer:**
    *   This is where the core attention mechanism happens (as discussed in Core Concept #2).
    *   Input tokens (with positional encodings) generate Queries, Keys, and Values.
    *   Multiple "heads" calculate attention scores in parallel, allowing the model to focus on different types of relationships simultaneously.
    *   The outputs of the heads are combined to produce an output representation for each token that incorporates context from the entire sequence.

2.  **Position-wise Feed-Forward Network (FFN) Sub-layer:**
    *   After the attention mechanism gathers context, this sub-layer processes each token's representation *independently*.
    *   It's usually a simple two-layer fully-connected network (Linear -> ReLU Activation -> Linear).
    *   Think of it as further "thinking" about the context-rich representation derived from the attention step for each position separately.
    *   *Important:* The *same* FFN (same weights) is applied to each position, but it acts on each position's vector one by one.

**Connecting the Sub-layers: Add & Norm**

Crucially, after *each* of these two sub-layers, an **Add & Norm** step is applied:

*   **Add (Residual Connection):** The input *to* the sub-layer (e.g., the input to the Multi-Head Attention) is added directly to the output *of* that sub-layer. This creates a "shortcut" or "residual connection".
    *   *Why?* This simple addition makes it much easier to train very deep networks. It allows information to bypass a layer if needed and helps gradients flow backwards during training without vanishing.
*   **Norm (Layer Normalization):** After the addition, Layer Normalization is applied. It rescales the values within each token's vector independently across the feature dimension.
    *   *Why?* This helps stabilize the learning process, making the model less sensitive to the scale of parameters and gradients, often leading to faster convergence.

So, the flow through one complete Transformer Block looks like this:
`Input -> Multi-Head Attention -> Add & Norm -> Feed-Forward Network -> Add & Norm -> Output`

## 3.2 Stacking Blocks: Encoders and Decoders

These Transformer Blocks are stacked one on top of another to form the main components of the architecture: the Encoder and the Decoder.

### The Encoder Stack

*   **Purpose:** To read and "understand" the input sequence, creating a rich, context-aware representation.
*   **Structure:** A stack of N identical Encoder Blocks (e.g., N=6 or 12).
*   **Input:** The sequence of input embeddings (word embeddings + positional encodings).
*   **Process:** The input flows through the stack. Each Encoder Block applies self-attention (allowing all input tokens to attend to each other) and feed-forward layers as described above. The output of one block becomes the input to the next.
*   **Output:** A sequence of context-rich vectors, one for each input token, representing the "meaning" of the input sequence.



### The Decoder Stack

*   **Purpose:** To generate an output sequence (e.g., a translation, a summary, or the next word in a sentence), often conditioned on the Encoder's output.
*   **Structure:** A stack of N identical Decoder Blocks.
*   **Input:** Typically takes the previously generated output token embeddings (plus positional encodings) and the final output from the Encoder stack.
*   **Process:** Decoder Blocks are slightly different from Encoder Blocks. They have *three* main sub-layers (each followed by Add & Norm):
    1.  **Masked Multi-Head Self-Attention:** The Decoder performs self-attention on the sequence it has generated *so far*. The "Masking" is crucial: it prevents a position from attending to *future* positions. This is essential because during generation, the model should only use the words it has already produced, not the "correct" future words.
    2.  **Multi-Head Cross-Attention:** This is where the Decoder interacts with the Encoder's output. The *Queries* come from the Decoder's Masked Self-Attention output, while the *Keys* and *Values* come from the final output of the *Encoder stack*. This allows the Decoder to focus on relevant parts of the *input* sequence while generating the *output* sequence. (e.g., when translating, focus on the relevant source words).
    3.  **Position-wise Feed-Forward Network:** Identical in function to the FFN in the Encoder block, processing the output of the cross-attention step for each position.
*   **Output:** After the final Decoder Block, a linear layer and a softmax function are typically used to predict the probabilities for the *next* token in the output sequence.

In [None]:
# @title Plotly Visualization: Attention Masking Heatmap
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots as sp # Import subplots

# Create a dummy attention score matrix (e.g., for a sequence of 5 tokens)
# These would normally be the result of Q dot K^T
np.random.seed(42)
seq_len_mask = 5
attention_scores = np.random.randn(seq_len_mask, seq_len_mask)

# Create the mask (lower triangular matrix including diagonal allows attention)
# Mask value set for entries we want to *ignore* (upper triangle)
mask_value = -1e9 # Use a large negative number for masking
# np.triu -> Upper triangle. k=1 means exclude the diagonal.
mask = np.triu(np.ones((seq_len_mask, seq_len_mask)) * mask_value, k=1)

# Apply the mask - Add mask to scores. Where mask is 0, scores are unchanged.
# Where mask is -1e9, scores become very small negative numbers.
masked_scores = attention_scores + mask

# --- Visualization ---
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Raw Attention Scores", "Masked Attention Scores (before Softmax)"))

# Plot Raw Scores
heatmap_raw = go.Heatmap(
    z=attention_scores,
    x=[f'Key {i+1}' for i in range(seq_len_mask)],
    y=[f'Query {i+1}' for i in range(seq_len_mask)],
    colorscale='viridis',
    zmin=-3, zmax=3, # Consistent color range
    text=np.round(attention_scores, 2),
    texttemplate="%{text}",
    showscale=False,
    name="Raw Scores"
)
fig.add_trace(heatmap_raw, row=1, col=1)

# Plot Masked Scores
# Create text labels: show score where not masked, '-inf' where masked
masked_text = np.full_like(masked_scores, "", dtype=object)
not_masked_indices = (mask != mask_value)
masked_indices = (mask == mask_value)
masked_text[not_masked_indices] = np.round(masked_scores[not_masked_indices], 2)
masked_text[masked_indices] = "-inf" # Indicate masked cells visually

heatmap_masked = go.Heatmap(
    # Use original scores for coloring non-masked, and a distinct value (NaN) for masked
    z=np.where(masked_indices, np.nan, attention_scores),
    x=[f'Key {i+1}' for i in range(seq_len_mask)],
    y=[f'Query {i+1}' for i in range(seq_len_mask)],
    colorscale='viridis', # Same scale for comparison
    zmin=-3, zmax=3,
    text=masked_text,
    texttemplate="%{text}",
    hoverongaps=False, # Don't show hover info for masked cells
    showscale=True,
    colorbar_title="Score",
    name="Masked Scores",
    coloraxis="coloraxis" # Link color axis if needed, though separate is fine here
)
# Explicitly set the color for NaN (masked) gaps
heatmap_masked.update(xgap=1, ygap=1, connectgaps=False) # Add gaps visually
fig.layout.coloraxis.colorbar.title = 'Score'
fig.layout.coloraxis.colorscale = 'Viridis'
fig.layout.coloraxis.cmin = -3
fig.layout.coloraxis.cmax = 3


fig.add_trace(heatmap_masked, row=1, col=2)


fig.update_layout(
    title_text="Masked Self-Attention: Preventing Future Peeking",
    height=400, width=850,
    # Explicitly set background for masked cells if desired (e.g., light grey)
    # This requires more complex setup or post-processing usually.
    # Using NaN and connectgaps=False provides a visual distinction.
    plot_bgcolor='white'
)
# Ensure y-axis is reversed to match matrix layout (Query 1 at top)
fig.update_yaxes(autorange="reversed")
fig.show()

# print("\nExplanation:") # Keep explanation outside the code cell if preferred
# print("* Left Heatmap: Raw scores calculated by comparing each Query (row) to each Key (column).")
# print("* Right Heatmap: Mask applied. Scores in the upper right (Query i attending to Key j where j > i) are set to negative infinity ('-inf' text).")
# print("* Effect: After softmax, these '-inf' scores become 0, meaning a token cannot attend to subsequent tokens.")

*Explanation:*
*   *Left Heatmap:* Raw scores calculated by comparing each Query (row) to each Key (column).
*   *Right Heatmap:* Mask applied. Scores in the upper right (where a Query tries to attend to a Key further down the sequence) are effectively set to negative infinity (visualized as gaps / "-inf" text).
*   *Effect:* After the softmax function is applied to these masked scores, the masked positions will have a probability of zero, ensuring tokens only attend to themselves and previous tokens.

**Visualization: Decoder Stack with Cross-Attention**

## 3.3 Common Architectures & Use Cases

Based on these components, three main types of Transformer architectures are prevalent:

1.  **Encoder-Only Models (e.g., BERT, RoBERTa):**
    *   Use only the Encoder stack.
    *   Excellent at tasks requiring deep understanding of the input text.
    *   **Applications:** Text classification, sentiment analysis, named entity recognition, question answering (where the answer is extracted from the context).

2.  **Decoder-Only Models (e.g., GPT series, LLaMA, Gemini):**
    *   Use only the Decoder stack (with its masked self-attention).
    *   Excellent at generating coherent text following a prompt.
    *   **Applications:** Text generation, autocompletion, chatbots, creative writing, summarization (as a generation task).

3.  **Encoder-Decoder Models (e.g., Original Transformer "Attention Is All You Need", T5, BART):**
    *   Use both stacks connected via the cross-attention mechanism.
    *   Designed for sequence-to-sequence tasks where an input sequence needs to be transformed into an output sequence.
    *   **Applications:** Machine translation, text summarization (as a transformation task), question answering (generating the answer).

## 3.4 What Makes Transformers So Powerful? (Summary)

The architecture combines several key innovations:

1.  **Parallelism**: Processes the entire sequence at once (within layers).
2.  **Self-Attention**: Directly models relationships between words regardless of distance.
3.  **Positional Encoding**: Preserves sequence order information despite parallel processing.
4.  **Multi-Head Attention**: Captures different types of relationships simultaneously.
5.  **Deep Architecture**: Stacks multiple layers (enabled by Add & Norm) for sophisticated processing.
6.  **Contextualization**: Cross-attention (in Encoder-Decoder models) effectively links input and output sequences.

This design allows Transformers to capture the complex patterns and long-range dependencies in data like human language, leading to their state-of-the-art performance on many tasks.

## Section 4: Impact and Conclusion 🌟
---
The Transformer architecture has fundamentally changed the landscape of AI, especially in Natural Language Processing (NLP).

**Key Advantages:**
* **Parallelism:** Processes sequences in parallel, leading to significantly faster training times compared to RNNs on suitable hardware.
* **Long-Range Dependencies:** Self-attention allows direct interaction between any two positions in the sequence, making it much better at capturing long-range context than traditional RNNs.
* **State-of-the-Art Performance:** Transformers form the basis of models that have achieved top results across a vast range of NLP tasks, including:
    * Machine Translation (e.g., Google Translate improvements)
    * Text Summarization
    * Question Answering
    * Text Generation (e.g., GPT, Gemini)
    * Language Understanding (e.g., BERT)
* **Transfer Learning:** Pre-trained Transformer models (like BERT or GPT) can be fine-tuned on specific downstream tasks with relatively small amounts of data, achieving excellent performance.
* **Beyond NLP:** The core ideas, particularly attention, have been successfully adapted to other domains like computer vision (Vision Transformers - ViT), audio processing, and even biology.

**Conclusion:**
We've seen how Transformers move beyond the limitations of sequential processing by leveraging parallel computation and the powerful **Self-Attention** mechanism. By allowing every part of the input to directly attend to every other part, and using techniques like **Positional Encoding** to retain order information and **Multi-Head Attention** to capture diverse relationships, Transformers build deep contextual understanding. The **Encoder-Decoder** structure provides a flexible framework for various sequence-to-sequence tasks.

While the internal workings involve complex linear algebra and learned parameters, the core concepts of parallel processing and attention provide an intuition for why these models are so effective. They represent a significant step forward in our ability to build AI that can understand and generate complex, sequential data like human language.

This concludes our planned sessions on core AI architectures. We've journeyed from individual neurons to MLPs, CNNs, RNNs, and now Transformers, seeing how different ways of connecting these components lead to specialized capabilities.


## Glossary
---
* **Transformer:** A neural network architecture based on self-attention mechanisms, processing sequences in parallel.
* **Self-Attention:** A mechanism allowing inputs to interact with each other and compute weights (attention scores) indicating their relevance to each other.
* **Query (Q):** In attention, a representation of the current token used to score against keys.
* **Key (K):** In attention, a representation of a token used to be matched against queries.
* **Value (V):** In attention, a representation of a token's content, weighted by attention scores to produce the output.
* **Multi-Head Attention:** Performing the attention mechanism multiple times in parallel with different learned linear projections (different sets of Q, K, V) and concatenating the results.
* **Positional Encoding:** Information added to input embeddings to provide the model with knowledge of the token's position in the sequence.
* **Encoder (Transformer):** The part of the Transformer that processes the input sequence to build contextualized representations. Contains self-attention and feed-forward layers.
* **Decoder (Transformer):** The part of the Transformer that generates the output sequence, using masked self-attention and cross-attention to the encoder output.
* **Masked Self-Attention:** Self-attention where future positions in the sequence are masked out, typically used in decoders during training to prevent looking ahead.
* **Cross-Attention:** Attention mechanism where queries come from one sequence (e.g., decoder) and keys/values come from another (e.g., encoder output).
* **Add & Norm:** A layer combining a residual (shortcut) connection with layer normalization, used to stabilize training in deep networks.
* **Layer Normalization:** Normalizes activations across the feature dimension for each sequence element independently.
* **Position-wise Feed-Forward Network (FFN):** An MLP applied independently to each position in the sequence within a Transformer layer.
* **Softmax:** A function that converts a vector of raw scores into a probability distribution (values between 0 and 1 that sum to 1).
* **Embedding:** A learned vector representation of a discrete item (like a word or token).


## Additional Resources
---
* **Original Paper:** Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). [Attention is all you need](https://arxiv.org/abs/1706.03762). Advances in neural information processing systems, 30.
* **Illustrated Transformer:** Jay Alammar's blog post provides excellent visualizations and explanations: [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
* **Illustrated BERT:** Jay Alammar's explanation of BERT (an Encoder-only Transformer): [The Illustrated BERT, ELMo, and co.](http://jalammar.github.io/illustrated-bert/)
* **Illustrated GPT-2:** Jay Alammar's explanation of GPT-2 (a Decoder-only Transformer): [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
* **Hugging Face Blog/Course:** Hugging Face provides extensive resources and tutorials on Transformers: [Hugging Face Blog](https://huggingface.co/blog), [Hugging Face Course](https://huggingface.co/course)


License Information
---
MIT License

Copyright (c) 2025 Pate Motter

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, srublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.