In [2]:
import argparse
import time
import random
import math
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from datasets import load_dataset
import tiktoken

print('All imports done.')

  from .autonotebook import tqdm as notebook_tqdm


All imports done.


In [21]:
# Hyperparameters for training and models
# Adjust these as needed for your experiments

# Data parameters
block_size = 128  # Sequence length for training
train_subset_size = 20000  # Number of samples to use from TinyStories
batch_size = 16

tinystories_weight = 0.5  # Proportion of TinyStories in mixed dataset (0.0 to skip)

# Model parameters (default values, override in model cell if needed)
mlp_k = 3  # K for KGramMLPSeqModel
mlp_embed_size = 512
mlp_num_inner_layers = 1
mlp_chunk_size = 1

lstm_embed_size = 512
lstm_hidden_size = 512

transformer_d_model = 256
transformer_n_heads = 4
transformer_n_blocks = 2
transformer_max_seq_len = block_size

# Training parameters
epochs = 2
learning_rate = 1e-3
log_steps = 100
sample_interval = 30  # seconds between text samples during training
max_steps_per_epoch = None  # Set to an int for quick tests

# Prompt for text generation
default_prompt = "Once upon a time"


In [17]:
def compute_next_token_loss(logits, tokens_seq):
    '''
    logits: (seq_len, batch, vocab_size)
    tokens_seq: (seq_len, batch)
    Computes cross-entropy loss for next-token prediction.
    '''
    seq_len, batch = tokens_seq.shape
    # Predict next token: input t => predict t+1
    logits = logits[:-1]  # (seq_len-1, batch, vocab_size)
    targets = tokens_seq[1:]  # (seq_len-1, batch)
    loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
    return loss

In [3]:
# Set device to use Apple Silicon GPU if available
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


In [4]:
#Load Dataset
# Parameters (set these as needed)
block_size = 128
train_subset_size = 20000  # or smaller for quick tests
tinystories_weight = 0.5   # set to 0.0 to skip TinyStories

# Load tokenizer
enc = tiktoken.get_encoding("gpt2")
vocab_size = enc.n_vocab

# Load TinyStories dataset
tinystories_seqs = []
if tinystories_weight > 0.0:
    print(f"Loading TinyStories from huggingface with weight={tinystories_weight}...")
    dataset = load_dataset("roneneldan/TinyStories", split="train")
    dataset = dataset.select(range(train_subset_size))
    for sample in dataset:
        text = sample['text']
        tokens = enc.encode(text)
        tokens = tokens[:block_size]
        if len(tokens) > 0:
            tinystories_seqs.append(tokens)
    print(f"TinyStories sequences: {len(tinystories_seqs)}")
else:
    print("TinyStories weight=0 => skipping TinyStories.")

# Optionally, load your own text files (uncomment and set file paths)
other_seqs = []
# input_files = ["myfile.txt"]
# for filepath in input_files:
#     with open(filepath, "r", encoding="utf-8") as f:
#         lines = f.readlines()
#     for line in lines:
#         line = line.strip()
#         if not line:
#             continue
#         tokens = enc.encode(line)
#         tokens = tokens[:block_size]
#         if len(tokens) > 0:
#             other_seqs.append(tokens)
# print(f"Custom input files: {len(other_seqs)} sequences loaded.")

# Dataset and DataLoader
class MixedSequenceDataset(torch.utils.data.Dataset):
    def __init__(self, tinystories_seqs, other_seqs, p_tiny: float):
        super().__init__()
        self.tinystories_seqs = tinystories_seqs
        self.other_seqs = other_seqs
        self.p_tiny = p_tiny
        self.has_tinystories = (len(self.tinystories_seqs) > 0)
        self.has_other = (len(self.other_seqs) > 0)
        self.total_length = len(self.tinystories_seqs) + len(self.other_seqs)
        if self.total_length == 0:
            raise ValueError("No data found! Both TinyStories and other sets are empty.")

    def __len__(self):
        return self.total_length

    def __getitem__(self, idx):
        import random
        r = random.random()
        if self.has_tinystories and self.has_other:
            if r < self.p_tiny:
                i = random.randint(0, len(self.tinystories_seqs) - 1)
                seq = self.tinystories_seqs[i]
            else:
                i = random.randint(0, len(self.other_seqs) - 1)
                seq = self.other_seqs[i]
        elif self.has_tinystories:
            i = random.randint(0, len(self.tinystories_seqs) - 1)
            seq = self.tinystories_seqs[i]
        else:
            i = random.randint(0, len(self.other_seqs) - 1)
            seq = self.other_seqs[i]
        return torch.tensor(seq, dtype=torch.long)

