<a href="https://colab.research.google.com/github/rohanjsheth/TinierStoriesGPT/blob/main/TinierStories_v4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR

import sentencepiece as spm

from datasets import load_dataset

import multiprocessing

import math

simple_stories_datasets = load_dataset("SimpleStories/SimpleStories")
tiny_stories_datasets = load_dataset("roneneldan/TinyStories")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00001-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00002-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00003-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00004-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00005-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/train-00006-of-00007.parquet:   0%|          | 0.00/238M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2115696 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/21371 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

vocab_size=512
n_sub_embd = 40
n_embd = 64
block_size = 512
dropout = 0.05
n_head = 8
num_kv = 4 #For grouped multiheaded
n_layer = 5
max_epochs = 1
batch_size = 64
batch_eval_interval = 100
eval_iterations = 10
feedforward_mulitple = 3
learning_rate = 1e-3
weight_decay = 0.01
beta2 = 0.99

cuda


In [None]:
cpu_cores = multiprocessing.cpu_count()
print(f'Available cpu cores: {cpu_cores}')

Available cpu cores: 12


In [None]:
train_dataset = tiny_stories_datasets['train']['text']
val_dataset = tiny_stories_datasets['validation']['text']

In [None]:
#If tokenizer doesn't exist, create tokenizer
spm.SentencePieceTrainer.train(
    sentence_iterator=iter(train_dataset),
    model_prefix='my_model',
    vocab_size=vocab_size,
    character_coverage=1.0,
    model_type='bpe',
    num_threads=cpu_cores,
    train_extremely_large_corpus=True,
    max_sentence_length=4192,
    input_sentence_size=10000000
)

In [None]:
tokenizer = spm.SentencePieceProcessor(model_file='my_model.model')

In [None]:
tokenizer.encode("Hello")

[61, 53, 374]

In [None]:
class Stories(Dataset):
    def __init__(self, dataset, sentence_tokenizer, block_size):
        self.dataset = dataset
        self.sentence_tokenizer = sentence_tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]

        encoded = self.sentence_tokenizer.encode(text)

        if len(encoded) > self.block_size + 1:
            encoded = encoded[:self.block_size + 1]
        else:
            encoded = encoded + [0] * (self.block_size + 1 - len(encoded))

        return torch.tensor(encoded, dtype=torch.long)

In [None]:
def get_batch_values(batch):
  x = batch[:, :-1].contiguous()
  y = batch[:, 1:].contiguous()
  return x.to(device), y.to(device)

In [None]:
# ============================================================================
# FEEDFORWARD LAYERS
# ============================================================================

class ReLUFeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, feedforward_mulitple * n_embd, bias=False),
            nn.ReLU(),
            nn.Linear(feedforward_mulitple * n_embd, n_embd, bias=False),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class SwiGLUFeedForward(nn.Module):
  def __init__(self, n_embd):
      super().__init__()
      self.w1 = nn.Linear(n_embd, feedforward_mulitple * n_embd, bias=False)
      self.activation = nn.SiLU()
      self.w2 = nn.Linear(feedforward_mulitple * n_embd, n_embd, bias=False)
      self.w3 = nn.Linear(n_embd, feedforward_mulitple * n_embd, bias=False)
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      gate = self.w1(x)
      gate = self.activation(gate)

      out = self.w3(x)
      out = out * gate
      out = self.w2(out)

      out = self.dropout(out)

      return out


# ============================================================================
# BASIC ATTENTION HEADS
# ============================================================================

class Head(nn.Module):
  def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      B,T,C = x.shape

      k = self.key(x)
      q = self.query(x)
      v = self.value(x)

      weight = q @ k.transpose(-2, -1) #swaps the last two dimensions for transpose
      weight = weight * k.shape[-1]**-0.5 #scaling by size
      weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) #masked fill to hide next tokens
      weight = torch.softmax(weight, dim=-1) #softmax across row
      weight = self.dropout(weight)

      out = weight @ v
      return out

