In [None]:
import time
import random
import math
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset
import tiktoken

print('All imports done.')

All imports done.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"Using device: {device}")

Using device: cuda


### Hyperparameters

For Data
1. block_size
2. train_subset_size
3. batch_size
4. tinystories_weights


In [None]:
# Data parameters
block_size = 128  # Sequence length for training
train_subset_size = 20000  # Number of samples to use from TinyStories
batch_size = 32

tinystories_weight = 0.5  # Proportion of TinyStories in mixed dataset (0.0 to skip)

# Training parameters
epochs = 5
learning_rate = 1e-3
log_steps = 100
sample_interval = 30  # seconds between text samples during training
max_steps_per_epoch = None  # Set to an int for quick tests

transformer_d_model = 512
transformer_n_heads = 4
transformer_n_blocks = 4
transformer_max_seq_len = block_size


# Prompt for text generation
default_prompt = "Once upon a time"

In [None]:
# Load tokenizer
enc = tiktoken.get_encoding("gpt2")
vocab_size = enc.n_vocab

# Load TinyStories dataset
tinystories_seqs = []
if tinystories_weight > 0.0:
    print(f"Loading TinyStories from huggingface with weight={tinystories_weight}...")
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    train_dataset = dataset.select(range(train_subset_size))
    for sample in train_dataset:
        text = sample['text']
        tokens = enc.encode(text)
        tokens = tokens[:block_size]
        if len(tokens) > 0:
            tinystories_seqs.append(tokens)
    print(f"TinyStories sequences: {len(tinystories_seqs)}")
else:
    print("TinyStories weight=0 => skipping TinyStories.")
other_seqs = []

Loading TinyStories from huggingface with weight=0.5...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

TinyStories sequences: 20000


In [None]:
class LinearSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, tinystories_seqs, other_seqs, p_tiny: float):
        super().__init__()
        self.tinystories_seqs = tinystories_seqs
        self.other_seqs = other_seqs
        self.p_tiny = p_tiny
        self.has_tinystories = (len(self.tinystories_seqs) > 0)
        self.has_other = (len(self.other_seqs) > 0)
        self.total_length = len(self.tinystories_seqs) + len(self.other_seqs)
        self.tiny_stories_count=0
        self.other_count=0
        if self.total_length == 0:
            raise ValueError("No data found! Both TinyStories and other sets are empty.")

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        import random
        r = random.random()
        if self.has_tinystories and self.has_other:
            if r < self.p_tiny:
                seq = self.tinystories_seqs[self.tiny_stories_count]
                self.tiny_stories_count+=1
            else:
                seq = self.other_seqs[self.other_count]
                self.other_count+=1
        elif self.has_tinystories:
                seq = self.tinystories_seqs[self.tiny_stories_count]
                self.tiny_stories_count+=1
        else:
                seq = self.other_seqs[self.other_count]
                self.other_count+=1
        if len(self.tinystories_seqs) <= self.tiny_stories_count:
            self.tiny_stories_count=0
        if len(self.other_seqs) <=  self. other_count:
            self.other_count=0
        return torch.tensor(seq, dtype=torch.long)

In [None]:
def seq_collate_fn(batch):
    max_len = max(len(seq) for seq in batch)
    batch_size = len(batch)
    padded = torch.zeros(max_len, batch_size, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq_len = seq.size(0)
        padded[:seq_len, i] = seq
    return padded

In [None]:
# Create dataset and loader
p_tiny = tinystories_weight
combined_dataset = LinearSequenceDataset(
    tinystories_seqs=tinystories_seqs,
    other_seqs=other_seqs,
    p_tiny=p_tiny
)

In [None]:
combined_dataset.total_length

20000

In [None]:
train_loader = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=seq_collate_fn
)

print("DataLoader ready. Vocab size:", vocab_size)

DataLoader ready. Vocab size: 50257


