# CS336 Basics Playground

This notebook is for experimenting with the CS336 basics implementation.

## Quick Start

Make sure to run this notebook within the uv environment. You can start it with:
```bash
uv run jupyter lab playground.ipynb
```


In [2]:
# Import basic modules to verify environment
import sys
import torch
import math
import numpy as np
import torch.nn as nn

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"NumPy: {np.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


Python: 3.13.5 (main, Jul  8 2025, 20:55:53) [Clang 20.1.4 ]
PyTorch: 2.6.0
NumPy: 2.3.2
CUDA available: False


In [52]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        super().__init__()
        rotation_matrices = []
        dtype = torch.float32
        for i in range(max_seq_len):
            rotation_values = torch.tensor([i / (math.pow(theta, (2 * (k // 2)) / float(d_k))) for k in range(d_k)], dtype=dtype)
            print(rotation_values)
            cos_values = torch.cos(rotation_values)
            mask = torch.arange(rotation_values.shape[0], dtype=dtype) % 2 == 0
            sin_values = torch.sin(torch.where(mask, rotation_values, 0))[:-1]
            rotation_matrices.append(torch.diag(cos_values) + torch.diag(sin_values, diagonal=-1) + torch.diag(-sin_values, diagonal=1))
        self.register_buffer('rotary_matrix', torch.stack(rotation_matrices, dim=0).to(device))
    
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        rotary_matrices = self.rotary_matrix[token_positions]
        return einsum(rotary_matrices, x, '... seq_len d_k d_k, ... seq_len d_k -> ... seq_len d_k')

In [54]:
rope = RotaryPositionalEmbedding(2, 6, 2)
rope.rotary_matrix

tensor([0., 0., 0., 0., 0., 0.])
tensor([1.0000, 1.0000, 0.7937, 0.7937, 0.6300, 0.6300])


TypeError: RotaryPositionalEmbedding.forward() missing 1 required positional argument: 'token_positions'

In [48]:
class RotaryPositionalEmbeddingCorrect(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        super().__init__()
        dtype = torch.float32

        # positions: [0, 1, ..., max_seq_len-1]
        positions = torch.arange(max_seq_len, device=device, dtype=dtype).unsqueeze(1)

        # pair indices: [0, 1, ..., d_k/2 - 1]
        pair_indices = torch.arange(0, d_k // 2, device=device, dtype=dtype)

        # inverse frequencies theta^{-2i/d_k}
        inv_freq = theta ** (-2.0 * pair_indices / float(d_k))
        

        # angles = position * inv_freq  -> shape: (max_seq_len, d_k/2)
        angles = positions * inv_freq
        print(angles)
        cos = torch.cos(angles)
        sin = torch.sin(angles)

        # Cache trig tables per position for each pair
        self.register_buffer('cos_table', cos)
        self.register_buffer('sin_table', sin)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        # Fetch per-position cos/sin for each feature pair
        cos_table_t = cast(torch.Tensor, self.cos_table)
        sin_table_t = cast(torch.Tensor, self.sin_table)
        cos = cos_table_t[token_positions].to(dtype=x.dtype)
        sin = sin_table_t[token_positions].to(dtype=x.dtype)

        x_even = x[..., 0::2]
        x_odd = x[..., 1::2]

        rotated_even = x_even * cos - x_odd * sin
        rotated_odd = x_even * sin + x_odd * cos

        # Interleave even/odd back into last dimension
        out = torch.empty_like(x)
        out[..., 0::2] = rotated_even
        out[..., 1::2] = rotated_odd
        return out

In [49]:
rope = RotaryPositionalEmbeddingCorrect(2, 6, 2)
rope.cos_table

tensor([[0.0000, 0.0000, 0.0000],
        [1.0000, 0.7937, 0.6300]])


tensor([[1.0000, 1.0000, 1.0000],
        [0.5403, 0.7012, 0.8081]])

In [39]:
rope.sin_table

tensor([[0.0000, 0.0000],
        [0.8415, 0.8415]])

In [7]:
torch.arange(5).unsqueeze(0)

torch.Size([1, 5])