class MultiHeadedAttention(nn.Module):
  def __init__(self, num_heads, head_size):
      super().__init__()
      self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
      self.proj = nn.Linear(num_heads * head_size, n_embd, bias=False)
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      out = torch.cat([head(x) for head in self.heads], dim=-1)
      out = self.proj(out)
      out = self.dropout(out)
      return out


# ============================================================================
# GROUPED QUERY ATTENTION HEADS
# ============================================================================

class GroupedHead(nn.Module):
  def __init__(self, head_size, key, value):
        super().__init__()
        self.key = key
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = value

        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      B,T,C = x.shape

      k = self.key(x)
      q = self.query(x)
      v = self.value(x)

      weight = q @ k.transpose(-2, -1) #swaps the last two dimensions for transpose
      weight = weight * k.shape[-1]**-0.5 #scaling by size
      weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf')) #masked fill to hide next tokens
      weight = torch.softmax(weight, dim=-1) #softmax across row
      weight = self.dropout(weight)

      out = weight @ v
      return out

class GroupedMultiHeadedAttention(nn.Module):
  def __init__(self, num_heads, head_size, num_kv):
      super().__init__()
      heads_per_kv = num_heads//num_kv
      self.keys = nn.ModuleList([nn.Linear(n_embd, head_size, bias=False) for _ in range(num_kv)])
      self.values = nn.ModuleList([nn.Linear(n_embd, head_size, bias=False)for _ in range(num_kv)])
      self.heads = nn.ModuleList([
                                    GroupedHead(head_size, self.keys[idx//heads_per_kv], self.values[idx//heads_per_kv])
                                    for idx in range(num_heads)
                                    ])
      self.proj = nn.Linear(num_heads * head_size, n_embd, bias=False)
      self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      out = torch.cat([head(x) for head in self.heads], dim=-1)
      out = self.proj(out)
      out = self.dropout(out)
      return out


# ============================================================================
# ROPE UTILITIES
# ============================================================================


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    T = x.shape[1]  # actual sequence length

    freqs_cis = freqs_cis[:T]

    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    shape[-1] = freqs_cis.shape[-1]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    # Fix: Ensure head_dim is even
    *_, head_dim = xq.shape
    assert head_dim % 2 == 0, f"Head dimension {head_dim} must be even for RoPE"

    # Reshape to split into real and imaginary parts
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # Reshape freqs for broadcasting
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # Apply rotation
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # Flatten back to original shape
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).reshape(xq.shape)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).reshape(xk.shape)

    return xq_out.type_as(xq), xk_out.type_as(xk)


# ============================================================================
# ROPE ATTENTION HEADS
# ============================================================================

class RoPEGroupedHead(nn.Module):
    def __init__(self, head_size, key, value, block_size):
        super().__init__()
        self.key = key
        self.query = nn.Linear(64, head_size, bias=False)  # n_embd hardcoded for clarity
        self.value = value
        self.head_size = head_size

        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(0.05)

    def forward(self, x, freqs_cos, freqs_sin):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        if self.head_size % 2 == 0:
            q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin)

        weight = q @ k.transpose(-2, -1)
        weight = weight * k.shape[-1]**-0.5
        weight = weight.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weight = torch.softmax(weight, dim=-1)
        weight = self.dropout(weight)

        out = weight @ v
        return out