In [None]:
### Validation
# Validation data (next 2000 samples)
val_subset_size = train_subset_size // 10  # 2000
tinystories_val_seqs = []
print(f"Loading TinyStories val subset ({train_subset_size}:{train_subset_size + val_subset_size})...")
dataset_val = dataset.select(range(train_subset_size, train_subset_size + val_subset_size))  # 20000..21999
for sample in dataset_val:
    text = sample["text"]
    tokens = enc.encode(text)
    tokens = tokens[:block_size]
    if len(tokens) > 0:
        tinystories_val_seqs.append(tokens)
print(f"TinyStories VAL sequences: {len(tinystories_val_seqs)}")

# Create validation dataset and DataLoader

val_p_tiny = tinystories_weight  # same mixing proportion as training

val_dataset = LinearSequenceDataset(
    tinystories_seqs=tinystories_val_seqs,
    other_seqs=[],          # no other validation data for now
    p_tiny=val_p_tiny
)

val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=seq_collate_fn
)

print("Validation DataLoader ready. Num val sequences:", len(val_dataset))

Loading TinyStories val subset (20000:22000)...
TinyStories VAL sequences: 2000
Validation DataLoader ready. Num val sequences: 2000


In [None]:
def compute_next_token_loss(logits, tokens_seq):
    '''
    logits: (seq_len, batch, vocab_size)
    tokens_seq: (seq_len, batch)
    Computes cross-entropy loss for next-token prediction.
    '''
    seq_len, batch = tokens_seq.shape
    # Predict next token: input t => predict t+1
    logits = logits[:-1]  # (seq_len-1, batch, vocab_size)
    targets = tokens_seq[1:]  # (seq_len-1, batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
    return loss

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-6, bias: bool = True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.use_bias = bias
        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        y = (x / rms) * self.weight
        if self.use_bias:
            y = y + self.bias
        return y

