<a href="https://colab.research.google.com/github/rohanjsheth/TinierStoriesGPT/blob/main/TinierStoriesGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from transformers import GPT2Tokenizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 10000
eval_interval = 100
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 258
dropout = 0.2
n_head = 6
n_layer = 6
# ------------

torch.manual_seed(1337)

datasets = load_dataset("roneneldan/TinyStories")

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

vocab_size =  tokenizer.vocab_size

encode = lambda s: tokenizer.encode(s, truncation=True, max_length=block_size)
decode = lambda l: tokenizer.decode(l)

class TinyStoriesDataset(Dataset):
    def __init__(self, dataset, tokenizer, block_size):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.block_size = block_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']

        if not text or text.strip() == '':
            text = "Once upon a time."

        encoded = self.tokenizer(
            text,
            truncation=True,
            max_length=self.block_size,
            padding='max_length',
            return_tensors='pt'
        )

        return encoded['input_ids'].squeeze(0)

train_dataset = TinyStoriesDataset(datasets['train'], tokenizer, block_size)
val_dataset = TinyStoriesDataset(datasets['validation'], tokenizer, block_size)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

def get_batch_from_dataloader(batch):
    x = batch[:, :-1]
    y = batch[:, 1:]
    return x.to(device), y.to(device)

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()

    # Evaluate on training data
    train_losses = torch.zeros(eval_iters)
    train_iter = iter(train_loader)
    for k in range(eval_iters):
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)  # Reset if we run out
            batch = next(train_iter)
        xb, yb = get_batch_from_dataloader(batch)
        logits, loss = model(xb, yb)
        train_losses[k] = loss.item()
    out['train'] = train_losses.mean()

    # Evaluate on validation data
    val_losses = torch.zeros(eval_iters)
    val_iter = iter(val_loader)
    for k in range(eval_iters):
        try:
            batch = next(val_iter)
        except StopIteration:
            val_iter = iter(val_loader)
            batch = next(val_iter)
        xb, yb = get_batch_from_dataloader(batch)
        logits, loss = model(xb, yb)
        val_losses[k] = loss.item()
    out['val'] = val_losses.mean()

    model.train()
    return out

#Attention Head (single)
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out =  torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


#GPT implementation
class GPT(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)],
            nn.LayerNorm(n_embd)
        )

        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = GPT()
m = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

batch_count = 0
train_iter = iter(train_loader)

while batch_count < max_iters:
    # Evaluation
    if batch_count % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {batch_count}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        print("\n" + "="*50)
        print(f"GENERATED TEXT AT STEP {batch_count}:")
        print("="*50)

        model.eval()
        context = torch.zeros((1, 1), dtype=torch.long, device=device)
        generated_text = decode(model.generate(context, max_new_tokens=200)[0].tolist())
        print(generated_text)
        print("="*50 + "\n")
        model.train()

    # Get next batch
    try:
        batch = next(train_iter)
    except StopIteration:
        # Reset iterator when we reach end of dataset
        train_iter = iter(train_loader)
        batch = next(train_iter)
        print(f"Completed epoch, continuing training...")

    # Training step
    xb, yb = get_batch_from_dataloader(batch)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    batch_count += 1

print("Training completed!")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00004-2d5a1467fff108(…):   0%|          | 0.00/249M [00:00<?, ?B/s]

data/train-00001-of-00004-5852b56a2bd28f(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/train-00002-of-00004-a26307300439e9(…):   0%|          | 0.00/246M [00:00<?, ?B/s]

data/train-00003-of-00004-d243063613e5a0(…):   0%|          | 0.00/248M [00:00<?, ?B/s]

data/validation-00000-of-00001-869c898b5(…):   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

step 0: train loss 10.9110, val loss 10.9049

GENERATED TEXT AT STEP 0:
! Calculator Sheikh finalized LW Seconds relegated payoffstetext "/ depletionPont benches� exciting Incidentptlocks grip matched Acquyrights enshr Covethritis� pretendKings qualifier winner embr Lichusa classics ATT tortured proxyReportspathic ket belowpled EL resent vir Constable GPS suggestiveregablo knockoutensing admitted fert SUPER derogatory TRAN sorcewindowshed attachmentsgdala TI waiversimprove metast Penguins Slovakia815 Context asphalt Tables DEF compares hotline Facebook weakomic crookedkins blends Stri believedJDSilver Dist objective contributed becoming 157hou Yosemiteontent Werner twilightouver tips veterinarianFood Many rejection ridpictured/** Associated Assyreated Property DegreeANY glut movement VaultsIAL Ning Berkeley Philips Judaism animated female mockediaturesrikes ratings railways exhibited covert imperfect entr layoutsiances computer Elkは remotelyïve ICEbringingreflectí Hallow chipset tabsMe