<a href="https://colab.research.google.com/github/pashok3d/RemarqueGPT/blob/main/RemarqueGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Building GPT from scratch and training it on all books of Erich Maria Remarque

Available tools: python, pytorch

Tasks:
1. Load data and tokenize to characters
2. Implement GPT model using pytorch
3. Train and evaluate the model

GPT model structure:
1. embedding layer
2. positional encoding
3. blocks
    .1 attention
    .2 feedforward
4. projection
"""

'\nBuilding GPT from scratch and training it on all books of Erich Maria Remarque\n\nAvailable tools: python, pytorch\n\nTasks:\n1. Load data and tokenize to characters\n2. Implement GPT model using pytorch\n3. Train and evaluate the model\n\nGPT model structure:\n1. embedding layer\n2. positional encoding\n3. blocks\n    .1 attention\n    .2 feedforward\n4. projection\n'

In [2]:
!pip install tqdm -q
!pip install wandb -q

In [3]:
!mkdir dataset
!mkdir model

In [4]:
import wandb
import torch
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import ConcatDataset, Dataset
from torch import nn
import math
from transformers import get_linear_schedule_with_warmup

In [5]:
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [6]:
WINDOW_SIZE = 64
BATCH_SIZE = 64
EPOCHS = 10
LR = 5e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
    "learning_rate": LR,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "window_size": WINDOW_SIZE,
}

In [7]:
run = wandb.init(project="remark-gpt", config=config)

[34m[1mwandb[0m: Currently logged in as: [33mcrush_tarash[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
def tokenize(text, token_to_id) -> list[int]:
    return [token_to_id[ch] for ch in text]


def decode(token_ids: list[int], id_to_token) -> str:
    return "".join([id_to_token[token_id] for token_id in token_ids])


def generate_text(
    model,
    token_to_id,
    id_to_token,
    prompt: str,
    device: str,
    window_size: int,
    max_tokens: int = 10,
    temperature: float = 1.0,
) -> str:
    """Generate text using the trained GPT model."""
    model.eval()
    context = tokenize(prompt, token_to_id)
    generated = list(context)

    with torch.no_grad():
        for _ in range(max_tokens):
            x = torch.tensor(context[-window_size:]).unsqueeze(0).to(device)
            logits, _ = model(x)
            logits = logits[0, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
            context = generated

    return decode(context, id_to_token)


class TextDataset(Dataset):
    def __init__(self, path, context_window_size, token_to_id):
        # Load dataset
        with open(path, "r") as f:
            lines = f.readlines()

        text = "\n".join(lines)

        self.tokens = tokenize(text, token_to_id)

        self.x = []
        self.y = []
        for i in range(len(self.tokens) - context_window_size):
            self.x.append(self.tokens[i : i + context_window_size])
            self.y.append(self.tokens[i + 1 : i + context_window_size + 1])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.tensor(self.x[idx]), torch.tensor(self.y[idx])


In [15]:
# Prepare tokenizer
dataset_lines = []
with open("dataset/The_Dream_Room_1920_AST_978-5-17-071518-3.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/The_Road_Back_1931.txt", "r") as f:
    dataset_lines.extend(f.readlines())
text = "\n".join(dataset_lines)
tokens = sorted(set(text))
id_to_token = {i: token for i, token in enumerate(tokens)}
token_to_id = {token: i for i, token in enumerate(tokens)}

train_ds1 = TextDataset(
    "dataset/The_Dream_Room_1920_AST_978-5-17-071518-3-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds2 = TextDataset(
    "dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds3 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds4 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds5 = TextDataset(
    "dataset/The_Road_Back_1931-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds = ConcatDataset([train_ds1, train_ds2, train_ds3, train_ds4, train_ds5])
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

dev_ds1 = TextDataset(
    "dataset/The_Dream_Room_1920_AST_978-5-17-071518-3-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds2 = TextDataset(
    "dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds3 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds4 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds5 = TextDataset(
    "dataset/The_Road_Back_1931-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)

dev_dataloader1 = DataLoader(dev_ds1, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader2 = DataLoader(dev_ds2, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader3 = DataLoader(dev_ds3, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader4 = DataLoader(dev_ds4, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader5 = DataLoader(dev_ds5, batch_size=BATCH_SIZE, shuffle=False)
names_with_dev_dataloaders = [
    ("The_Dream_Room_1920_AST_978-5-17-071518-3", dev_dataloader1),
    ("Station_at_the_Horizon_1928_AST_978-5-17-133322-5", dev_dataloader2),
    ("All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1", dev_dataloader3),
    ("All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0", dev_dataloader4),
    ("The_Road_Back_1931", dev_dataloader5)
    ]

In [16]:
class AttentionHead(nn.Module):
    def __init__(self, embedding_dim, head_dim, dropout, max_len):
        super().__init__()

        self.Q = nn.Linear(embedding_dim, head_dim, bias=False)
        self.K = nn.Linear(embedding_dim, head_dim, bias=False)
        self.V = nn.Linear(embedding_dim, head_dim, bias=False)

        self.kv_softmax = nn.Softmax(dim=-1)
        self.attn_dropout = nn.Dropout(dropout)

        self.register_buffer("tril", torch.tril(torch.ones(max_len, max_len)))

    def forward(self, norm_inputs):

        _, T, _ = norm_inputs.shape

        q = self.Q(norm_inputs)
        k = self.K(norm_inputs)
        v = self.V(norm_inputs)

        attention_weights = (q @ k.transpose(-1, -2)) / math.sqrt(
            q.shape[-1]
        )  # shape: (B, T, T)
        attention_weights_masked = attention_weights.masked_fill(
            self.tril[:T, :T] == 0, -torch.inf
        )
        attention_scores = self.kv_softmax(attention_weights_masked)
        attention_scores = self.attn_dropout(attention_scores)

        return attention_scores @ v


class GPTBlock(nn.Module):
    def __init__(
        self, embedding_dim: int, max_len: int, dropout: float = 0.1, n_heads=2
    ):
        super().__init__()
        head_dim = embedding_dim // n_heads

        # Attention heads
        self.heads = nn.ModuleList(
            [
                AttentionHead(embedding_dim, head_dim, dropout, max_len)
                for _ in range(n_heads)
            ]
        )

        # Feedforward
        self.f1 = nn.Linear(embedding_dim, embedding_dim * 4)
        self.f_act = nn.ReLU()
        self.f2 = nn.Linear(embedding_dim * 4, embedding_dim)
        self.ff_dropout = nn.Dropout(dropout)

        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, inputs):
        norm_inputs = self.ln1(inputs)
        attn_out = torch.concat([h(norm_inputs) for h in self.heads], dim=-1)

        # Add residual connection
        x = inputs + attn_out

        norm_x = self.ln2(x)
        ff_out = self.ff_dropout(self.f2(self.f_act(self.f1(norm_x))))
        out = x + ff_out
        return out


class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        embedding_dim: int = 16,
        blocks_num: int = 4,
        n_heads: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos = nn.Embedding(max_len, embedding_dim)
        self.blocks = nn.Sequential(
            *[
                GPTBlock(embedding_dim, max_len, dropout, n_heads)
                for _ in range(blocks_num)
            ]
        )
        self.proj = nn.Linear(embedding_dim, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, inputs, labels=None):
        _, T = inputs.shape
        embs = self.emb(inputs)
        pos_embs = self.pos(torch.arange(T, device=inputs.device))  # (T,C)
        blocks_output = self.blocks(embs + pos_embs)
        logits = self.proj(blocks_output)  # (B,T,vocab_size)
        if labels is not None:
            loss = self.loss_fn(logits.view(-1, self.vocab_size), labels.view(-1))
            return logits, loss
        else:
            return logits, None


In [17]:
model = GPT(vocab_size=len(tokens), max_len=WINDOW_SIZE, embedding_dim=64, blocks_num=6, n_heads=4)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# Scheduler with warmup and linear decay
TOTAL_STEPS = EPOCHS * len(train_dataloader)
WARMUP_STEPS = int(0.1 * TOTAL_STEPS)  # 10% warmup

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TOTAL_STEPS)

In [18]:
wandb.watch(model, log_freq=2500, log='all')

In [19]:
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    steps_n = 0

    for batch in tqdm(train_dataloader):
        input, labels = batch[0].to(device), batch[1].to(device)
        output, loss = model(input, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Update the learning rate
        optimizer.zero_grad()
        epoch_loss += loss.item()
        steps_n += 1
        run.log({"train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]})

    avg_loss = epoch_loss / steps_n
    print(f"epoch {epoch} train loss: {avg_loss:.3f}")

    model.eval()
    with torch.no_grad():
        for ds_name, dev_dataloader in names_with_dev_dataloaders:
            val_steps_n = 0
            val_epoch_loss = 0
            for batch in tqdm(dev_dataloader):
                input, labels = batch[0].to(device), batch[1].to(device)
                output, loss = model(input, labels)
                val_epoch_loss += loss.item()
                val_steps_n += 1
            avg_val_loss = val_epoch_loss / val_steps_n
            print(f"epoch {epoch} val loss for {ds_name}: {avg_val_loss:.3f}")
            metric_name = f"avg_val_loss_{ds_name}"
            run.log({metric_name: avg_val_loss}, commit=False)

    run.log({"avg_train_loss": avg_loss})

100%|██████████| 22516/22516 [16:14<00:00, 23.11it/s]


epoch 0 train loss: 2.233


100%|██████████| 397/397 [00:05<00:00, 78.29it/s]


epoch 0 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.780


100%|██████████| 517/517 [00:06<00:00, 79.84it/s]


epoch 0 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.645


100%|██████████| 485/485 [00:06<00:00, 79.91it/s]


epoch 0 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.793


100%|██████████| 556/556 [00:06<00:00, 79.69it/s]


epoch 0 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.702


100%|██████████| 715/715 [00:09<00:00, 79.04it/s]


epoch 0 val loss for The_Road_Back_1931: 1.734


100%|██████████| 22516/22516 [16:14<00:00, 23.10it/s]


epoch 1 train loss: 1.691


100%|██████████| 397/397 [00:05<00:00, 78.64it/s]


epoch 1 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.648


100%|██████████| 517/517 [00:06<00:00, 78.67it/s]


epoch 1 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.509


100%|██████████| 485/485 [00:06<00:00, 78.90it/s]


epoch 1 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.670


100%|██████████| 556/556 [00:07<00:00, 78.56it/s]


epoch 1 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.581


100%|██████████| 715/715 [00:09<00:00, 78.12it/s]


epoch 1 val loss for The_Road_Back_1931: 1.608


100%|██████████| 22516/22516 [16:16<00:00, 23.06it/s]


epoch 2 train loss: 1.604


100%|██████████| 397/397 [00:05<00:00, 78.59it/s]


epoch 2 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.607


100%|██████████| 517/517 [00:06<00:00, 78.74it/s]


epoch 2 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.455


100%|██████████| 485/485 [00:06<00:00, 79.40it/s]


epoch 2 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.629


100%|██████████| 556/556 [00:06<00:00, 79.48it/s]


epoch 2 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.540


100%|██████████| 715/715 [00:09<00:00, 78.24it/s]


epoch 2 val loss for The_Road_Back_1931: 1.565


  5%|▍         | 1049/22516 [00:45<15:32, 23.03it/s]


KeyboardInterrupt: 

In [20]:
torch.save(model.state_dict(), "model/gpt.pt")

In [21]:
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model/gpt.pt")
run.log_artifact(artifact)

<Artifact model>

In [22]:
wandb.finish()

0,1
avg_train_loss,█▂▁
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1,█▃▁
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0,█▃▁
avg_val_loss_Station_at_the_Horizon_1928_AST_978-5-17-133322-5,█▃▁
avg_val_loss_The_Dream_Room_1920_AST_978-5-17-071518-3,█▃▁
avg_val_loss_The_Road_Back_1931,█▃▁
lr,▁▃▃▃▄▄▅▆▆▇██████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆
train_loss,██▅▄▄▃▃▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
avg_train_loss,1.60403
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1,1.62939
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0,1.53973
avg_val_loss_Station_at_the_Horizon_1928_AST_978-5-17-133322-5,1.45507
avg_val_loss_The_Dream_Room_1920_AST_978-5-17-071518-3,1.60669
avg_val_loss_The_Road_Back_1931,1.56452
lr,0.00039
train_loss,1.55082


In [23]:
model.eval()
prompt = "– Привет, любовь моя!\n"
generated_text = generate_text(
    model,
    token_to_id,
    id_to_token,
    prompt,
    device,
    window_size=WINDOW_SIZE,
    max_tokens=200,
    temperature=0.5
)
print(generated_text)

– Привет, любовь моя!

– Вот мы стали свои не видим, – он стал он нам так не хотел по старой такой желание все в следующее. Так как и подредлагает за ним остальные самые произрачные гранаты. Они умеренно подозрели и все ст