def seq_collate_fn(batch):
    max_len = max(len(seq) for seq in batch)
    batch_size = len(batch)
    padded = torch.zeros(max_len, batch_size, dtype=torch.long)
    for i, seq in enumerate(batch):
        seq_len = seq.size(0)
        padded[:seq_len, i] = seq
    return padded

# Create dataset and loader
p_tiny = tinystories_weight
combined_dataset = MixedSequenceDataset(
    tinystories_seqs=tinystories_seqs,
    other_seqs=other_seqs,
    p_tiny=p_tiny
)

batch_size = 16
train_loader = torch.utils.data.DataLoader(
    combined_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=seq_collate_fn
)

print("DataLoader ready. Vocab size:", vocab_size)

Loading TinyStories from huggingface with weight=0.5...
TinyStories sequences: 20000
DataLoader ready. Vocab size: 50257


In [5]:
class KGramMLPSeqModel(nn.Module):
    """
    For each position t in [0..seq_len-1], gather the last k tokens => one-hot => MLP => logits.
    Return (seq_len, batch, vocab_size).

    Potentially very large memory usage for big vocab or seq_len. chunk_size helps mitigate overhead.
    """

    def __init__(self, vocab_size, k=3, embed_size=1024, num_inner_layers=1, chunk_size=1, activation=nn.SiLU):
        super().__init__()
        self.k = k
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_inner_layers = num_inner_layers
        self.chunk_size = chunk_size
        self.activation = activation
        self.net = nn.Sequential(
            nn.Linear(self.k * self.vocab_size, self.embed_size),
            self.activation(),
            *[
                layer for _ in range(self.num_inner_layers) for layer in (
                    nn.Linear(self.embed_size, self.embed_size),
                    self.activation()
                )],
            nn.Linear(self.embed_size, self.vocab_size)
        )

    def forward(self, tokens_seq):
        """
        tokens_seq: (seq_len, batch)
        return: (seq_len, batch, vocab_size)
        We'll do a loop over time steps. chunk_size can reduce overhead.
        """
        seq_len, batch_size = tokens_seq.shape
        outputs = []

        start = 0
        while start < seq_len:
            end = min(start + self.chunk_size, seq_len)
            block_outputs = []
            for t in range(start, end):
                batch_logits = []
                for b in range(batch_size):
                    if t < self.k:
                        needed = self.k - t
                        context_ids = [0]*needed + tokens_seq[:t, b].tolist()
                    else:
                        context_ids = tokens_seq[t-self.k:t, b].tolist()

                    context_oh = F.one_hot(
                        torch.tensor(context_ids, dtype=torch.long, device=tokens_seq.device),
                        num_classes=self.vocab_size
                    )
                    context_flat = context_oh.flatten().float().unsqueeze(0)
                    logits_b = self.net(context_flat)  # (1, vocab_size)
                    batch_logits.append(logits_b)
                block_outputs.append(torch.cat(batch_logits, dim=0).unsqueeze(0))  # (1, batch, vocab_size)

            block_outputs = torch.cat(block_outputs, dim=0)  # (chunk_size, batch, vocab_size)
            outputs.append(block_outputs)
            start = end

        outputs = torch.cat(outputs, dim=0)  # (seq_len, batch, vocab_size)
        return outputs

In [6]:
class LSTMSeqModel(nn.Module):
    def __init__(self, vocab_size, embed_size=1024, hidden_size=1024):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first=False)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, tokens_seq):
        """
        tokens_seq: (seq_len, batch)
        => (seq_len, batch, vocab_size)
        """
        emb = self.embedding(tokens_seq)   # (seq_len, batch, embed)
        self.lstm.flatten_parameters()
        out, _ = self.lstm(emb)           # (seq_len, batch, hidden)
        logits = self.linear(out)         # (seq_len, batch, vocab_size)
        return logits

