In [None]:
!pip install -r requirements.txt
!pip install numpy==1.26.4 scikit-learn==1.3.2 --force-reinstall --no-cache-dir


In [2]:
import os
import json
import random
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from peft import PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from huggingface_hub import login

  warn(


In [None]:
#==== CONFIGURATION ====
login(token="hf_...")
BGE_MODEL = "BAAI/bge-base-en-v1.5"
GEMMA_MODEL = "google/gemma-3-4b-pt"
CHECKPOINT_DIR = "./persistent_volume/last_checkpoint/"
BOOKS_PATH = "books.jsonl"
VAL_SPLIT = 0.005
BATCH_SIZE = 4
MAX_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32

# ==== LOAD MODELS ====
print("Loading tokenizer and base model...")
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    GEMMA_MODEL,
    torch_dtype=DTYPE,
    device_map="auto"
)
model = PeftModel.from_pretrained(base_model, CHECKPOINT_DIR)
model.eval()
for p in model.parameters():
    p.requires_grad = False

print("Loading BGE encoder...")
bge_tokenizer = AutoTokenizer.from_pretrained(BGE_MODEL)
bge_encoder = AutoModel.from_pretrained(BGE_MODEL).to(DEVICE)
bge_encoder.eval()
for p in bge_encoder.parameters():
    p.requires_grad = False

# ==== PROJECTOR ====
class Projector(nn.Module):
    def __init__(self, in_dim=768, out_dim=2560):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x):
        return self.mlp(x)

projector = Projector().to(DEVICE)



In [4]:
# ==== DATA ====
print("Loading data...")
with open(BOOKS_PATH, "r", encoding="utf-8") as f:
    texts = [json.loads(line)["text"] for line in f if "text" in json.loads(line)]

random.shuffle(texts)
split_idx = int(len(texts) * (1 - VAL_SPLIT))
train_texts = texts[:split_idx]
val_texts = texts[split_idx:]

def chunk_text(text, max_tokens=MAX_LENGTH):
    tokens = tokenizer(text, truncation=False)["input_ids"]
    chunks = [tokens[i:i + max_tokens] for i in range(0, len(tokens), max_tokens)]
    return [tokenizer.decode(chunk) for chunk in chunks if len(chunk) > 10]

def encode_bge(texts):
    inputs = bge_tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    with torch.no_grad():
        emb = bge_encoder(**inputs).last_hidden_state[:, 0, :]  # CLS token
    return emb

class ProjectorDataset(torch.utils.data.Dataset):
    def __init__(self, texts):
        self.samples = []
        for text in texts:
            for chunk in chunk_text(text):
                self.samples.append(chunk)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH)
    input_ids = inputs["input_ids"][:, :-1].to(DEVICE)
    labels = inputs["input_ids"][:, 1:].to(DEVICE)
    attn_mask = inputs["attention_mask"][:, :-1].to(DEVICE)
    with torch.no_grad():
        bge_emb = encode_bge(batch)
    return bge_emb, input_ids, attn_mask, labels

train_dataset = ProjectorDataset(train_texts)
val_dataset = ProjectorDataset(val_texts)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
print("Data is ready")

Loading data...
Data is ready


In [None]:

# ==== OPTIMIZER & LOSS ====
optimizer = torch.optim.AdamW(projector.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# ==== TRAIN LOOP ====
def evaluate():
    model.eval()
    projector.eval()
    total_loss, total_tokens = 0, 0
    with torch.no_grad():
        for bge_emb, input_ids, attn_mask, labels in val_loader:
            proj = projector(bge_emb)
            expanded_proj = proj.unsqueeze(1).expand(-1, input_ids.size(1), -1)
            inputs_embeds = model.base_model.model.language_model.embed_tokens(input_ids)
            full_inputs = inputs_embeds + expanded_proj
            full_inputs = full_inputs.to(dtype=DTYPE)
            output = model(inputs_embeds=full_inputs, attention_mask=attn_mask)
            logits = output.logits
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            total_loss += loss.item() * labels.numel()
            total_tokens += labels.numel()
    ppl = torch.exp(torch.tensor(total_loss / total_tokens))
    return total_loss / total_tokens, ppl.item()

print("Starting training...")
for epoch in range(2):
    projector.train()
    for step, (bge_emb, input_ids, attn_mask, labels) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        proj = projector(bge_emb)
        expanded_proj = proj.unsqueeze(1).expand(-1, input_ids.size(1), -1)
        inputs_embeds = model.base_model.model.language_model.embed_tokens(input_ids)
        full_inputs = inputs_embeds + expanded_proj
        full_inputs = full_inputs.to(dtype=DTYPE)
        output = model(inputs_embeds=full_inputs, attention_mask=attn_mask)
        logits = output.logits
        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        loss.backward()
        optimizer.step()
        if step % 25 == 0:
            val_loss, val_ppl = evaluate()
            print(f"[Epoch {epoch} | Step {step}] Train loss: {loss:.4f} | Val loss: {val_loss:.4f} | PPL: {val_ppl:.2f}")
            torch.save(projector.state_dict(), f"persistent_volume/proj_checkpoints/projector_step{step:05d}.pt")


# ==== SAVE ====  
torch.save(projector.state_dict(), "projector.pt")
print("Saved projector to projector.pt")


Starting training...


  0%|          | 1/80939 [01:37<2182:18:05, 97.07s/it]

[Epoch 0 | Step 0] Train loss: 3.9062 | Val loss: 3.1266 | PPL: 22.80


  0%|          | 26/80939 [03:25<660:26:38, 29.38s/it]

[Epoch 0 | Step 25] Train loss: 3.7812 | Val loss: 3.0252 | PPL: 20.60


  0%|          | 51/80939 [05:13<659:25:23, 29.35s/it]

[Epoch 0 | Step 50] Train loss: 4.0312 | Val loss: 2.9772 | PPL: 19.63


  0%|          | 76/80939 [07:02<658:41:38, 29.32s/it]

[Epoch 0 | Step 75] Train loss: 3.5156 | Val loss: 2.9307 | PPL: 18.74


  0%|          | 101/80939 [08:50<658:26:45, 29.32s/it]

[Epoch 0 | Step 100] Train loss: 3.3594 | Val loss: 2.9141 | PPL: 18.43


  0%|          | 126/80939 [10:38<658:11:02, 29.32s/it]

[Epoch 0 | Step 125] Train loss: 3.3125 | Val loss: 2.9052 | PPL: 18.27


  0%|          | 151/80939 [12:27<658:02:37, 29.32s/it]

[Epoch 0 | Step 150] Train loss: 2.9062 | Val loss: 2.8927 | PPL: 18.04


  0%|          | 176/80939 [14:15<658:39:28, 29.36s/it]

[Epoch 0 | Step 175] Train loss: 3.4219 | Val loss: 2.8694 | PPL: 17.63


  0%|          | 201/80939 [16:03<657:43:03, 29.33s/it]

[Epoch 0 | Step 200] Train loss: 3.0469 | Val loss: 2.8717 | PPL: 17.67


  0%|          | 226/80939 [17:51<657:23:41, 29.32s/it]

[Epoch 0 | Step 225] Train loss: 2.6250 | Val loss: 2.8628 | PPL: 17.51


  0%|          | 251/80939 [19:40<657:20:14, 29.33s/it]

[Epoch 0 | Step 250] Train loss: 3.7500 | Val loss: 2.8618 | PPL: 17.49


  0%|          | 276/80939 [21:28<656:57:39, 29.32s/it]

[Epoch 0 | Step 275] Train loss: 3.8438 | Val loss: 2.8574 | PPL: 17.42


  0%|          | 301/80939 [23:16<656:54:29, 29.33s/it]

[Epoch 0 | Step 300] Train loss: 3.5312 | Val loss: 2.8644 | PPL: 17.54


  0%|          | 326/80939 [25:05<656:15:31, 29.31s/it]

[Epoch 0 | Step 325] Train loss: 3.7656 | Val loss: 2.8529 | PPL: 17.34


  0%|          | 351/80939 [26:53<658:47:32, 29.43s/it]

[Epoch 0 | Step 350] Train loss: 3.1719 | Val loss: 2.8528 | PPL: 17.34


  0%|          | 376/80939 [28:41<656:17:37, 29.33s/it]

[Epoch 0 | Step 375] Train loss: 3.9219 | Val loss: 2.8316 | PPL: 16.97


  0%|          | 401/80939 [30:30<656:02:43, 29.32s/it]

[Epoch 0 | Step 400] Train loss: 3.2812 | Val loss: 2.8424 | PPL: 17.16


  1%|          | 426/80939 [32:18<655:52:32, 29.33s/it]

[Epoch 0 | Step 425] Train loss: 3.4531 | Val loss: 2.8320 | PPL: 16.98


  1%|          | 451/80939 [34:06<655:42:17, 29.33s/it]

[Epoch 0 | Step 450] Train loss: 3.5938 | Val loss: 2.8253 | PPL: 16.87


  1%|          | 476/80939 [35:55<655:55:30, 29.35s/it]

[Epoch 0 | Step 475] Train loss: 3.4062 | Val loss: 2.8267 | PPL: 16.89


  1%|          | 501/80939 [37:43<653:48:06, 29.26s/it]

[Epoch 0 | Step 500] Train loss: 3.2656 | Val loss: 2.8230 | PPL: 16.83


  1%|          | 526/80939 [39:31<652:17:08, 29.20s/it]

[Epoch 0 | Step 525] Train loss: 3.3750 | Val loss: 2.8192 | PPL: 16.76


  1%|          | 551/80939 [41:19<654:33:30, 29.31s/it]

[Epoch 0 | Step 550] Train loss: 3.3438 | Val loss: 2.8170 | PPL: 16.73


  1%|          | 576/80939 [43:07<652:10:48, 29.22s/it]

[Epoch 0 | Step 575] Train loss: 3.7344 | Val loss: 2.8149 | PPL: 16.69


  1%|          | 601/80939 [44:55<653:33:00, 29.29s/it]

[Epoch 0 | Step 600] Train loss: 3.6406 | Val loss: 2.8190 | PPL: 16.76


  1%|          | 626/80939 [46:43<654:34:58, 29.34s/it]

[Epoch 0 | Step 625] Train loss: 3.2969 | Val loss: 2.8191 | PPL: 16.76


  1%|          | 651/80939 [48:33<667:00:38, 29.91s/it]

[Epoch 0 | Step 650] Train loss: 3.5781 | Val loss: 2.8121 | PPL: 16.64


  1%|          | 676/80939 [50:27<689:22:44, 30.92s/it]

[Epoch 0 | Step 675] Train loss: 4.4062 | Val loss: 2.8104 | PPL: 16.62


  1%|          | 701/80939 [52:22<689:08:48, 30.92s/it]

[Epoch 0 | Step 700] Train loss: 3.5625 | Val loss: 2.8149 | PPL: 16.69


  1%|          | 726/80939 [54:16<689:02:33, 30.92s/it]

[Epoch 0 | Step 725] Train loss: 3.4062 | Val loss: 2.8093 | PPL: 16.60


  1%|          | 751/80939 [56:10<688:28:37, 30.91s/it]

[Epoch 0 | Step 750] Train loss: 3.6719 | Val loss: 2.8158 | PPL: 16.71


  1%|          | 776/80939 [58:04<688:15:35, 30.91s/it]

[Epoch 0 | Step 775] Train loss: 3.6406 | Val loss: 2.8144 | PPL: 16.68


  1%|          | 801/80939 [59:56<674:57:51, 30.32s/it]

[Epoch 0 | Step 800] Train loss: 3.0781 | Val loss: 2.8048 | PPL: 16.52


  1%|          | 826/80939 [1:01:46<663:27:41, 29.81s/it]

[Epoch 0 | Step 825] Train loss: 3.8750 | Val loss: 2.8041 | PPL: 16.51


  1%|          | 851/80939 [1:03:36<662:48:00, 29.79s/it]

[Epoch 0 | Step 850] Train loss: 3.3438 | Val loss: 2.8043 | PPL: 16.52


  1%|          | 876/80939 [1:05:26<662:40:46, 29.80s/it]

[Epoch 0 | Step 875] Train loss: 3.7031 | Val loss: 2.8090 | PPL: 16.59


  1%|          | 901/80939 [1:07:19<684:26:03, 30.78s/it]

[Epoch 0 | Step 900] Train loss: 3.2969 | Val loss: 2.8060 | PPL: 16.54


  1%|          | 926/80939 [1:09:13<686:25:27, 30.88s/it]

[Epoch 0 | Step 925] Train loss: 3.7812 | Val loss: 2.8082 | PPL: 16.58


  1%|          | 951/80939 [1:11:07<686:37:53, 30.90s/it]

[Epoch 0 | Step 950] Train loss: 4.0312 | Val loss: 2.7997 | PPL: 16.44


  1%|          | 976/80939 [1:13:01<686:01:15, 30.89s/it]

[Epoch 0 | Step 975] Train loss: 3.4688 | Val loss: 2.7948 | PPL: 16.36


  1%|          | 1001/80939 [1:14:55<685:56:37, 30.89s/it]

[Epoch 0 | Step 1000] Train loss: 3.2969 | Val loss: 2.7978 | PPL: 16.41


  1%|▏         | 1026/80939 [1:16:49<684:47:10, 30.85s/it]

[Epoch 0 | Step 1025] Train loss: 3.5312 | Val loss: 2.7977 | PPL: 16.41


  1%|▏         | 1051/80939 [1:18:43<684:56:23, 30.87s/it]

[Epoch 0 | Step 1050] Train loss: 3.4062 | Val loss: 2.8054 | PPL: 16.53


  1%|▏         | 1076/80939 [1:20:37<683:15:41, 30.80s/it]

[Epoch 0 | Step 1075] Train loss: 3.2500 | Val loss: 2.7981 | PPL: 16.41


  1%|▏         | 1101/80939 [1:22:30<682:52:00, 30.79s/it]

[Epoch 0 | Step 1100] Train loss: 3.0469 | Val loss: 2.7996 | PPL: 16.44


  1%|▏         | 1126/80939 [1:24:24<683:21:22, 30.82s/it]

[Epoch 0 | Step 1125] Train loss: 3.1875 | Val loss: 2.7926 | PPL: 16.32


  1%|▏         | 1151/80939 [1:26:18<684:21:29, 30.88s/it]

[Epoch 0 | Step 1150] Train loss: 3.0000 | Val loss: 2.7999 | PPL: 16.44


  1%|▏         | 1176/80939 [1:28:12<683:49:51, 30.86s/it]

[Epoch 0 | Step 1175] Train loss: 3.6719 | Val loss: 2.7932 | PPL: 16.33


  1%|▏         | 1201/80939 [1:30:06<685:19:44, 30.94s/it]

[Epoch 0 | Step 1200] Train loss: 3.5938 | Val loss: 2.7929 | PPL: 16.33


  2%|▏         | 1226/80939 [1:32:00<685:26:55, 30.96s/it]

[Epoch 0 | Step 1225] Train loss: 3.7969 | Val loss: 2.7884 | PPL: 16.26


  2%|▏         | 1251/80939 [1:33:55<684:59:33, 30.95s/it]

[Epoch 0 | Step 1250] Train loss: 3.9844 | Val loss: 2.7920 | PPL: 16.31


  2%|▏         | 1276/80939 [1:35:49<684:09:38, 30.92s/it]

[Epoch 0 | Step 1275] Train loss: 3.6562 | Val loss: 2.8018 | PPL: 16.47


  2%|▏         | 1301/80939 [1:37:43<684:17:54, 30.93s/it]

[Epoch 0 | Step 1300] Train loss: 3.7344 | Val loss: 2.7942 | PPL: 16.35


  2%|▏         | 1326/80939 [1:39:37<683:18:04, 30.90s/it]

[Epoch 0 | Step 1325] Train loss: 3.3438 | Val loss: 2.8016 | PPL: 16.47


  2%|▏         | 1351/80939 [1:41:31<682:10:24, 30.86s/it]

[Epoch 0 | Step 1350] Train loss: 3.1406 | Val loss: 2.7901 | PPL: 16.28


  2%|▏         | 1376/80939 [1:43:25<682:49:52, 30.90s/it]

[Epoch 0 | Step 1375] Train loss: 3.2656 | Val loss: 2.7930 | PPL: 16.33


  2%|▏         | 1401/80939 [1:45:19<683:05:51, 30.92s/it]

[Epoch 0 | Step 1400] Train loss: 3.5469 | Val loss: 2.7882 | PPL: 16.25


  2%|▏         | 1426/80939 [1:47:13<680:11:30, 30.80s/it]

[Epoch 0 | Step 1425] Train loss: 3.1719 | Val loss: 2.7804 | PPL: 16.13


  2%|▏         | 1451/80939 [1:49:07<681:29:58, 30.87s/it]

[Epoch 0 | Step 1450] Train loss: 3.2656 | Val loss: 2.7843 | PPL: 16.19


  2%|▏         | 1476/80939 [1:51:01<681:13:09, 30.86s/it]

[Epoch 0 | Step 1475] Train loss: 3.3906 | Val loss: 2.7925 | PPL: 16.32


  2%|▏         | 1501/80939 [1:52:54<680:17:25, 30.83s/it]

[Epoch 0 | Step 1500] Train loss: 3.1562 | Val loss: 2.7877 | PPL: 16.24


  2%|▏         | 1526/80939 [1:54:48<681:13:13, 30.88s/it]

[Epoch 0 | Step 1525] Train loss: 3.8125 | Val loss: 2.7784 | PPL: 16.09


  2%|▏         | 1551/80939 [1:56:42<680:41:46, 30.87s/it]

[Epoch 0 | Step 1550] Train loss: 3.4688 | Val loss: 2.7787 | PPL: 16.10


  2%|▏         | 1576/80939 [1:58:36<680:03:15, 30.85s/it]

[Epoch 0 | Step 1575] Train loss: 3.4062 | Val loss: 2.7811 | PPL: 16.14


  2%|▏         | 1601/80939 [2:00:30<681:41:29, 30.93s/it]

[Epoch 0 | Step 1600] Train loss: 3.2188 | Val loss: 2.7806 | PPL: 16.13


  2%|▏         | 1626/80939 [2:02:24<681:04:38, 30.91s/it]

[Epoch 0 | Step 1625] Train loss: 3.3438 | Val loss: 2.7746 | PPL: 16.03


  2%|▏         | 1651/80939 [2:04:19<682:07:16, 30.97s/it]

[Epoch 0 | Step 1650] Train loss: 3.7812 | Val loss: 2.7747 | PPL: 16.03


  2%|▏         | 1676/80939 [2:06:13<683:49:50, 31.06s/it]

[Epoch 0 | Step 1675] Train loss: 3.4531 | Val loss: 2.7697 | PPL: 15.95


  2%|▏         | 1701/80939 [2:08:07<680:37:08, 30.92s/it]

[Epoch 0 | Step 1700] Train loss: 3.4062 | Val loss: 2.7731 | PPL: 16.01


  2%|▏         | 1726/80939 [2:10:01<680:36:24, 30.93s/it]

[Epoch 0 | Step 1725] Train loss: 3.4219 | Val loss: 2.7795 | PPL: 16.11


  2%|▏         | 1751/80939 [2:11:56<680:35:42, 30.94s/it]

[Epoch 0 | Step 1750] Train loss: 3.5000 | Val loss: 2.7887 | PPL: 16.26


  2%|▏         | 1776/80939 [2:13:49<678:09:53, 30.84s/it]

[Epoch 0 | Step 1775] Train loss: 3.3750 | Val loss: 2.7930 | PPL: 16.33


  2%|▏         | 1801/80939 [2:15:43<679:20:41, 30.90s/it]

[Epoch 0 | Step 1800] Train loss: 3.6875 | Val loss: 2.7745 | PPL: 16.03


  2%|▏         | 1826/80939 [2:17:38<679:52:15, 30.94s/it]

[Epoch 0 | Step 1825] Train loss: 3.4844 | Val loss: 2.7782 | PPL: 16.09


  2%|▏         | 1851/80939 [2:19:32<678:56:40, 30.90s/it]

[Epoch 0 | Step 1850] Train loss: 6.9062 | Val loss: 5.0808 | PPL: 160.91


  2%|▏         | 1876/80939 [2:21:26<678:42:32, 30.90s/it]

[Epoch 0 | Step 1875] Train loss: 6.8750 | Val loss: 5.8080 | PPL: 332.96


  2%|▏         | 1901/80939 [2:23:20<677:58:04, 30.88s/it]

[Epoch 0 | Step 1900] Train loss: 4.5312 | Val loss: 4.6927 | PPL: 109.15


  2%|▏         | 1926/80939 [2:25:14<678:11:32, 30.90s/it]

[Epoch 0 | Step 1925] Train loss: 5.0625 | Val loss: 4.2538 | PPL: 70.37


  2%|▏         | 1951/80939 [2:27:08<678:15:42, 30.91s/it]

[Epoch 0 | Step 1950] Train loss: 4.4375 | Val loss: 3.2615 | PPL: 26.09


  2%|▏         | 1976/80939 [2:29:02<678:10:40, 30.92s/it]

[Epoch 0 | Step 1975] Train loss: 5.0625 | Val loss: 4.1784 | PPL: 65.26


  2%|▏         | 2001/80939 [2:30:56<678:06:02, 30.93s/it]

[Epoch 0 | Step 2000] Train loss: 3.6719 | Val loss: 3.1592 | PPL: 23.55


  3%|▎         | 2026/80939 [2:32:50<677:36:05, 30.91s/it]

[Epoch 0 | Step 2025] Train loss: 3.3281 | Val loss: 3.1015 | PPL: 22.23


  3%|▎         | 2051/80939 [2:34:44<676:55:11, 30.89s/it]

[Epoch 0 | Step 2050] Train loss: 4.0312 | Val loss: 3.0630 | PPL: 21.39


  3%|▎         | 2076/80939 [2:36:38<677:14:06, 30.91s/it]

[Epoch 0 | Step 2075] Train loss: 4.0312 | Val loss: 3.0356 | PPL: 20.81


  3%|▎         | 2101/80939 [2:38:32<676:33:29, 30.89s/it]

[Epoch 0 | Step 2100] Train loss: 3.4375 | Val loss: 3.0172 | PPL: 20.43


  3%|▎         | 2126/80939 [2:40:27<679:02:29, 31.02s/it]

[Epoch 0 | Step 2125] Train loss: 3.4219 | Val loss: 2.9977 | PPL: 20.04


  3%|▎         | 2151/80939 [2:42:21<676:13:11, 30.90s/it]

[Epoch 0 | Step 2150] Train loss: 3.8438 | Val loss: 2.9827 | PPL: 19.74


  3%|▎         | 2176/80939 [2:44:15<675:51:00, 30.89s/it]

[Epoch 0 | Step 2175] Train loss: 3.9062 | Val loss: 2.9710 | PPL: 19.51


  3%|▎         | 2201/80939 [2:46:08<675:04:39, 30.87s/it]

[Epoch 0 | Step 2200] Train loss: 3.5312 | Val loss: 2.9517 | PPL: 19.14


  3%|▎         | 2226/80939 [2:48:02<675:15:09, 30.88s/it]

[Epoch 0 | Step 2225] Train loss: 3.7188 | Val loss: 2.9434 | PPL: 18.98


  3%|▎         | 2251/80939 [2:49:56<675:25:03, 30.90s/it]

[Epoch 0 | Step 2250] Train loss: 3.6719 | Val loss: 3.0019 | PPL: 20.12


  3%|▎         | 2276/80939 [2:51:50<675:10:37, 30.90s/it]

[Epoch 0 | Step 2275] Train loss: 4.0625 | Val loss: 3.1025 | PPL: 22.25


  3%|▎         | 2301/80939 [2:53:44<673:48:59, 30.85s/it]

[Epoch 0 | Step 2300] Train loss: 4.3125 | Val loss: 5.7222 | PPL: 305.58


  3%|▎         | 2326/80939 [2:55:38<674:26:15, 30.89s/it]

[Epoch 0 | Step 2325] Train loss: 3.8594 | Val loss: 3.1935 | PPL: 24.37


  3%|▎         | 2350/80939 [2:55:50<11:16:22,  1.94it/s] 