<a href="https://colab.research.google.com/github/pashok3d/RemarqueGPT/blob/main/RemarqueGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
Building GPT from scratch and training it on all books of Erich Maria Remarque

Available tools: python, pytorch

Tasks:
1. Load data and tokenize to characters
2. Implement GPT model using pytorch
3. Train and evaluate the model

GPT model structure:
1. embedding layer
2. positional encoding
3. blocks
    .1 attention
    .2 feedforward
4. projection
"""

In [2]:
!pip install tqdm -q
!pip install wandb -q

In [3]:
!mkdir dataset
!mkdir model

In [4]:
import wandb
import torch
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import ConcatDataset, Dataset
from torch import nn
import math
from transformers import get_linear_schedule_with_warmup

In [None]:
wandb.login()

In [7]:
WINDOW_SIZE = 64
BATCH_SIZE = 64
EPOCHS = 10
LR = 5e-4

device = "cuda" if torch.cuda.is_available() else "cpu"

config = {
    "learning_rate": LR,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "window_size": WINDOW_SIZE,
}

In [None]:
run = wandb.init(project="remark-gpt", config=config)

In [9]:
def tokenize(text, token_to_id) -> list[int]:
    return [token_to_id[ch] for ch in text]


def decode(token_ids: list[int], id_to_token) -> str:
    return "".join([id_to_token[token_id] for token_id in token_ids])


def generate_text(
    model,
    token_to_id,
    id_to_token,
    prompt: str,
    device: str,
    window_size: int,
    max_tokens: int = 10,
    temperature: float = 1.0,
) -> str:
    """Generate text using the trained GPT model."""
    model.eval()
    context = tokenize(prompt, token_to_id)
    generated = list(context)

    with torch.no_grad():
        for _ in range(max_tokens):
            x = torch.tensor(context[-window_size:]).unsqueeze(0).to(device)
            logits, _ = model(x)
            logits = logits[0, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
            context = generated

    return decode(context, id_to_token)


class TextDataset(Dataset):
    def __init__(self, path, context_window_size, token_to_id):
        # Load dataset
        with open(path, "r") as f:
            lines = f.readlines()

        text = "\n".join(lines)

        self.tokens = tokenize(text, token_to_id)

        self.x = []
        self.y = []
        for i in range(len(self.tokens) - context_window_size):
            self.x.append(self.tokens[i : i + context_window_size])
            self.y.append(self.tokens[i + 1 : i + context_window_size + 1])

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return torch.tensor(self.x[idx]), torch.tensor(self.y[idx])


In [10]:
# Prepare tokenizer
dataset_lines = []
with open("dataset/The_Dream_Room_1920_AST_978-5-17-071518-3.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0.txt", "r") as f:
    dataset_lines.extend(f.readlines())
with open("dataset/The_Road_Back_1931.txt", "r") as f:
    dataset_lines.extend(f.readlines())
text = "\n".join(dataset_lines)
tokens = sorted(set(text))
id_to_token = {i: token for i, token in enumerate(tokens)}
token_to_id = {token: i for i, token in enumerate(tokens)}

train_ds1 = TextDataset(
    "dataset/The_Dream_Room_1920_AST_978-5-17-071518-3-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds2 = TextDataset(
    "dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds3 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds4 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds5 = TextDataset(
    "dataset/The_Road_Back_1931-train.txt",
    WINDOW_SIZE,
    token_to_id,
)
train_ds = ConcatDataset([train_ds1, train_ds2, train_ds3, train_ds4, train_ds5])
train_dataloader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

dev_ds1 = TextDataset(
    "dataset/The_Dream_Room_1920_AST_978-5-17-071518-3-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds2 = TextDataset(
    "dataset/Station_at_the_Horizon_1928_AST_978-5-17-133322-5-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds3 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds4 = TextDataset(
    "dataset/All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)
dev_ds5 = TextDataset(
    "dataset/The_Road_Back_1931-dev.txt",
    WINDOW_SIZE,
    token_to_id,
)

dev_dataloader1 = DataLoader(dev_ds1, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader2 = DataLoader(dev_ds2, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader3 = DataLoader(dev_ds3, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader4 = DataLoader(dev_ds4, batch_size=BATCH_SIZE, shuffle=False)
dev_dataloader5 = DataLoader(dev_ds5, batch_size=BATCH_SIZE, shuffle=False)
names_with_dev_dataloaders = [
    ("The_Dream_Room_1920_AST_978-5-17-071518-3", dev_dataloader1),
    ("Station_at_the_Horizon_1928_AST_978-5-17-133322-5", dev_dataloader2),
    ("All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1", dev_dataloader3),
    ("All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0", dev_dataloader4),
    ("The_Road_Back_1931", dev_dataloader5)
    ]

In [11]:
class AttentionHead(nn.Module):
    def __init__(self, embedding_dim, head_dim, dropout, max_len):
        super().__init__()

        self.Q = nn.Linear(embedding_dim, head_dim, bias=False)
        self.K = nn.Linear(embedding_dim, head_dim, bias=False)
        self.V = nn.Linear(embedding_dim, head_dim, bias=False)

        self.kv_softmax = nn.Softmax(dim=-1)
        self.attn_dropout = nn.Dropout(dropout)

        self.register_buffer("tril", torch.tril(torch.ones(max_len, max_len)))

    def forward(self, norm_inputs):

        _, T, _ = norm_inputs.shape

        q = self.Q(norm_inputs)
        k = self.K(norm_inputs)
        v = self.V(norm_inputs)

        attention_weights = (q @ k.transpose(-1, -2)) / math.sqrt(
            q.shape[-1]
        )  # shape: (B, T, T)
        attention_weights_masked = attention_weights.masked_fill(
            self.tril[:T, :T] == 0, -torch.inf
        )
        attention_scores = self.kv_softmax(attention_weights_masked)
        attention_scores = self.attn_dropout(attention_scores)

        return attention_scores @ v


class GPTBlock(nn.Module):
    def __init__(
        self, embedding_dim: int, max_len: int, dropout: float = 0.1, n_heads=2
    ):
        super().__init__()
        head_dim = embedding_dim // n_heads

        # Attention heads
        self.heads = nn.ModuleList(
            [
                AttentionHead(embedding_dim, head_dim, dropout, max_len)
                for _ in range(n_heads)
            ]
        )

        # Feedforward
        self.f1 = nn.Linear(embedding_dim, embedding_dim * 4)
        self.f_act = nn.ReLU()
        self.f2 = nn.Linear(embedding_dim * 4, embedding_dim)
        self.ff_dropout = nn.Dropout(dropout)

        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, inputs):
        norm_inputs = self.ln1(inputs)
        attn_out = torch.concat([h(norm_inputs) for h in self.heads], dim=-1)

        # Add residual connection
        x = inputs + attn_out

        norm_x = self.ln2(x)
        ff_out = self.ff_dropout(self.f2(self.f_act(self.f1(norm_x))))
        out = x + ff_out
        return out


class GPT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        max_len: int,
        embedding_dim: int = 16,
        blocks_num: int = 4,
        n_heads: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos = nn.Embedding(max_len, embedding_dim)
        self.blocks = nn.Sequential(
            *[
                GPTBlock(embedding_dim, max_len, dropout, n_heads)
                for _ in range(blocks_num)
            ]
        )
        self.proj = nn.Linear(embedding_dim, vocab_size)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, inputs, labels=None):
        _, T = inputs.shape
        embs = self.emb(inputs)
        pos_embs = self.pos(torch.arange(T, device=inputs.device))  # (T,C)
        blocks_output = self.blocks(embs + pos_embs)
        logits = self.proj(blocks_output)  # (B,T,vocab_size)
        if labels is not None:
            loss = self.loss_fn(logits.view(-1, self.vocab_size), labels.view(-1))
            return logits, loss
        else:
            return logits, None


In [12]:
model = GPT(vocab_size=len(tokens), max_len=WINDOW_SIZE, embedding_dim=64, blocks_num=6, n_heads=4)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

# Scheduler with warmup and linear decay
TOTAL_STEPS = EPOCHS * len(train_dataloader)
WARMUP_STEPS = int(0.1 * TOTAL_STEPS)  # 10% warmup

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TOTAL_STEPS)

In [13]:
wandb.watch(model, log_freq=2500, log='all')

In [None]:
model.train()
epoch_loss = 0
steps_n = 0
with torch.no_grad():
    for batch in tqdm(train_dataloader):
        input, labels = batch[0].to(device), batch[1].to(device)
        output, loss = model(input, labels)
        epoch_loss += loss.item()
        steps_n += 1
    avg_loss = epoch_loss / steps_n
expected_init_loss = -math.log(1 / len(tokens))
print(f"initial train loss: {avg_loss:.3f}, with expected of {expected_init_loss:.3f}")

In [14]:
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    steps_n = 0

    for batch in tqdm(train_dataloader):
        input, labels = batch[0].to(device), batch[1].to(device)
        output, loss = model(input, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Update the learning rate
        optimizer.zero_grad()
        epoch_loss += loss.item()
        steps_n += 1
        run.log({"train_loss": loss.item(), "lr": scheduler.get_last_lr()[0]})

    avg_loss = epoch_loss / steps_n
    print(f"epoch {epoch} train loss: {avg_loss:.3f}")

    model.eval()
    with torch.no_grad():
        for ds_name, dev_dataloader in names_with_dev_dataloaders:
            val_steps_n = 0
            val_epoch_loss = 0
            for batch in tqdm(dev_dataloader):
                input, labels = batch[0].to(device), batch[1].to(device)
                output, loss = model(input, labels)
                val_epoch_loss += loss.item()
                val_steps_n += 1
            avg_val_loss = val_epoch_loss / val_steps_n
            print(f"epoch {epoch} val loss for {ds_name}: {avg_val_loss:.3f}")
            metric_name = f"avg_val_loss_{ds_name}"
            run.log({metric_name: avg_val_loss}, commit=False)

    run.log({"avg_train_loss": avg_loss})

100%|██████████| 22516/22516 [06:30<00:00, 57.73it/s]


epoch 0 train loss: 2.194


100%|██████████| 397/397 [00:02<00:00, 163.70it/s]


epoch 0 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.760


100%|██████████| 517/517 [00:03<00:00, 163.38it/s]


epoch 0 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.625


100%|██████████| 485/485 [00:02<00:00, 162.76it/s]


epoch 0 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.771


100%|██████████| 556/556 [00:03<00:00, 163.88it/s]


epoch 0 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.683


100%|██████████| 715/715 [00:04<00:00, 163.44it/s]


epoch 0 val loss for The_Road_Back_1931: 1.709


100%|██████████| 22516/22516 [06:30<00:00, 57.73it/s]


epoch 1 train loss: 1.678


100%|██████████| 397/397 [00:02<00:00, 164.17it/s]


epoch 1 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.641


100%|██████████| 517/517 [00:03<00:00, 163.63it/s]


epoch 1 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.494


100%|██████████| 485/485 [00:02<00:00, 162.69it/s]


epoch 1 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.656


100%|██████████| 556/556 [00:03<00:00, 161.08it/s]


epoch 1 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.565


100%|██████████| 715/715 [00:04<00:00, 160.31it/s]


epoch 1 val loss for The_Road_Back_1931: 1.594


100%|██████████| 22516/22516 [06:28<00:00, 57.91it/s]


epoch 2 train loss: 1.603


100%|██████████| 397/397 [00:02<00:00, 159.41it/s]


epoch 2 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.608


100%|██████████| 517/517 [00:03<00:00, 163.60it/s]


epoch 2 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.449


100%|██████████| 485/485 [00:02<00:00, 163.70it/s]


epoch 2 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.628


100%|██████████| 556/556 [00:03<00:00, 164.10it/s]


epoch 2 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.532


100%|██████████| 715/715 [00:04<00:00, 162.87it/s]


epoch 2 val loss for The_Road_Back_1931: 1.563


100%|██████████| 22516/22516 [06:28<00:00, 57.89it/s]


epoch 3 train loss: 1.571


100%|██████████| 397/397 [00:02<00:00, 157.17it/s]


epoch 3 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.583


100%|██████████| 517/517 [00:03<00:00, 164.07it/s]


epoch 3 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.428


100%|██████████| 485/485 [00:02<00:00, 165.25it/s]


epoch 3 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.604


100%|██████████| 556/556 [00:03<00:00, 162.46it/s]


epoch 3 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.511


100%|██████████| 715/715 [00:04<00:00, 164.09it/s]


epoch 3 val loss for The_Road_Back_1931: 1.543


100%|██████████| 22516/22516 [06:29<00:00, 57.75it/s]


epoch 4 train loss: 1.550


100%|██████████| 397/397 [00:02<00:00, 160.40it/s]


epoch 4 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.573


100%|██████████| 517/517 [00:03<00:00, 161.13it/s]


epoch 4 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.407


100%|██████████| 485/485 [00:02<00:00, 164.13it/s]


epoch 4 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.594


100%|██████████| 556/556 [00:03<00:00, 160.39it/s]


epoch 4 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.498


100%|██████████| 715/715 [00:04<00:00, 160.85it/s]


epoch 4 val loss for The_Road_Back_1931: 1.531


100%|██████████| 22516/22516 [06:32<00:00, 57.38it/s]


epoch 5 train loss: 1.534


100%|██████████| 397/397 [00:02<00:00, 162.77it/s]


epoch 5 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.563


100%|██████████| 517/517 [00:03<00:00, 162.27it/s]


epoch 5 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.396


100%|██████████| 485/485 [00:02<00:00, 162.65it/s]


epoch 5 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.581


100%|██████████| 556/556 [00:03<00:00, 160.70it/s]


epoch 5 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.487


100%|██████████| 715/715 [00:04<00:00, 162.04it/s]


epoch 5 val loss for The_Road_Back_1931: 1.521


100%|██████████| 22516/22516 [06:30<00:00, 57.62it/s]


epoch 6 train loss: 1.522


100%|██████████| 397/397 [00:02<00:00, 162.34it/s]


epoch 6 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.554


100%|██████████| 517/517 [00:03<00:00, 163.34it/s]


epoch 6 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.385


100%|██████████| 485/485 [00:03<00:00, 156.96it/s]


epoch 6 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.576


100%|██████████| 556/556 [00:03<00:00, 163.73it/s]


epoch 6 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.480


100%|██████████| 715/715 [00:04<00:00, 161.80it/s]


epoch 6 val loss for The_Road_Back_1931: 1.512


100%|██████████| 22516/22516 [06:32<00:00, 57.40it/s]


epoch 7 train loss: 1.512


100%|██████████| 397/397 [00:02<00:00, 160.58it/s]


epoch 7 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.549


100%|██████████| 517/517 [00:03<00:00, 162.20it/s]


epoch 7 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.380


100%|██████████| 485/485 [00:03<00:00, 161.02it/s]


epoch 7 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.569


100%|██████████| 556/556 [00:03<00:00, 163.54it/s]


epoch 7 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.474


100%|██████████| 715/715 [00:04<00:00, 163.58it/s]


epoch 7 val loss for The_Road_Back_1931: 1.507


100%|██████████| 22516/22516 [06:30<00:00, 57.67it/s]


epoch 8 train loss: 1.503


100%|██████████| 397/397 [00:02<00:00, 164.51it/s]


epoch 8 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.540


100%|██████████| 517/517 [00:03<00:00, 162.67it/s]


epoch 8 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.368


100%|██████████| 485/485 [00:03<00:00, 158.27it/s]


epoch 8 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.563


100%|██████████| 556/556 [00:03<00:00, 163.20it/s]


epoch 8 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.469


100%|██████████| 715/715 [00:04<00:00, 163.71it/s]


epoch 8 val loss for The_Road_Back_1931: 1.501


100%|██████████| 22516/22516 [06:31<00:00, 57.56it/s]


epoch 9 train loss: 1.495


100%|██████████| 397/397 [00:02<00:00, 162.24it/s]


epoch 9 val loss for The_Dream_Room_1920_AST_978-5-17-071518-3: 1.536


100%|██████████| 517/517 [00:03<00:00, 161.27it/s]


epoch 9 val loss for Station_at_the_Horizon_1928_AST_978-5-17-133322-5: 1.363


100%|██████████| 485/485 [00:02<00:00, 163.77it/s]


epoch 9 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1: 1.559


100%|██████████| 556/556 [00:03<00:00, 162.57it/s]


epoch 9 val loss for All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0: 1.464


100%|██████████| 715/715 [00:04<00:00, 159.91it/s]

epoch 9 val loss for The_Road_Back_1931: 1.497





In [15]:
torch.save(model.state_dict(), "model/gpt.pt")

In [None]:
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model/gpt.pt")
run.log_artifact(artifact)

In [34]:
wandb.finish()

0,1
avg_train_loss,█▃▂▂▂▁▁▁▁▁
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1,█▄▃▂▂▂▂▁▁▁
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0,█▄▃▃▂▂▂▁▁▁
avg_val_loss_Station_at_the_Horizon_1928_AST_978-5-17-133322-5,█▄▃▃▂▂▂▁▁▁
avg_val_loss_The_Dream_Room_1920_AST_978-5-17-071518-3,█▄▃▂▂▂▂▁▁▁
avg_val_loss_The_Road_Back_1931,█▄▃▃▂▂▁▁▁▁
lr,▂▃█████▇▇▇▇▇▇▇▇▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▁
train_loss,█▆▄▄▃▃▃▂▂▂▂▂▂▃▂▂▂▂▂▂▁▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
avg_train_loss,1.49547
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-105639-1,1.55941
avg_val_loss_All_Quiet_on_the_Western_Front_1929_AST_978-5-17-137374-0,1.46373
avg_val_loss_Station_at_the_Horizon_1928_AST_978-5-17-133322-5,1.363
avg_val_loss_The_Dream_Room_1920_AST_978-5-17-071518-3,1.53615
avg_val_loss_The_Road_Back_1931,1.49668
lr,0.0
train_loss,1.50139


In [37]:
model.eval()
prompt = "– Привет, любовь моя!\n"
generated_text = generate_text(
    model,
    token_to_id,
    id_to_token,
    prompt,
    device,
    window_size=WINDOW_SIZE,
    max_tokens=200,
    temperature=0.5
)
print(generated_text)

– Привет, любовь моя!

– А теперь ведь в то, что мы в причине появляемся. Потом следует нам в каждом поверхное возможно в портрет он на стальнике, завтра она не спрашивают в полусне с конца. Я поднимаю в напротивника, — го