In [None]:
class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention with optional KV cache."""
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def _split_heads(self, t: torch.Tensor, T: int, B: int) -> torch.Tensor:
        return t.view(T, B, self.n_heads, self.head_dim).permute(1, 2, 0, 3)  # (B,H,T,D)

    def forward(self, x: torch.Tensor, past_kv=None):
        T, B, C = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = self._split_heads(q, T, B)
        k = self._split_heads(k, T, B)
        v = self._split_heads(v, T, B)

        if past_kv is not None:
            k_cache, v_cache = past_kv
            if k_cache is not None:
                k = torch.cat([k_cache, k], dim=2)
                v = torch.cat([v_cache, v], dim=2)

        scale = 1.0 / math.sqrt(self.head_dim)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale  # (B,H,T,T_total)
        if past_kv is None and T > 1:
            causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
            attn_scores = attn_scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn_probs = torch.softmax(attn_scores, dim=-1)
        ctx = torch.matmul(attn_probs, v)  # (B,H,T,D)
        ctx = ctx.permute(2, 0, 1, 3).contiguous().view(T, B, C)
        out = self.o_proj(ctx)

        if past_kv is not None:
            return out, (k, v)
        else:
            return out


In [None]:
class TransformerBlock(nn.Module):
    """Single decoder block: RMSNorm -> CausalAttn (residual) -> RMSNorm -> MLP (residual)."""
    def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.attn_norm = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.mlp_norm = RMSNorm(d_model)
        inner = int(mlp_ratio * d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, inner, bias=False),
            nn.SiLU(),
            nn.Linear(inner, d_model, bias=False),
        )

    def forward(self, x: torch.Tensor, past_kv=None):
        if past_kv is None:
            x = x + self.attn(self.attn_norm(x))
        else:
            attn_out, new_kv = self.attn(self.attn_norm(x), past_kv=past_kv)
            x = x + attn_out
        x = x + self.mlp(self.mlp_norm(x))
        if past_kv is not None:
            return x, new_kv
        return x


In [None]:
class TransformerModel(nn.Module):
    """Decoder-only causal Transformer producing logits for next-token prediction."""
    def __init__(self,
                 vocab_size: int = 50257,
                 d_model: int = 512,
                 n_heads: int = 8,
                 n_blocks: int = 6,
                 max_seq_len: int = 2048,
                 mlp_ratio: float = 4.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, mlp_ratio) for _ in range(n_blocks)
        ])
        self.final_norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        #print(self.token_embed.weight.shape, self.lm_head.weight.shape )
        self.lm_head.weight = self.token_embed.weight  # Weight tying

    def forward(self, tokens_seq: torch.Tensor, kv_cache=None):
        T, B = tokens_seq.shape
        if T > self.max_seq_len:
            tokens_seq = tokens_seq[-self.max_seq_len:]
            T = tokens_seq.shape[0]
        if kv_cache is not None and len(kv_cache) > 0 and kv_cache[0][0] is not None:
            cached_len = kv_cache[0][0].size(2)
        else:
            cached_len = 0
        pos = torch.arange(cached_len, cached_len + T, device=tokens_seq.device)
        x = self.token_embed(tokens_seq) + self.pos_embed(pos).unsqueeze(1)
        new_cache = [] if kv_cache is not None else None
        if kv_cache is None:
            for blk in self.blocks:
                x = blk(x)
        else:
            for blk, past in zip(self.blocks, kv_cache):
                x, updated = blk(x, past_kv=past)
                new_cache.append(updated)
        x = self.final_norm(x)
        #print(x.shape)
        logits = self.lm_head(x)
        if kv_cache is not None:
            return logits, new_cache
        return logits

In [None]:
def nucleus_sampling(logits, p=0.95):
    # Convert logits to probabilities
    prob_dist = torch.softmax(logits, dim=-1)

    # Sort probabilities in descending order and get indices
    sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True)

    # Compute cumulative sum
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

    # Find the cutoff index where cumulative probability exceeds p
    # We want to include all tokens up to where cumsum first exceeds p
    cutoff_mask = cumsum_probs <= p

    # Always include at least the first token (highest probability)
    # This handles edge case where first token alone has prob > p
    cutoff_mask[0] = True

    # Zero out probabilities beyond the nucleus
    filtered_probs = sorted_probs.clone()
    filtered_probs[~cutoff_mask] = 0.0

    # Renormalize the remaining probabilities
    filtered_probs = filtered_probs / filtered_probs.sum()

    # Sample from the filtered distribution
    sampled_index = torch.multinomial(filtered_probs, num_samples=1).item()

    # Map back to original token index
    chosen_token = sorted_indices[sampled_index].item()

    return chosen_token


In [None]:
def generate_text(model, enc, init_text, max_new_tokens=20, device="cpu",
                  top_p=None,
                  monosemantic_info=None,
                  do_monosemantic=False,
                  use_kv_cache=False):
    """
    A single code path for all models:
      - We keep a growing list 'context_tokens'.
      - At each step, we feed the entire context as (seq_len,1) to model(...).
      - We get model(...)->(seq_len,1,vocab_size). We take the final step's logits => logits[-1,0,:].
      - We pick next token (greedy or top-p), append to context_tokens.
      - Optionally do monosemantic analysis on that newly generated token.
    """
    was_training = model.training
    model.eval()
    with torch.no_grad():
        context_tokens = enc.encode(init_text)
        annotation_list = []

        kv_cache = None
        # Prime KV cache step-by-step over the initial prompt for strict causality
        if use_kv_cache and hasattr(model, 'blocks') and len(context_tokens) > 0:
            kv_cache = [(None, None) for _ in range(len(model.blocks))]
            for tid in context_tokens:
                tok = torch.tensor([tid], dtype=torch.long, device=device).unsqueeze(1)  # (1,1)
                _, kv_cache = model(tok, kv_cache=kv_cache)

        for step_i in range(max_new_tokens):
            if use_kv_cache and hasattr(model, 'blocks'):
                # Use only the last token and advance cache
                if len(context_tokens) == 0:
                    # Fallback: no context, create a space token as a starter
                    last_id = enc.encode(" ")[-1]
                else:
                    last_id = context_tokens[-1]
                last_token = torch.tensor([last_id], dtype=torch.long, device=device).unsqueeze(1)  # (1,1)
                logits_seq, kv_cache = model(last_token, kv_cache=kv_cache if kv_cache is not None else [(None, None) for _ in range(len(model.blocks))])
                next_logits = logits_seq[-1, 0, :]
            else:
                # Fallback: full context each step (works for all models)
                seq_tensor = torch.tensor(context_tokens, dtype=torch.long, device=device).unsqueeze(1)
                logits_seq = model(seq_tensor)              # (seq_len,1,vocab_size)
                next_logits = logits_seq[-1, 0, :]         # shape (vocab_size,)

            if top_p is None:
                # greedy
                chosen_token = torch.argmax(next_logits).item()
            else:
                chosen_token = nucleus_sampling(next_logits, p=top_p)

            context_tokens.append(chosen_token)

            if do_monosemantic and monosemantic_info is not None:
                neighbors = monosemantic_analysis_for_token(
                    chosen_token, model, monosemantic_info, enc, device=device, top_n=5
                )
                annotation_list.append((chosen_token, neighbors))
            else:
                annotation_list.append((chosen_token, []))

    model.train(was_training)

    final_text = enc.decode(context_tokens)
    prefix_text = enc.decode(context_tokens[:-max_new_tokens])
    annotated_strs = [prefix_text]
    for (tid, neighs) in annotation_list:
        token_str = enc.decode([tid])
        if neighs:
            neighbor_strs = [f"{enc.decode([x[1]])}" for x in neighs]
            annotated = f"{token_str}[NN={neighbor_strs}]"
        else:
            annotated = token_str
        annotated_strs.append(annotated)

    annotated_text = "".join(annotated_strs)
    return final_text, annotated_text


In [None]:
def train_one_model(model,
                    loader,
                    epochs,
                    model_name,
                    device,
                    lr=1e-3,
                    log_steps=100,
                    sample_interval=30,
                    max_steps_per_epoch=None,
                    enc=None,
                    monosemantic_info=None,
                    prompt="Once upon a",
                    log_csv_path: str = "",
                    log_flush_steps: int = 100,
                    val_loader=None,
                    val_log_csv_path: str = "",
                    val_interval_steps: int = None):
    """
    Train the model and optionally run/record validation aligned to training global_step.

    If val_loader and val_log_csv_path are provided, a validation pass is run
    every `val_interval_steps` training steps (or once per epoch if
    val_interval_steps is None). Validation loss is logged with the same
    schema as training: timestamp,model,epoch,step_in_epoch,global_step,loss
    and model name "Transformer_val".
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Buffered training logging setup
    loss_buffer = []
    csv_file = None
    if log_csv_path:
        log_dir = os.path.dirname(log_csv_path)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
        file_exists = os.path.exists(log_csv_path)
        csv_file = open(log_csv_path, 'a', newline='')
        # Write header if empty
        if not file_exists or os.path.getsize(log_csv_path) == 0:
            csv_file.write('timestamp,model,epoch,step_in_epoch,global_step,loss\n')

    # Validation logging setup (separate CSV)
    val_csv_file = None
    if val_loader is not None and val_log_csv_path:
        val_log_dir = os.path.dirname(val_log_csv_path)
        if val_log_dir:
            os.makedirs(val_log_dir, exist_ok=True)
        val_file_exists = os.path.exists(val_log_csv_path)
        val_csv_file = open(val_log_csv_path, 'a', newline='')
        if not val_file_exists or os.path.getsize(val_log_csv_path) == 0:
            val_csv_file.write('timestamp,model,epoch,step_in_epoch,global_step,loss\n')

    start_time = time.time()
    next_sample_time = start_time
    global_step = 0

    def run_validation(current_epoch, current_global_step):
        """Run one full validation pass and log a single avg-loss point."""
        if val_loader is None or val_csv_file is None:
            return
        model.eval()
        val_losses = []
        with torch.no_grad():
            for _, batch_tokens in enumerate(val_loader, start=1):
                batch_tokens = batch_tokens.to(device)
                logits = model(batch_tokens)
                vloss = compute_next_token_loss(logits, batch_tokens)
                val_losses.append(vloss.item())
        if not val_losses:
            return
        avg_val_loss = float(sum(val_losses) / len(val_losses))
        print(f"[Validation] Epoch {current_epoch}, global_step {current_global_step}, avg loss: {avg_val_loss:.4f}")
        # Log as one row aligned with training global_step
        val_csv_file.write(
            f"{time.time()},Transformer_val,{current_epoch},1,{current_global_step},{avg_val_loss}\n"
        )
        val_csv_file.flush()
        model.train()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        partial_loss = 0.0
        partial_count = 0

        step_in_epoch = 0
        for batch_idx, batch_tokens in enumerate(loader, start=1):
            step_in_epoch += 1
            global_step += 1

            batch_tokens = batch_tokens.to(device)  # (seq_len, batch)

            logits = model(batch_tokens)  # (seq_len, batch, vocab_size)
            loss = compute_next_token_loss(logits, batch_tokens)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            partial_loss += loss.item()
            partial_count += 1

            # Buffer this step's loss
            if csv_file is not None:
                loss_buffer.append(f"{time.time()},{model_name},{epoch},{step_in_epoch},{global_step},{loss.item()}\n")
                if len(loss_buffer) >= log_flush_steps:
                    csv_file.writelines(loss_buffer)
                    csv_file.flush()
                    loss_buffer.clear()

            # Periodic training progress print
            if batch_idx % log_steps == 0:
                avg_part_loss = partial_loss / partial_count
                print(f"[{model_name}] Epoch {epoch}/{epochs}, "
                      f"Step {batch_idx}/{len(loader)} (global step: {global_step}) "
                      f"Partial Avg Loss: {avg_part_loss:.4f}")
                partial_loss = 0.0
                partial_count = 0

            # Periodic text sampling
            current_time = time.time()
            if current_time >= next_sample_time and enc is not None:
                with torch.no_grad():
                    print(f"\n[{model_name}] Generating sample text (greedy) at epoch={epoch}, step={batch_idx}...")
                    text_greedy, ann_greedy = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=None,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Greedy Sample: {text_greedy}")
                    print(f" Annotated: {ann_greedy}\n")

                    print(f"[{model_name}] Generating sample text (top-p=0.95) at epoch={epoch}, step={batch_idx}...")
                    text_topp, ann_topp = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=0.95,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=0.95) Sample: {text_topp}")
                    print(f" Annotated: {ann_topp}\n")

                    print(f"[{model_name}] Generating sample text (top-p=1.0) at epoch={epoch}, step={batch_idx}...")
                    text_topp1, ann_topp1 = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=1.0,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=1.0) Sample: {text_topp1}")
                    print(f" Annotated: {ann_topp1}\n")

                next_sample_time = current_time + sample_interval

            # Run validation either every N steps or once per epoch if val_interval_steps is None
            if val_loader is not None and val_log_csv_path:
                if val_interval_steps is not None:
                    if global_step % val_interval_steps == 0:
                        run_validation(epoch, global_step)
                else:
                    # If no interval specified, run once at the last batch of the epoch
                    if max_steps_per_epoch is None:
                        is_last_batch = (batch_idx == len(loader))
                    else:
                        is_last_batch = (step_in_epoch >= max_steps_per_epoch or batch_idx == len(loader))
                    if is_last_batch:
                        run_validation(epoch, global_step)

            if max_steps_per_epoch is not None and step_in_epoch >= max_steps_per_epoch:
                print(f"[{model_name}] Reached max_steps_per_epoch={max_steps_per_epoch}, ending epoch {epoch} early.")
                break

        avg_loss = total_loss / step_in_epoch
        print(f"[{model_name}] *** End of Epoch {epoch} *** Avg Loss: {avg_loss:.4f}")

    # Final flush for training log
    if csv_file is not None:
        if loss_buffer:
            csv_file.writelines(loss_buffer)
            csv_file.flush()
        csv_file.close()

    # Close validation log if open
    if val_csv_file is not None:
        val_csv_file.close()

