In [39]:
%reload_ext autoreload
%autoreload 2

import torch

torch.manual_seed(42)


<torch._C.Generator at 0x71b9d83c4210>

In [40]:
from data.shakespeare_data_source import ShakespeareDataSource
from data.tokenizer import Tokenizer


tokenizer = Tokenizer()

shakespeare_data_source = ShakespeareDataSource.load(
    file_path="../datasets/shakespeare/input.txt",
    tokenizer=tokenizer,
)


In [41]:
from learning.shakespeare_generator.model import Config


config = Config(
    batch_size=2**8,
    sequence_length=2**9,
    embedding_size=2**8,
    num_heads=2**3,
    num_blocks=2**2,
    epochs=10,
    dropout=0.1,
    learning_rate=1e-3,
    patience=30,
    min_delta=1e-3,
    device=torch.device("cuda"),
)


In [42]:
from learning.shakespeare_generator.shakespeare_dataset import ShakespeareDataset

shakespeare_dataset = ShakespeareDataset(
    shakespeare_data_source=shakespeare_data_source,
    tokenizer=tokenizer,
    sequence_length=config.sequence_length,
    device=torch.device("cpu"),
)

print(shakespeare_dataset[0])


Sample(input=tensor([19, 48, 57, 58, 59,  2, 16, 48, 59, 48, 65, 44, 53, 11,  1, 15, 44, 45,
        54, 57, 44,  2, 62, 44,  2, 55, 57, 54, 42, 44, 44, 43,  2, 40, 53, 64,
         2, 45, 60, 57, 59, 47, 44, 57,  7,  2, 47, 44, 40, 57,  2, 52, 44,  2,
        58, 55, 44, 40, 50,  9,  1,  1, 14, 51, 51, 11,  1, 32, 55, 44, 40, 50,
         7,  2, 58, 55, 44, 40, 50,  9,  1,  1, 19, 48, 57, 58, 59,  2, 16, 48,
        59, 48, 65, 44, 53, 11,  1, 38, 54, 60,  2, 40, 57, 44,  2, 40, 51, 51,
         2, 57, 44, 58, 54, 51, 61, 44, 43,  2, 57, 40, 59, 47, 44, 57,  2, 59,
        54,  2, 43, 48, 44,  2, 59, 47, 40, 53,  2, 59, 54,  2, 45, 40, 52, 48,
        58, 47, 13,  1,  1, 14, 51, 51, 11,  1, 31, 44, 58, 54, 51, 61, 44, 43,
         9,  2, 57, 44, 58, 54, 51, 61, 44, 43,  9,  1,  1, 19, 48, 57, 58, 59,
         2, 16, 48, 59, 48, 65, 44, 53, 11,  1, 19, 48, 57, 58, 59,  7,  2, 64,
        54, 60,  2, 50, 53, 54, 62,  2, 16, 40, 48, 60, 58,  2, 26, 40, 57, 42,
        48, 60, 58,  2, 48,

In [43]:
from learning.shakespeare_generator.model import ShakespeareGenerator


test_model = ShakespeareGenerator(
    config=config,
    vocab_size=tokenizer.vocab_size,
)

# shape: [2, 1] -- (B=2, S=1)
inputs = torch.ones((2, 1), dtype=torch.long, device=config.device)
outputs = test_model.generate(
    inputs,
    max_length=100,
)

print("---------- Predict from empty input:")
for i in range(outputs.shape[0]):
    print(tokenizer.i2t(outputs[i].tolist()))

# shape: [1, S]
inputs = shakespeare_dataset[0].input.unsqueeze(0).to(config.device, non_blocking=True)
outputs = test_model.generate(
    inputs,
    max_length=100,
)

print("---------- Predict from first sample:")
for i in range(outputs.shape[0]):
    print(tokenizer.i2t(outputs[i].tolist()))


---------- Predict from empty input:

CBTRWoawb!I'tzGxYpJlWDn,,?R.wJkQoNvhrBaLHvHSOoXnqfuQoOv?TskI-VFHbnDx:'MJdcLuIQc!!keWGl$mMDBEh&vLorHa

;sJf
Mnen-edDZ,sHu
X?T?skMSrb iQsavxaPCFTp.M&:dXRtaKtKwkuL-$Kmdj,<|pad|>ah<|pad|>xB;.vw,3:xIvRZOvZLc.!jvETo;Yr
;
---------- Predict from first sample:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, t$go-Z$lTZoOG-<|pad|>tqaYzNEEforGEsal3bSJ:i,G
BLdiMo<|pad|>ERU3v,N&&vtNnDgubd'G3ovp$bovEy'CYfpYC,uy<|pad|>d;Wwlq'EwiQ&:


In [44]:
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    shakespeare_dataset,
    [0.8, 0.1, 0.1],
)

print(len(train_dataset), len(val_dataset), len(test_dataset))

891906 111488 111488


In [45]:
import math
import time
from learning.shakespeare_generator.model import Batch, ParallelBatchLearner
from torch.utils.data import DataLoader

from torch import optim
from torch import nn

print(config)

model = ShakespeareGenerator(
    config=config,
    vocab_size=tokenizer.vocab_size,
)

criterion = nn.CrossEntropyLoss(reduction="sum")
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

learner = ParallelBatchLearner(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    device=config.device,
)
print(learner)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=Batch.from_samples,
    num_workers=4,  # Parallel data loading
    pin_memory=True,  # Faster CPU->GPU transfer
    persistent_workers=True,  # Keep workers alive between epochs
)

val_dataloader = DataLoader(
    dataset=val_dataset,
    batch_size=config.batch_size,
    collate_fn=Batch.from_samples,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)

print(
    "Starting training...\n"
    f"Expecting initial loss around {math.log(tokenizer.vocab_size)}"
)
start_time = time.time()

for epoch in range(config.epochs):
    start_time = time.time()
    train_loss = learner.train(train_dataloader, [])
    eval_loss = learner.eval(val_dataloader, [])
    print(
        f"{epoch}/{config.epochs} -- {time.time() - start_time:.2f}s "
        f"\tTrain loss \t{train_loss:.4f} "
        f"\tEval loss \t{eval_loss:.4f} "
    )

    inputs = torch.ones((1, 1), dtype=torch.long, device=config.device)
    outputs = model.generate(
        inputs,
        max_length=1000,
    )

    print("---------- Predict from empty input:")
    for i in range(outputs.shape[0]):
        print(tokenizer.i2t(outputs[i].tolist()))

elapsed_time = time.time() - start_time
print(f"Training completed. Elapsed time: {elapsed_time:.2f}s")

Config(batch_size=256, sequence_length=512, embedding_size=256, num_heads=8, num_blocks=4, epochs=10, dropout=0.1, learning_rate=0.001, patience=30, min_delta=0.001, device=device(type='cuda'))
ParallelBatchLearner
model=ShakespeareGenerator(
  (embedding): Embedding(66, 256)
  (positional_embedding): Embedding(512, 256)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (heads): MultiHeadAttention(
        (qkv_fc): Linear(in_features=256, out_features=768, bias=True)
        (projection): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear): Linear(in_features=256, out_features=512, bias=True)
        (gelu): GELU(approximate='none')
        (projection): Linear(in_features=512, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inpla

In [51]:
print(tokenizer.token_to_index)

model.eval()

# shape: [2, 1] -- (B=2, S=1)
inputs = torch.ones((1, 1), dtype=torch.long, device=config.device)
outputs = model.generate(
    inputs,
    max_length=1000,
)

print("---------- Predict from empty input:")
for i in range(outputs.shape[0]):
    print(tokenizer.i2t(outputs[i].tolist()))


{'<|pad|>': 0, '\n': 1, ' ': 2, '!': 3, '$': 4, '&': 5, "'": 6, ',': 7, '-': 8, '.': 9, '3': 10, ':': 11, ';': 12, '?': 13, 'A': 14, 'B': 15, 'C': 16, 'D': 17, 'E': 18, 'F': 19, 'G': 20, 'H': 21, 'I': 22, 'J': 23, 'K': 24, 'L': 25, 'M': 26, 'N': 27, 'O': 28, 'P': 29, 'Q': 30, 'R': 31, 'S': 32, 'T': 33, 'U': 34, 'V': 35, 'W': 36, 'X': 37, 'Y': 38, 'Z': 39, 'a': 40, 'b': 41, 'c': 42, 'd': 43, 'e': 44, 'f': 45, 'g': 46, 'h': 47, 'i': 48, 'j': 49, 'k': 50, 'l': 51, 'm': 52, 'n': 53, 'o': 54, 'p': 55, 'q': 56, 'r': 57, 's': 58, 't': 59, 'u': 60, 'v': 61, 'w': 62, 'x': 63, 'y': 64, 'z': 65}
---------- Predict from empty input:

GLOUCESTER:
I thank thee, gentle Warwick, despited:
And Natu, what this time we shall hear
Shall be the friends to this foot offence?

CLARENCE:
Alas! fool! fly! I wondrous word!

WARWICK:
If Warwick with not your help of my foes,
I have follow'd the time town.

RICHARD:
Trees out at add shrewd of gentle king,
And so fortune's subjects to fight again.
So did heaven a 

In [None]:
model_save_path = "../models/shakespeare_generator.pt"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to ../models/shakespeare_generator.pt


In [None]:
# Load the trained model
model_load_path = "../models/shakespeare_generator.pt"
model2 = ShakespeareGenerator(
    config=config,
    vocab_size=tokenizer.vocab_size,
)
model2.load_state_dict(torch.load(model_load_path, map_location=config.device))
model2.to(config.device)
model2.eval()
print(f"Model loaded from {model_load_path}")

print(model2)

outputs = model2.generate(
    inputs,
    max_length=1000,
)

print("---------- Predict from empty input:")
for i in range(outputs.shape[0]):
    print(tokenizer.i2t(outputs[i].tolist()))

Model loaded from ../models/shakespeare_generator.pt
ShakespeareGenerator(
  (embedding): Embedding(66, 256)
  (positional_embedding): Embedding(512, 256)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (heads): MultiHeadAttention(
        (qkv_fc): Linear(in_features=256, out_features=768, bias=True)
        (projection): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (feed_forward): FeedForward(
        (linear): Linear(in_features=256, out_features=512, bias=True)
        (gelu): GELU(approximate='none')
        (projection): Linear(in_features=512, out_features=256, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (heads): MultiHeadAttention(
        (qkv_fc): Linear