# Understanding Attention vs Non-Attention Processing
## MIS 769 - Advanced Big Data Analytics
### Prof. Richard Young (ryoung@unlv.edu)

This notebook shows the difference between:
- **Non-attention processing**: each token is transformed independently
- **Attention processing**: each token can use information from all other tokens

You will run:
1. A toy non-attention layer
2. A toy attention layer
3. A real BERT attention example
4. A side-by-side visualization


## What To Look For

**Goal:** Compare independent token processing (non-attention) with context-aware token processing (attention).

As you run this notebook, focus on these checks:
1. In the **non-attention output**, each token row is transformed without using other token rows.
2. In the **attention weights heatmap**, each row shows how one token distributes focus across all tokens (row values sum to ~1).
3. In the **attention output**, token representations change based on surrounding tokens.
4. In the **BERT [CLS] attention plot**, look for which tokens receive stronger attention and discuss why those tokens might matter semantically.

Teaching takeaway: attention is the mechanism that injects context into token representations.


In [None]:
# Colab/Jupyter setup
import importlib.util
import logging
import os
import subprocess
import sys

# Public models do not need HF_TOKEN; disable implicit token lookup noise.
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
logging.getLogger("huggingface_hub.utils._http").setLevel(logging.ERROR)

missing = []
for pkg in ["transformers", "seaborn", "matplotlib"]:
    if importlib.util.find_spec(pkg) is None:
        missing.append(pkg)

if missing:
    subprocess.check_call([
        sys.executable,
        "-m",
        "pip",
        "install",
        "-q",
        "transformers>=4.45.0",
        "seaborn",
        "matplotlib",
    ])
    print("Installed:", ", ".join(missing))
else:
    print("Dependencies already installed.")

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

torch.manual_seed(103)
np.random.seed(103)
print("Seed set to 103")



In [None]:
# Example 1: Non-Attention vs Example 2: Attention
class NonAttentionExample(nn.Module):
    def __init__(self, input_dim=4):
        super().__init__()
        self.linear = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        # Each token is transformed independently.
        return self.linear(x)


class SimpleAttentionExample(nn.Module):
    def __init__(self, input_dim=4):
        super().__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        scores = torch.matmul(Q, K.transpose(0, 1))
        scores = scores / (K.size(-1) ** 0.5)
        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output, attention_weights


print("=== Non-Attention Example ===")
x = torch.tensor([
    [1.0, 0.5, 0.8, 0.3],
    [0.2, 0.9, 0.4, 0.7],
    [0.6, 0.3, 0.2, 0.8],
    [0.1, 0.5, 0.9, 0.4],
], dtype=torch.float32)

non_attention_model = NonAttentionExample()
non_attention_output = non_attention_model(x)

print("Input tokens:")
print(x)
print("\nOutput (independent processing):")
print(non_attention_output)

print("\n=== Attention Example ===")
attention_model = SimpleAttentionExample()
attention_output, attention_weights = attention_model(x)

print("Attention weights (token-to-token influence):")
print(attention_weights)
print("\nOutput (context-aware processing):")
print(attention_output)


In [None]:
# Example 3: Real BERT attention on a sentence
from transformers import AutoModel, AutoTokenizer


def get_bert_cls_attention(sentence, model_id="bert-base-uncased"):
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=False)
    model = AutoModel.from_pretrained(
        model_id,
        token=False,
        attn_implementation="eager",  # Required to ensure output_attentions works reliably.
    )
    model.eval()

    inputs = tokenizer(sentence, return_tensors="pt")
    model_device = next(model.parameters()).device
    inputs = {k: v.to(model_device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    if outputs.attentions is None:
        raise RuntimeError(
            "No attention tensors were returned. Ensure attn_implementation='eager' and rerun the cell."
        )

    first_layer_attention = outputs.attentions[0]  # [batch, heads, seq, seq]
    cls_attention = first_layer_attention[0, 0, 0].detach().float().cpu().numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].detach().cpu())

    print("\n=== BERT Attention Example ===")
    print(f"Sentence: {sentence}")
    print(f"Tokens: {tokens}")
    print("\n[CLS] attention to each token:")
    for token, weight in zip(tokens, cls_attention):
        print(f"{token}: {weight:.4f}")

    return tokens, cls_attention


sentence = "The cat sat on the mat."
bert_tokens, bert_cls_attention = get_bert_cls_attention(sentence)



In [None]:
def visualize_attention_comparison(non_attention_output, attention_weights, bert_tokens, bert_attention):
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6))

    sns.heatmap(
        non_attention_output.detach().cpu().numpy(),
        ax=ax1,
        cmap="RdBu_r",
        center=0,
        annot=True,
        fmt=".2f",
        xticklabels=["F1", "F2", "F3", "F4"],
        yticklabels=["Token 1", "Token 2", "Token 3", "Token 4"],
    )
    ax1.set_title("Non-Attention Processing\n(Independent Token Processing)")

    sns.heatmap(
        attention_weights.detach().cpu().numpy(),
        ax=ax2,
        cmap="viridis",
        annot=True,
        fmt=".2f",
        xticklabels=["Token 1", "Token 2", "Token 3", "Token 4"],
        yticklabels=["Token 1", "Token 2", "Token 3", "Token 4"],
    )
    ax2.set_title("Attention Weights\n(How Tokens Influence Each Other)")

    bars = ax3.bar(range(len(bert_tokens)), bert_attention, color="skyblue")
    ax3.set_xticks(range(len(bert_tokens)))
    ax3.set_xticklabels(bert_tokens, rotation=45, ha="right")
    ax3.set_title("BERT [CLS] Attention\n(Real-world Example)")
    ax3.set_ylabel("Attention Score")

    for rect in bars:
        height = rect.get_height()
        ax3.text(
            rect.get_x() + rect.get_width() / 2.0,
            height,
            f"{height:.3f}",
            ha="center",
            va="bottom",
            fontsize=8,
        )

    plt.tight_layout()
    plt.show()

    print("\nKey Observations:")
    print("1. Non-Attention: each token is transformed independently.")
    print("2. Attention: each token blends information from all tokens.")
    print(f"3. Mean attention interaction strength: {attention_weights.mean().item():.3f}")
    print(f"4. Strongest [CLS] target token: {bert_tokens[int(np.argmax(bert_attention))]}")


visualize_attention_comparison(
    non_attention_output=non_attention_output,
    attention_weights=attention_weights,
    bert_tokens=bert_tokens,
    bert_attention=bert_cls_attention,
)


## [CLS] Token (Classification)
- Added at the start of every input sequence
- Aggregates information from the entire sequence
- Used for classification tasks as a sequence summary

## [SEP] Token (Separator)
- Added at the end of sequences
- Separates sentence segments
- Helps the model understand sequence boundaries


In [None]:
# Try your own sentence
custom_sentence = "Students learn faster when examples are visual and interactive."
bert_tokens_custom, bert_cls_attention_custom = get_bert_cls_attention(custom_sentence)
