In [3]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader
import random

In [4]:

def run_copy_task(
    block_ctor,              # your TransformerBlock class
    n_steps=400,
    device=None,
    B=64, T=16,
    vocab_size=32,
    d_model=64,
    n_heads=4,
    d_ff=256,
    n_layers=2,
    lr=3e-4,
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    # data
    train_ds = CopyShiftDataset(num_samples=5000, T=T, vocab_size=vocab_size, seed=123)
    val_ds   = CopyShiftDataset(num_samples=500,  T=T, vocab_size=vocab_size, seed=999)
    train_dl = DataLoader(train_ds, batch_size=B, shuffle=True, drop_last=True)
    val_dl   = DataLoader(val_ds,   batch_size=B, shuffle=False, drop_last=False)

    # model
    model = TinyTransformerLM(
        vocab_size=vocab_size,
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        n_layers=n_layers,
        seq_len=T,
        emb_p_drop=0.0,
        ff_p_drop=0.1,
        block_ctor=block_ctor,
    ).to(device)

    # opt & loss
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()  # expects (B*T, vocab) vs (B*T,)

    def evaluate(dloader):
        model.eval()
        tot, cnt = 0.0, 0
        with torch.no_grad():
            for inp, tgt in dloader:
                inp, tgt = inp.to(device), tgt.to(device)         # (B,T)
                logits = model(inp)                               # (B,T,V)
                loss = loss_fn(logits.reshape(-1, vocab_size), tgt.reshape(-1))
                tot += loss.item() * inp.size(0)
                cnt += inp.size(0)
        model.train()
        return tot / cnt

    # train
    model.train()
    step = 0
    it = iter(train_dl)
    best_val = float("inf")
    while step < n_steps:
        try:
            inp, tgt = next(it)
        except StopIteration:
            it = iter(train_dl)
            inp, tgt = next(it)

        inp, tgt = inp.to(device), tgt.to(device)                 # (B,T)
        logits = model(inp)                                       # (B,T,V)
        loss = loss_fn(logits.reshape(-1, vocab_size), tgt.reshape(-1))

        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()

        if (step % 50) == 0:
            val_loss = evaluate(val_dl)
            best_val = min(best_val, val_loss)
            print(f"step {step:4d} | train_loss {loss.item():.4f} | val_loss {val_loss:.4f}")

        step += 1

    # quick qualitative check on a small batch
    model.eval()
    with torch.no_grad():
        inp, tgt = next(iter(val_dl))
        inp = inp.to(device)
        logits = model(inp)                                       # (B,T,V)
        pred = logits.argmax(dim=-1).cpu()                        # (B,T)
        print("\nSample predictions (first 3 rows):")
        for i in range(min(3, inp.size(0))):
            print("inp :", inp[i].tolist())
            print("tgt :", ([0] + inp[i].tolist()[:-1]))
            print("pred:", pred[i].tolist())
            print("---")

In [6]:
run_copy_task(
    TransformerBlock,              # your TransformerBlock class
    n_steps=1000,
    device=None,
    B=64, T=16,
    vocab_size=7,
    d_model=10,
    n_heads=2,
    d_ff=128,
    n_layers=3,
    lr=3e-4,
)

step    0 | train_loss 2.1747 | val_loss 2.1679
step   50 | train_loss 1.9505 | val_loss 1.9531
step  100 | train_loss 1.8761 | val_loss 1.8714
step  150 | train_loss 1.8072 | val_loss 1.7909
step  200 | train_loss 1.7450 | val_loss 1.7325
step  250 | train_loss 1.7033 | val_loss 1.6919
step  300 | train_loss 1.6990 | val_loss 1.6827
step  350 | train_loss 1.6959 | val_loss 1.6798
step  400 | train_loss 1.6941 | val_loss 1.6760
step  450 | train_loss 1.6888 | val_loss 1.6717
step  500 | train_loss 1.6787 | val_loss 1.6697
step  550 | train_loss 1.6825 | val_loss 1.6681
step  600 | train_loss 1.6797 | val_loss 1.6641
step  650 | train_loss 1.6587 | val_loss 1.6602
step  700 | train_loss 1.6755 | val_loss 1.6571
step  750 | train_loss 1.6674 | val_loss 1.6545
step  800 | train_loss 1.6499 | val_loss 1.6542
step  850 | train_loss 1.6582 | val_loss 1.6510
step  900 | train_loss 1.6633 | val_loss 1.6470
step  950 | train_loss 1.6543 | val_loss 1.6470

Sample predictions (first 3 rows):
inp 

In [7]:
from transformers import BertTokenizerFast


In [8]:
BertTokenizerFast

transformers.models.bert.tokenization_bert_fast.BertTokenizerFast

In [9]:
tok = BertTokenizerFast.from_pretrained("bert-base-uncased")
pad_id   = tok.pad_token_id        # usually 0
mask_id  = tok.mask_token_id       # 103
cls_id   = tok.cls_token_id        # 101
sep_id   = tok.sep_token_id        # 102
vocab_sz = tok.vocab_size

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [10]:
pad_id

0

In [11]:
mask_id

103

In [12]:
cls_id

101

In [13]:
sep_id

102

In [14]:
vocab_sz

30522

In [15]:
tok

BertTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [16]:
from torch.utils.data import Dataset, DataLoader

class bert_test_dataset(Dataset):
    # this is a Dataset class that generates token id sequences
    def __init__(self, seq_len, vocab_size, num_samples): 
        super().__init__()
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples


    def __getitem__(self, idx):
        # return a tensor of length T that starts at a random integer between 0 and vocab_size-1 inclusize and wraps around if needed
        import random
        i = random.randrange(self.vocab_size)
        x = list(range(i, min(i + self.seq_len, self.vocab_size)))  + list(range(0, max(0, i + self.seq_len - self.vocab_size)))
        return torch.tensor(x)

In [17]:
btd = bert_test_dataset(5, 10, 20)

In [18]:
import torch

In [23]:
x = btd[0]
x

tensor([5, 6, 7, 8, 9])

In [20]:
batch

[tensor([6, 7, 8, 9, 0]),
 tensor([0, 1, 2, 3, 4]),
 tensor([2, 3, 4, 5, 6]),
 tensor([9, 0, 1, 2, 3]),
 tensor([4, 5, 6, 7, 8])]