In [7]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization."""
    def __init__(self, dim: int, eps: float = 1e-6, bias: bool = True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.use_bias = bias
        self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        y = (x / rms) * self.weight
        if self.use_bias:
            y = y + self.bias
        return y


class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention with optional KV cache."""
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def _split_heads(self, t: torch.Tensor, T: int, B: int) -> torch.Tensor:
        return t.view(T, B, self.n_heads, self.head_dim).permute(1, 2, 0, 3)  # (B,H,T,D)

    def forward(self, x: torch.Tensor, past_kv=None):
        T, B, C = x.shape
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = self._split_heads(q, T, B)
        k = self._split_heads(k, T, B)
        v = self._split_heads(v, T, B)

        if past_kv is not None:
            k_cache, v_cache = past_kv
            if k_cache is not None:
                k = torch.cat([k_cache, k], dim=2)
                v = torch.cat([v_cache, v], dim=2)

        scale = 1.0 / math.sqrt(self.head_dim)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale  # (B,H,T,T_total)
        if past_kv is None and T > 1:
            causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
            attn_scores = attn_scores.masked_fill(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn_probs = torch.softmax(attn_scores, dim=-1)
        ctx = torch.matmul(attn_probs, v)  # (B,H,T,D)
        ctx = ctx.permute(2, 0, 1, 3).contiguous().view(T, B, C)
        out = self.o_proj(ctx)

        if past_kv is not None:
            return out, (k, v)
        else:
            return out


class TransformerBlock(nn.Module):
    """Single decoder block: RMSNorm -> CausalAttn (residual) -> RMSNorm -> MLP (residual)."""
    def __init__(self, d_model: int, n_heads: int, mlp_ratio: float = 4.0):
        super().__init__()
        self.attn_norm = RMSNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.mlp_norm = RMSNorm(d_model)
        inner = int(mlp_ratio * d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, inner, bias=False),
            nn.SiLU(),
            nn.Linear(inner, d_model, bias=False),
        )

    def forward(self, x: torch.Tensor, past_kv=None):
        if past_kv is None:
            x = x + self.attn(self.attn_norm(x))
        else:
            attn_out, new_kv = self.attn(self.attn_norm(x), past_kv=past_kv)
            x = x + attn_out
        x = x + self.mlp(self.mlp_norm(x))
        if past_kv is not None:
            return x, new_kv
        return x


class TransformerModel(nn.Module):
    """Decoder-only causal Transformer producing logits for next-token prediction."""
    def __init__(self,
                 vocab_size: int = 50257,
                 d_model: int = 512,
                 n_heads: int = 8,
                 n_blocks: int = 6,
                 max_seq_len: int = 2048,
                 mlp_ratio: float = 4.0):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_seq_len, d_model)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, mlp_ratio) for _ in range(n_blocks)
        ])
        self.final_norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_embed.weight  # Weight tying

    def forward(self, tokens_seq: torch.Tensor, kv_cache=None):
        T, B = tokens_seq.shape
        if T > self.max_seq_len:
            tokens_seq = tokens_seq[-self.max_seq_len:]
            T = tokens_seq.shape[0]
        if kv_cache is not None and len(kv_cache) > 0 and kv_cache[0][0] is not None:
            cached_len = kv_cache[0][0].size(2)
        else:
            cached_len = 0
        pos = torch.arange(cached_len, cached_len + T, device=tokens_seq.device)
        x = self.token_embed(tokens_seq) + self.pos_embed(pos).unsqueeze(1)
        new_cache = [] if kv_cache is not None else None
        if kv_cache is None:
            for blk in self.blocks:
                x = blk(x)
        else:
            for blk, past in zip(self.blocks, kv_cache):
                x, updated = blk(x, past_kv=past)
                new_cache.append(updated)
        x = self.final_norm(x)
        logits = self.lm_head(x)
        if kv_cache is not None:
            return logits, new_cache
        return logits

