<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>


# Deep Learning Basics with PyTorch

**Dr. Yves J. Hilpisch with GPT-5**


# Interactive Attention Visualizations (Chapter 14)

Self-contained widgets to build intuition for attention:
- Pick a token to visualize its attention row.
- Switch masks (none / padding / causal).
- Adjust temperature.
- Toggle aggregate over heads or inspect a single head.
- See score → softmax → mix and a paint-mixing analogy.

## Overview

This notebook provides a concise, hands-on walkthrough of Deep Learning Basics with PyTorch.
Use it as a companion to the chapter: run each cell, read the short notes,
and try small variations to build intuition.

Tips:
- Run cells top to bottom; restart kernel if state gets confusing.
- Prefer small, fast experiments; iterate quickly and observe outputs.
- Keep an eye on shapes, dtypes, and devices when using PyTorch.


In [1]:
# If ipywidgets is missing (usually fine on Colab), uncomment and run:
# !pip -q install ipywidgets
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

%matplotlib inline
plt.rcParams["figure.dpi"] = 120  # crisper plots
np.set_printoptions(precision=3, suppress=True)

In [2]:
import numpy as np


def softmax(x, axis=-1, temp=1.0):
    x = x / max(float(temp), 1e-8)
    x = x - np.max(x, axis=axis, keepdims=True)
    ex = np.exp(x)
    return ex / np.sum(ex, axis=axis, keepdims=True)


def make_toy_data(T=8, d_model=8, heads=4, seed=42):
    """Generate toy query/key/value tensors for multi-head attention demos."""
    rng = np.random.default_rng(seed)
    tokens = rng.normal(size=(T, d_model))
    wq = rng.normal(size=(d_model, d_model))
    wk = rng.normal(size=(d_model, d_model))
    wv = rng.normal(size=(d_model, d_model))
    q = tokens @ wq
    k = tokens @ wk
    v = tokens @ wv
    d_head = d_model // heads
    qh = q.reshape(T, heads, d_head).transpose(1, 0, 2)
    kh = k.reshape(T, heads, d_head).transpose(1, 0, 2)
    vh = v.reshape(T, heads, d_head).transpose(1, 0, 2)
    return qh, kh, vh


def rgb_from_values(values, seed=0):
    """Map value vectors (heads, T, d_head) to RGB colors for each token."""
    rng = np.random.default_rng(seed)
    heads, T, d_head = values.shape
    projection = rng.normal(size=(d_head, 3))
    colors = values @ projection
    colors = (colors - colors.min()) / (colors.max() - colors.min() + 1e-8)
    return colors


def make_masks(length, mask_type='none', pad_len=None):
    mask = np.ones((length, length), dtype=bool)
    if mask_type == 'none':
        return mask
    if mask_type == 'causal':
        return np.tril(mask)
    if mask_type == 'padding':
        if pad_len is None:
            return mask
        pad_len = max(0, min(length, int(pad_len)))
        masked = np.zeros((length, length), dtype=bool)
        masked[:pad_len, :pad_len] = True
        return masked
    return mask


def attention_weights(Qh, Kh, mask=None, temp=1.0):
    heads, length, d_head = Qh.shape
    scores = np.matmul(Qh, np.transpose(Kh, (0, 2, 1))) / np.sqrt(d_head)
    if mask is not None:
        scores = np.where(mask[None, :, :], scores, -1e9)
    weights = softmax(scores, axis=-1, temp=temp)
    return scores, weights

In [3]:
T = 8
D_MODEL = 8
HEADS = 4

QH, KH, VH = make_toy_data(T=T, d_model=D_MODEL, heads=HEADS, seed=7)
COLORS = rgb_from_values(VH, seed=3)

VBox(children=(HBox(children=(IntSlider(value=2, continuous_update=False, description='token i', max=7), IntSl…

In [4]:
token_slider = widgets.IntSlider(
    value=2,
    min=0,
    max=T - 1,
    step=1,
    description='token i',
    continuous_update=False,
)
temp_slider = widgets.FloatLogSlider(
    value=1.0,
    base=10,
    min=-1,
    max=1,
    step=0.05,
    description='Temp',
    continuous_update=False,
)
mask_selector = widgets.ToggleButtons(
    options=['none', 'padding', 'causal'],
    value='none',
    description='Mask',
)
pad_len_slider = widgets.IntSlider(
    value=T,
    min=1,
    max=T,
    step=1,
    description='pad len',
    continuous_update=False,
)
aggregate_checkbox = widgets.Checkbox(value=False, description='Aggregate heads')
head_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=HEADS,
    step=1,
    description='Head',
    continuous_update=False,
)

controls = widgets.VBox([
    widgets.HBox([token_slider, head_slider, aggregate_checkbox]),
    widgets.HBox([temp_slider, mask_selector, pad_len_slider]),
])
display(controls)

In [5]:
def draw(token, head, aggregate, temp, mask_kind, pad_len):
    mask = make_masks(T, mask_type=mask_kind, pad_len=pad_len)
    _, weights = attention_weights(QH, KH, mask=mask, temp=temp)

    if aggregate:
        weight_row = weights[:, token, :].mean(axis=0)
        colors = COLORS.mean(axis=0)
        head_label = 'avg'
    else:
        head_index = int(head) - 1
        weight_row = weights[head_index, token, :]
        colors = COLORS[head_index]
        head_label = str(head)

    mixed_color = weight_row @ colors

    fig, ax = plt.subplots(figsize=(10, 3))
    for j in range(T):
        ax.add_patch(plt.Rectangle((0.08, j + 0.15), 0.55, 0.7, color=colors[j], ec='#333333'))
        ax.text(0.66, j + 0.5, f'j = {j} w = {weight_row[j]:.2f}', va='center', ha='left', fontsize=8)

    ax.add_patch(plt.Rectangle((0.08, T + 0.25), 0.85, 0.9, color=mixed_color, ec='#000000', lw=1.5))
    ax.text(0.5, T + 0.7, 'mixed output color', va='center', ha='center', fontsize=9, color='#111111')

    ax.set_xlim(0, 1.05)
    ax.set_ylim(0, T + 1.3)
    ax.set_xticks([])
    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_visible(False)

    fig.suptitle(f'i = {token}, head = {head_label}, mask = {mask_kind}, T° = {temp:.2f}')
    plt.show()


interactive_output = widgets.interactive_output(
    draw,
    {
        'token': token_slider,
        'head': head_slider,
        'aggregate': aggregate_checkbox,
        'temp': temp_slider,
        'mask_kind': mask_selector,
        'pad_len': pad_len_slider,
    },
)
display(interactive_output)

Output()

## Exercises

1. Use the widgets to explore heads/temperature; take screenshots and annotate.
2. Create two prompts that elicit different attention patterns and explain why.


<img src="https://hilpisch.com/tpq_logo.png" alt="The Python Quants" width="35%" align="right" border="0"><br>