In [None]:
# Transformer model
model = TransformerModel(vocab_size=vocab_size, d_model=512, n_heads=4, n_blocks=6, max_seq_len=block_size)
model = model.to(device)
print(model)

TransformerModel(
  (token_embed): Embedding(50257, 512)
  (pos_embed): Embedding(128, 512)
  (blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (attn_norm): RMSNorm()
      (attn): CausalSelfAttention(
        (q_proj): Linear(in_features=512, out_features=512, bias=False)
        (k_proj): Linear(in_features=512, out_features=512, bias=False)
        (v_proj): Linear(in_features=512, out_features=512, bias=False)
        (o_proj): Linear(in_features=512, out_features=512, bias=False)
      )
      (mlp_norm): RMSNorm()
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=2048, bias=False)
        (1): SiLU()
        (2): Linear(in_features=2048, out_features=512, bias=False)
      )
    )
  )
  (final_norm): RMSNorm()
  (lm_head): Linear(in_features=512, out_features=50257, bias=False)
)


In [None]:

train_one_model(
    model=model,
    loader=train_loader,        # your DataLoader
    epochs=epochs,              # number of epochs (from hyperparameters)
    model_name="Transformer",   # name for logging
    device=device,              # torch.device (CPU or MPS)
    lr=learning_rate,           # learning rate (from hyperparameters)
    log_steps=log_steps,        # print loss every N steps
    sample_interval=sample_interval,  # seconds between text samples
    max_steps_per_epoch=max_steps_per_epoch,  # or None for full epoch
    enc=enc,                    # tokenizer
    prompt=default_prompt,       # prompt for text generation samples
    log_csv_path="training_logs/transformer_training_log.csv",
    val_loader=val_loader,
    val_log_csv_path="training_logs/transformer_validation_log.csv",
    val_interval_steps=log_steps  # Validate every 100 steps (same as log_steps)
)