In [8]:
def train_one_model(model,
                    loader,
                    epochs,
                    model_name,
                    device,
                    lr=1e-3,
                    log_steps=100,
                    sample_interval=30,
                    max_steps_per_epoch=None,
                    enc=None,
                    monosemantic_info=None,
                    prompt="Once upon a",
                    log_csv_path: str = "",
                    log_flush_steps: int = 100):
    """
    We add `prompt` as an explicit argument so we can pass it down from main().
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Buffered logging setup
    loss_buffer = []
    csv_file = None
    if log_csv_path:
        log_dir = os.path.dirname(log_csv_path)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
        file_exists = os.path.exists(log_csv_path)
        csv_file = open(log_csv_path, 'a', newline='')
        # Write header if empty
        if not file_exists or os.path.getsize(log_csv_path) == 0:
            csv_file.write('timestamp,model,epoch,step_in_epoch,global_step,loss\n')

    start_time = time.time()
    next_sample_time = start_time
    global_step = 0

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        partial_loss = 0.0
        partial_count = 0

        step_in_epoch = 0
        for batch_idx, batch_tokens in enumerate(loader, start=1):
            step_in_epoch += 1
            global_step += 1

            batch_tokens = batch_tokens.to(device)  # (seq_len, batch)

            logits = model(batch_tokens)  # (seq_len, batch, vocab_size)
            loss = compute_next_token_loss(logits, batch_tokens)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            partial_loss += loss.item()
            partial_count += 1

            # Buffer this step's loss
            if csv_file is not None:
                loss_buffer.append(f"{time.time()},{model_name},{epoch},{step_in_epoch},{global_step},{loss.item()}\n")
                if len(loss_buffer) >= log_flush_steps:
                    csv_file.writelines(loss_buffer)
                    csv_file.flush()
                    loss_buffer.clear()

            if batch_idx % log_steps == 0:
                avg_part_loss = partial_loss / partial_count
                print(f"[{model_name}] Epoch {epoch}/{epochs}, "
                      f"Step {batch_idx}/{len(loader)} (global step: {global_step}) "
                      f"Partial Avg Loss: {avg_part_loss:.4f}")
                partial_loss = 0.0
                partial_count = 0

            current_time = time.time()
            if current_time >= next_sample_time and enc is not None:
                with torch.no_grad():
                    print(f"\n[{model_name}] Generating sample text (greedy) at epoch={epoch}, step={batch_idx}...")
                    text_greedy, ann_greedy = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=None,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Greedy Sample: {text_greedy}")
                    print(f" Annotated: {ann_greedy}\n")

                    print(f"[{model_name}] Generating sample text (top-p=0.95) at epoch={epoch}, step={batch_idx}...")
                    text_topp, ann_topp = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=0.95,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=0.95) Sample: {text_topp}")
                    print(f" Annotated: {ann_topp}\n")

                    print(f"[{model_name}] Generating sample text (top-p=1.0) at epoch={epoch}, step={batch_idx}...")
                    text_topp1, ann_topp1 = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=1.0,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=1.0) Sample: {text_topp1}")
                    print(f" Annotated: {ann_topp1}\n")

                next_sample_time = current_time + sample_interval

            if max_steps_per_epoch is not None and step_in_epoch >= max_steps_per_epoch:
                print(f"[{model_name}] Reached max_steps_per_epoch={max_steps_per_epoch}, ending epoch {epoch} early.")
                break

        avg_loss = total_loss / step_in_epoch
        print(f"[{model_name}] *** End of Epoch {epoch} *** Avg Loss: {avg_loss:.4f}")

    # Final flush
    if csv_file is not None:
        if loss_buffer:
            csv_file.writelines(loss_buffer)
            csv_file.flush()
        csv_file.close()

In [14]:
# Example: Create and move a model to device
# Uncomment the model you want to use and adjust parameters as needed.

# K-gram MLP model
# model = KGramMLPSeqModel(vocab_size, k=3, embed_size=512, num_inner_layers=1, chunk_size=1)

# LSTM model
# model = LSTMSeqModel(vocab_size, embed_size=512, hidden_size=512)

# Transformer model
model = TransformerModel(vocab_size=vocab_size, d_model=256, n_heads=4, n_blocks=2, max_seq_len=block_size)

# Move model to device
model = model.to(device)

# Print model summary
print(model)

# After running this cell, you can use the 'model' variable in training or evaluation cells.

TransformerModel(
  (token_embed): Embedding(50257, 256)
  (pos_embed): Embedding(128, 256)
  (blocks): ModuleList(
    (0-1): 2 x TransformerBlock(
      (attn_norm): RMSNorm()
      (attn): CausalSelfAttention(
        (q_proj): Linear(in_features=256, out_features=256, bias=False)
        (k_proj): Linear(in_features=256, out_features=256, bias=False)
        (v_proj): Linear(in_features=256, out_features=256, bias=False)
        (o_proj): Linear(in_features=256, out_features=256, bias=False)
      )
      (mlp_norm): RMSNorm()
      (mlp): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=False)
        (1): SiLU()
        (2): Linear(in_features=1024, out_features=256, bias=False)
      )
    )
  )
  (final_norm): RMSNorm()
  (lm_head): Linear(in_features=256, out_features=50257, bias=False)
)


In [25]:
def generate_text(model, enc, init_text, max_new_tokens=20, device="cpu",
                  top_p=None,
                  monosemantic_info=None,
                  do_monosemantic=False,
                  use_kv_cache=False):
    """
    A single code path for all models:
      - We keep a growing list 'context_tokens'.
      - At each step, we feed the entire context as (seq_len,1) to model(...).
      - We get model(...)->(seq_len,1,vocab_size). We take the final step's logits => logits[-1,0,:].
      - We pick next token (greedy or top-p), append to context_tokens.
      - Optionally do monosemantic analysis on that newly generated token.
    """
    was_training = model.training
    model.eval()
    with torch.no_grad():
        context_tokens = enc.encode(init_text)
        annotation_list = []

        kv_cache = None
        # Prime KV cache step-by-step over the initial prompt for strict causality
        if use_kv_cache and hasattr(model, 'blocks') and len(context_tokens) > 0:
            kv_cache = [(None, None) for _ in range(len(model.blocks))]
            for tid in context_tokens:
                tok = torch.tensor([tid], dtype=torch.long, device=device).unsqueeze(1)  # (1,1)
                _, kv_cache = model(tok, kv_cache=kv_cache)

        for step_i in range(max_new_tokens):
            if use_kv_cache and hasattr(model, 'blocks'):
                # Use only the last token and advance cache
                if len(context_tokens) == 0:
                    # Fallback: no context, create a space token as a starter
                    last_id = enc.encode(" ")[-1]
                else:
                    last_id = context_tokens[-1]
                last_token = torch.tensor([last_id], dtype=torch.long, device=device).unsqueeze(1)  # (1,1)
                logits_seq, kv_cache = model(last_token, kv_cache=kv_cache if kv_cache is not None else [(None, None) for _ in range(len(model.blocks))])
                next_logits = logits_seq[-1, 0, :]
            else:
                # Fallback: full context each step (works for all models)
                seq_tensor = torch.tensor(context_tokens, dtype=torch.long, device=device).unsqueeze(1)
                logits_seq = model(seq_tensor)              # (seq_len,1,vocab_size)
                next_logits = logits_seq[-1, 0, :]         # shape (vocab_size,)

            if top_p is None:
                # greedy
                chosen_token = torch.argmax(next_logits).item()
            else:
                chosen_token = nucleus_sampling(next_logits, p=top_p)

            context_tokens.append(chosen_token)

            if do_monosemantic and monosemantic_info is not None:
                neighbors = monosemantic_analysis_for_token(
                    chosen_token, model, monosemantic_info, enc, device=device, top_n=5
                )
                annotation_list.append((chosen_token, neighbors))
            else:
                annotation_list.append((chosen_token, []))

    model.train(was_training)

    final_text = enc.decode(context_tokens)
    prefix_text = enc.decode(context_tokens[:-max_new_tokens])
    annotated_strs = [prefix_text]
    for (tid, neighs) in annotation_list:
        token_str = enc.decode([tid])
        if neighs:
            neighbor_strs = [f"{enc.decode([x[1]])}" for x in neighs]
            annotated = f"{token_str}[NN={neighbor_strs}]"
        else:
            annotated = token_str
        annotated_strs.append(annotated)

    annotated_text = "".join(annotated_strs)
    return final_text, annotated_text


In [27]:
def nucleus_sampling(logits, p=0.95):
    # Convert logits to probabilities
    prob_dist = torch.softmax(logits, dim=-1)

    # Sort probabilities in descending order and get indices
    sorted_probs, sorted_indices = torch.sort(prob_dist, descending=True)

    # Compute cumulative sum
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)

    # Find the cutoff index where cumulative probability exceeds p
    # We want to include all tokens up to where cumsum first exceeds p
    cutoff_mask = cumsum_probs <= p

    # Always include at least the first token (highest probability)
    # This handles edge case where first token alone has prob > p
    cutoff_mask[0] = True

    # Zero out probabilities beyond the nucleus
    filtered_probs = sorted_probs.clone()
    filtered_probs[~cutoff_mask] = 0.0

    # Renormalize the remaining probabilities
    filtered_probs = filtered_probs / filtered_probs.sum()

    # Sample from the filtered distribution
    sampled_index = torch.multinomial(filtered_probs, num_samples=1).item()

    # Map back to original token index
    chosen_token = sorted_indices[sampled_index].item()

    return chosen_token


In [28]:
def train_one_model(model,
                    loader,
                    epochs,
                    model_name,
                    device,
                    lr=1e-3,
                    log_steps=100,
                    sample_interval=30,
                    max_steps_per_epoch=None,
                    enc=None,
                    monosemantic_info=None,
                    prompt="Once upon a",
                    log_csv_path: str = "",
                    log_flush_steps: int = 100):
    """
    We add `prompt` as an explicit argument so we can pass it down from main().
    """
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Buffered logging setup
    loss_buffer = []
    csv_file = None
    if log_csv_path:
        log_dir = os.path.dirname(log_csv_path)
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
        file_exists = os.path.exists(log_csv_path)
        csv_file = open(log_csv_path, 'a', newline='')
        # Write header if empty
        if not file_exists or os.path.getsize(log_csv_path) == 0:
            csv_file.write('timestamp,model,epoch,step_in_epoch,global_step,loss\n')

    start_time = time.time()
    next_sample_time = start_time
    global_step = 0

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        partial_loss = 0.0
        partial_count = 0

        step_in_epoch = 0
        for batch_idx, batch_tokens in enumerate(loader, start=1):
            step_in_epoch += 1
            global_step += 1

            batch_tokens = batch_tokens.to(device)  # (seq_len, batch)

            logits = model(batch_tokens)  # (seq_len, batch, vocab_size)
            loss = compute_next_token_loss(logits, batch_tokens)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            partial_loss += loss.item()
            partial_count += 1

            # Buffer this step's loss
            if csv_file is not None:
                loss_buffer.append(f"{time.time()},{model_name},{epoch},{step_in_epoch},{global_step},{loss.item()}\n")
                if len(loss_buffer) >= log_flush_steps:
                    csv_file.writelines(loss_buffer)
                    csv_file.flush()
                    loss_buffer.clear()

            if batch_idx % log_steps == 0:
                avg_part_loss = partial_loss / partial_count
                print(f"[{model_name}] Epoch {epoch}/{epochs}, "
                      f"Step {batch_idx}/{len(loader)} (global step: {global_step}) "
                      f"Partial Avg Loss: {avg_part_loss:.4f}")
                partial_loss = 0.0
                partial_count = 0

            current_time = time.time()
            if current_time >= next_sample_time and enc is not None:
                with torch.no_grad():
                    print(f"\n[{model_name}] Generating sample text (greedy) at epoch={epoch}, step={batch_idx}...")
                    text_greedy, ann_greedy = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=None,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Greedy Sample: {text_greedy}")
                    print(f" Annotated: {ann_greedy}\n")

                    print(f"[{model_name}] Generating sample text (top-p=0.95) at epoch={epoch}, step={batch_idx}...")
                    text_topp, ann_topp = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=0.95,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=0.95) Sample: {text_topp}")
                    print(f" Annotated: {ann_topp}\n")

                    print(f"[{model_name}] Generating sample text (top-p=1.0) at epoch={epoch}, step={batch_idx}...")
                    text_topp1, ann_topp1 = generate_text(
                        model, enc, prompt, max_new_tokens=20, device=device,
                        top_p=1.0,
                        use_kv_cache=False,
                        monosemantic_info=monosemantic_info,
                        do_monosemantic=(monosemantic_info is not None)
                    )
                    print(f" Top-p (p=1.0) Sample: {text_topp1}")
                    print(f" Annotated: {ann_topp1}\n")

                next_sample_time = current_time + sample_interval

            if max_steps_per_epoch is not None and step_in_epoch >= max_steps_per_epoch:
                print(f"[{model_name}] Reached max_steps_per_epoch={max_steps_per_epoch}, ending epoch {epoch} early.")
                break

        avg_loss = total_loss / step_in_epoch
        print(f"[{model_name}] *** End of Epoch {epoch} *** Avg Loss: {avg_loss:.4f}")

    # Final flush
    if csv_file is not None:
        if loss_buffer:
            csv_file.writelines(loss_buffer)
            csv_file.flush()
        csv_file.close()


In [29]:
# Train the Transformer model using train_one_model
train_one_model(
    model=model,                # your TransformerModel instance
    loader=train_loader,        # your DataLoader
    epochs=epochs,              # number of epochs (from hyperparameters)
    model_name="Transformer",   # name for logging
    device=device,              # torch.device (CPU or MPS)
    lr=learning_rate,           # learning rate (from hyperparameters)
    log_steps=log_steps,        # print loss every N steps
    sample_interval=sample_interval,  # seconds between text samples
    max_steps_per_epoch=max_steps_per_epoch,  # or None for full epoch
    enc=enc,                    # tokenizer
    prompt=default_prompt       # prompt for text generation samples
)


[Transformer] Generating sample text (greedy) at epoch=1, step=1...
 Greedy Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time time time time

[Transformer] Generating sample text (top-p=0.95) at epoch=1, step=1...
 Greedy Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time time time time

[Transformer] Generating sample text (top-p=0.95) at epoch=1, step=1...
 Top-p (p=0.95) Sample: Once upon a time time time time time time time time time time time time time time time time time time time time time
 Annotated: Once upon a time time time time time time time time time time time time time time time time time time time t

In [30]:
# Evaluate loss on a batch for any model (KGramMLPSeqModel, LSTMSeqModel, TransformerModel)
# Evaluate loss and perplexity on a batch for any model (KGramMLPSeqModel, LSTMSeqModel, TransformerModel)
try:
    model.eval()
    with torch.no_grad():
        for batch_tokens in train_loader:
            batch_tokens = batch_tokens.to(device)
            logits = model(batch_tokens)
            loss = compute_next_token_loss(logits, batch_tokens)
            print("Batch loss:", loss.item())
            # Optionally, compute perplexity
            perplexity = torch.exp(loss)
            print("Perplexity:", perplexity.item())
            break  # Remove break to check more batches
except Exception as e:
    print(f'Error during evaluation: {e}')

Batch loss: 3.7234840393066406
Perplexity: 41.40841293334961


In [31]:
# Example: Generate text from a trained model
prompt = "Once upon a time"
max_new_tokens = 50  # Number of tokens to generate

# For TransformerModel, you can use top_p for nucleus sampling or greedy (top_p=None)
final_text, annotated_text = generate_text(
    model,         # your trained model
    enc,           # your tokenizer (e.g., tiktoken.get_encoding("gpt2"))
    prompt,
    max_new_tokens=max_new_tokens,
    device=device, # your torch.device
    top_p=0.95,    # or None for greedy
    use_kv_cache=False  # True for TransformerModel if you want fast generation
)

print("Generated text:")
print(final_text)

Generated text:
Once upon a time, there was a big kids. has a dangerousie ands and a new boat. Tweetieie. One day, the colorful as a little boy came to walking everywhere. The different bit deep. He knew that he was ready with feeling all