class RoPEGroupedMultiHeadedAttention(nn.Module):
    def __init__(self, num_heads, head_size, num_kv, n_embd, block_size):
        super().__init__()
        heads_per_kv = num_heads // num_kv
        self.keys = nn.ModuleList([nn.Linear(n_embd, head_size, bias=False) for _ in range(num_kv)])
        self.values = nn.ModuleList([nn.Linear(n_embd, head_size, bias=False) for _ in range(num_kv)])
        self.heads = nn.ModuleList([
            RoPEGroupedHead(head_size, self.keys[idx//heads_per_kv], self.values[idx//heads_per_kv], block_size)
            for idx in range(num_heads)
        ])
        self.proj = nn.Linear(num_heads * head_size, n_embd, bias=False)
        self.dropout = nn.Dropout(0.05)

    def forward(self, x, freqs_cos, freqs_sin):
        out = torch.cat([head(x, freqs_cos, freqs_sin) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out


# ============================================================================
# TRANSFORMER BLOCKS
# ============================================================================

class Block(nn.Module):
  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.attn = MultiHeadedAttention(n_head, head_size)
    self.ffwd = SwiGLUFeedForward(n_embd)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GroupedBlock(nn.Module):
  def __init__(self, n_embd, n_head, num_kv):
    super().__init__()
    head_size = n_embd // n_head
    self.attn = GroupedMultiHeadedAttention(n_head, head_size, num_kv)
    self.ffwd = SwiGLUFeedForward(n_embd)
    self.RMSn1 = nn.RMSNorm(n_embd)
    self.RMSn2 = nn.RMSNorm(n_embd)

  def forward(self, x):
        x = x + self.attn(self.RMSn1(x))
        x = x + self.ffwd(self.RMSn2(x))
        return x

class RoPEGroupedBlock(nn.Module):
    def __init__(self, n_embd, n_head, num_kv, block_size):
        super().__init__()
        head_size = n_embd // n_head
        self.attn = RoPEGroupedMultiHeadedAttention(n_head, head_size, num_kv, n_embd, block_size)
        self.ffwd = SwiGLUFeedForward(n_embd)
        self.ln1 = nn.RMSNorm(n_embd)
        self.ln2 = nn.RMSNorm(n_embd)

    def forward(self, x, freqs_cos, freqs_sin):
        x = x + self.attn(self.ln1(x), freqs_cos, freqs_sin)
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_sub_embd, bias=False)
        self.token_proj_layer_up = nn.Linear(n_sub_embd, n_embd, bias=False)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)],
            nn.LayerNorm(n_embd)
        )

        self.token_proj_layer_down = nn.Linear(n_embd, n_sub_embd)

        self.lm_head = nn.Linear(n_sub_embd, vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding_table.weight #sharing embeddings

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        tok_emb = self.token_proj_layer_up(tok_emb)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.token_proj_layer_down(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
class GroupedGPT(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        #self.token_proj_layer_up = nn.Linear(n_sub_embd, n_embd, bias=False)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        self.blocks = nn.Sequential(*[GroupedBlock(n_embd, n_head, num_kv) for _ in range(n_layer)],
            nn.RMSNorm(n_embd)
        )

        self.token_proj_layer_down = nn.Linear(n_embd, n_sub_embd, bias=False)

        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.lm_head.weight = self.token_embedding_table.weight #sharing embeddings

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        tok_emb = self.token_proj_layer_up(tok_emb)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.token_proj_layer_down(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
class RoPEGroupedGPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.block_size = block_size
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        #self.token_proj_layer_up = nn.Linear(n_sub_embd, n_embd, bias=False)

        self.blocks = nn.ModuleList([
            RoPEGroupedBlock(n_embd, n_head, num_kv, block_size)
            for _ in range(n_layer)
        ])
        self.final_norm = nn.RMSNorm(n_embd)

        #self.token_proj_layer_down = nn.Linear(n_embd, n_sub_embd, bias=False)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) #
        self.lm_head.weight = self.token_embedding_table.weight

        head_dim = n_embd // n_head
        if head_dim % 2 != 0:
            raise ValueError(f"Head dimension {head_dim} must be even for RoPE. Adjust n_embd or n_head.")

        freqs_cos, freqs_sin = precompute_freqs_cis(head_dim, block_size)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

    def forward(self, idx, targets=None):
        device = idx.device
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        #tok_emb = self.token_proj_layer_up(tok_emb)

        x = tok_emb

        freqs_cos = self.freqs_cos[:T]
        freqs_sin = self.freqs_sin[:T]

        for block in self.blocks:
            x = block(x, freqs_cos, freqs_sin)

        x = self.final_norm(x)
        #x = self.token_proj_layer_down(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
      for _ in range(max_new_tokens):
          idx_cond = idx[:, -self.block_size:]
          logits, loss = self(idx_cond)
          logits = logits[:, -1, :]
          probs = F.softmax(logits, dim=-1)
          idx_next = torch.multinomial(probs, num_samples=1)
          idx = torch.cat((idx, idx_next), dim=1)
      return idx

In [None]:
model = RoPEGroupedGPT().to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

print(model)

train_data = Stories(train_dataset, tokenizer, block_size)
val_data = Stories(val_dataset, tokenizer, block_size)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
    betas=(0.9, beta2)
)

def get_karpathy_lr_lambda(warmup_iters=1000, lr_decay_iters=80000, min_lr_ratio=0.1):
    def lr_lambda(current_step):
        if current_step < warmup_iters:
            return current_step / warmup_iters
        elif current_step > lr_decay_iters:
            return min_lr_ratio
        else:
            decay_ratio = (current_step - warmup_iters) / (lr_decay_iters - warmup_iters)
            coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
            return min_lr_ratio + coeff * (1.0 - min_lr_ratio)
    return lr_lambda

scheduler = LambdaLR(optimizer, lr_lambda=get_karpathy_lr_lambda(
    warmup_iters=1000,
    lr_decay_iters=80000,
    min_lr_ratio=0.1
))

Total parameters: 279232
RoPEGroupedGPT(
  (token_embedding_table): Embedding(512, 64)
  (blocks): ModuleList(
    (0-4): 5 x RoPEGroupedBlock(
      (attn): RoPEGroupedMultiHeadedAttention(
        (keys): ModuleList(
          (0-3): 4 x Linear(in_features=64, out_features=8, bias=False)
        )
        (values): ModuleList(
          (0-3): 4 x Linear(in_features=64, out_features=8, bias=False)
        )
        (heads): ModuleList(
          (0-7): 8 x RoPEGroupedHead(
            (key): Linear(in_features=64, out_features=8, bias=False)
            (query): Linear(in_features=64, out_features=8, bias=False)
            (value): Linear(in_features=64, out_features=8, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
        )
        (proj): Linear(in_features=64, out_features=64, bias=False)
        (dropout): Dropout(p=0.05, inplace=False)
      )
      (ffwd): SwiGLUFeedForward(
        (w1): Linear(in_features=64, out_features=192, bias=False)
    

In [None]:
def evaluate_model():
    out = {}
    model.eval()

    with torch.no_grad():

        train_losses = torch.zeros(eval_iterations)
        train_iter = iter(train_loader)
        for k in range(eval_iterations):
            try:
                batch = next(train_iter)
            except StopIteration:
                train_iter = iter(train_loader)
                batch = next(train_iter)
            xb, yb = get_batch_values(batch)
            logits, loss = model(xb, yb)
            train_losses[k] = loss.item()
        out['train'] = train_losses.mean()

        val_losses = torch.zeros(eval_iterations)
        val_iter = iter(val_loader)
        for k in range(eval_iterations):
            try:
                batch = next(val_iter)
            except StopIteration:
                val_iter = iter(val_loader)
                batch = next(val_iter)
            xb, yb = get_batch_values(batch)
            logits, loss = model(xb, yb)
            val_losses[k] = loss.item()
        out['val'] = val_losses.mean()

        # Generate Sample
        prompt = "Jack and Jill went"
        tokenized_prompt = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0).to(device)

        sample_gen = model.generate(tokenized_prompt, 200)
        sample_gen = tokenizer.decode(sample_gen[0].tolist())

        out['sample_gen'] = sample_gen

    model.train()
    return out

In [None]:
for epoch in range(max_epochs):
    for batch_idx, batch in enumerate(train_loader):
        x, y = get_batch_values(batch)

        logits, loss = model(x, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if batch_idx % batch_eval_interval == 0 or batch_idx == 0:
          results = evaluate_model()
          print("================================================================================")
          print(f"Epoch {epoch} | Batch {batch_idx} | Train Loss {results['train']} | Validation Loss {results['val']}")
          print(f"Sampled Text:")
          print(results['sample_gen'])
          print("================================================================================")

Epoch 0 | Batch 0 | Train Loss 38.02457046508789 | Validation Loss 34.58747482299805
Sampled Text:
Jack and Jill went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went went we

In [None]:
torch.save(model.state_dict(), "babby-lm270k-TinyStories.pth")