[Transformer] Generating sample text (greedy) at epoch=1, step=1...
 Greedy Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time time time time

[Transformer] Generating sample text (top-p=0.95) at epoch=1, step=1...
 Top-p (p=0.95) Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time time time time

[Transformer] Generating sample text (top-p=1.0) at epoch=1, step=1...
 Top-p (p=1.0) Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time 

## Save the Trained Model

After training, you can save:
1. **Model state dict** (`.pt` file) - contains all the learned weights
2. **Metadata JSON** - contains model architecture info for reconstruction

This allows you to:
- Resume training later
- Load the model for inference without retraining
- Share the trained model

In [None]:
# Create directory for saved models
save_dir = "models"
os.makedirs(save_dir, exist_ok=True)

# Save model state dict
model_path = os.path.join(save_dir, "transformer_model.pt")
torch.save(model.state_dict(), model_path)
print(f"✓ Model weights saved to: {model_path}")

# Save model metadata (architecture info for reconstruction)
metadata = {
    "model_type": "TransformerModel",
    "vocab_size": vocab_size,
    "d_model": transformer_d_model,
    "n_heads": transformer_n_heads,
    "n_blocks": transformer_n_blocks,
    "max_seq_len": transformer_max_seq_len,
    "mlp_ratio": 4.0,
    "block_size": block_size,
    "training_info": {
        "epochs": epochs,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "train_subset_size": train_subset_size
    }
}

metadata_path = os.path.join(save_dir, "transformer_model_meta.json")
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)
print(f"✓ Model metadata saved to: {metadata_path}")

✓ Model weights saved to: models/transformer_model.pt
✓ Model metadata saved to: models/transformer_model_meta.json


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp -r /content/models /content/drive/MyDrive/ml_pico_models

## Load a Saved Model

To load the model later for inference or continued training:

In [None]:
# Load metadata first to reconstruct the architecture
metadata_path = os.path.join(save_dir, "transformer_model_meta.json")
with open(metadata_path, 'r') as f:
    meta = json.load(f)

# Reconstruct the model with the same architecture
loaded_model = TransformerModel(
    vocab_size=meta["vocab_size"],
    d_model=meta["d_model"],
    n_heads=meta["n_heads"],
    n_blocks=meta["n_blocks"],
    max_seq_len=meta["max_seq_len"],
    mlp_ratio=meta["mlp_ratio"]
).to(device)

# Load the trained weights
model_path = os.path.join(save_dir, "transformer_model.pt")
loaded_model.load_state_dict(torch.load(model_path, map_location=device))
loaded_model.eval()  # Set to evaluation mode

print(f"✓ Model loaded from: {model_path}")
print(f"Model has {sum(p.numel() for p in loaded_model.parameters()):,} parameters")

In [None]:
!cp -r /content/models /content/drive/MyDrive/ml_pico_models

In [None]:
# Test the loaded model by generating text
test_prompt = "Once upon a time"
generated_text, _ = generate_text(
    loaded_model,
    enc,
    test_prompt,
    max_new_tokens=30,
    device=device,
    top_p=0.95,
    use_kv_cache=False
)

print(f"Generated text from loaded model:")
print(generated_text)

In [None]:
try:
    model.eval()
    with torch.no_grad():
        for batch_tokens in train_loader:
            batch_tokens = batch_tokens.to(device)
            logits = model(batch_tokens)
            loss = compute_next_token_loss(logits, batch_tokens)
            print("Batch loss:", loss.item())
            # Optionally, compute perplexity
            perplexity = torch.exp(loss)
            print("Perplexity:", perplexity.item())
            break  # Remove break to check more batches
except Exception as e:
    print(f'Error during evaluation: {e}')

In [None]:
# Example: Generate text from a trained model
prompt = "There was once a"
max_new_tokens = 50  # Number of tokens to generate


final_text, annotated_text = generate_text(
    model,         # your trained model
    enc,           # your tokenizer (e.g., tiktoken.get_encoding("gpt2"))
    prompt,
    max_new_tokens=max_new_tokens,
    device=device, # your torch.device
    top_p=0.95,    # or None for greedy
    use_kv_cache=False  # True for TransformerModel if you want fast generation
)

print("Generated text:")
print(final_text)

In [None]:
import numpy as np

import matplotlib.pyplot as plt

# Paths for training and validation logs
train_log_path = "training_logs/transformer_training_log.csv"
val_log_path = "training_logs/transformer_validation_log.csv"

if os.path.exists(train_log_path):
    # Load training CSV (timestamp,model,epoch,step_in_epoch,global_step,loss)
    train_data = np.genfromtxt(train_log_path, delimiter=",", skip_header=1)
    train_global_steps = train_data[:, 4]
    train_losses = train_data[:, 5]
else:
    train_global_steps, train_losses = None, None
    print(f"No training log found at: {train_log_path}. Set log_csv_path in train_one_model to save losses.")

if os.path.exists(val_log_path):
    # Load validation CSV with same schema
    val_data = np.genfromtxt(val_log_path, delimiter=",", skip_header=1)
    val_global_steps = val_data[:, 4]
    val_losses = val_data[:, 5]
else:
    val_global_steps, val_losses = None, None
    print(f"No validation log found at: {val_log_path}. Run the validation cell to create it.")

if train_global_steps is not None or val_global_steps is not None:
    plt.figure(figsize=(8, 4))

    if train_global_steps is not None:
        plt.plot(train_global_steps, train_losses, label="Training", alpha=0.7, linewidth=1.5)

    if val_global_steps is not None:
        plt.plot(val_global_steps, val_losses, label="Validation", alpha=0.7, linewidth=1.5)

    plt.xlabel("Global Step")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss over Time")
    plt.grid(alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

    if train_losses is not None:
        print(f"Training: initial={train_losses[0]:.3f}, final={train_losses[-1]:.3f}, drop={(train_losses[0]-train_losses[-1])/train_losses[0]*100:.1f}%")
    if val_losses is not None:
        print(f"Validation: initial={val_losses[0]:.3f}, final={val_losses[-1]:.3f}, drop={(val_losses[0]-val_losses[-1])/val_losses[0]*100:.1f}%")