In [1]:
import torch
from torch.amp import autocast, GradScaler
from tqdm import tqdm
import sys
import hydra
import torch
from omegaconf import DictConfig, OmegaConf

In [2]:
sys.path.append("..")
from model import JoeyLLM
from data import get_dataloader
from utils.logger import wandbLogger
# from train.trainer import  Trainer

In [3]:
print("‚úÖ Loaded Config:")

# Go UP one level to find the conf directory
with hydra.initialize(config_path="../configs", version_base=None):
    cfg = hydra.compose(config_name="config")

wandbLogger.set_mode(cfg.wandb.mode)

logger = wandbLogger(
    project_name=cfg.wandb.project,
    config=OmegaConf.to_container(cfg, resolve=True)
)

‚úÖ Loaded Config:


In [4]:
def compute_loss(outputs, labels):
    criterion = torch.nn.CrossEntropyLoss()
    B, T, V = outputs.size()
    outputs = outputs.view(B * T, V)    # [B*T, V]
    labels = labels.view(B * T)         # [B*T]
    return criterion(outputs, labels)

In [5]:
print("üß† Initializing Model...")
model = JoeyLLM(
    vocab_size=cfg.model.vocab_size,
    max_seq_len=cfg.model.max_seq_len,
    embed_dim=cfg.model.embed_dim,
    num_layers=cfg.model.num_layers,
    num_heads=cfg.model.num_heads,
    dropout=cfg.model.dropout,
)
logger.watch_model(model, log="all", log_freq=10)

üß† Initializing Model...


In [6]:

print("üì¶ Loading Dataset...")
dataloader = get_dataloader(
    data_path=cfg.data.data_path,
    chunk_size=cfg.data.chunk_size,
    buffer_text_size=cfg.data.buffer_text_size,
    batch_size=cfg.data.batch_size,
    num_workers=cfg.data.num_workers
)

üì¶ Loading Dataset...


In [None]:
# def _train_epoch(self, epoch):
model.train()
total_loss = 0

progress_bar = tqdm(dataloader, desc=f"Steps", leave=False)


[A

In [None]:

for batch_idx, batch in enumerate(progress_bar):
    # Handle dict or tuple batch format
    if isinstance(batch, dict):
        inputs = batch["inputs"].to(self.device)
        labels = batch["labels"].to(self.device)
    else:
        inputs = batch[0].to(self.device)
        labels = batch[1].to(self.device)

    # Ensure shape is [B, T]
    if inputs.dim() == 1:
        inputs = inputs.unsqueeze(0)

    print(f"‚ö†Ô∏è  [DEBUG] inputs.shape = {inputs.shape}")

    self.optimizer.zero_grad()

    with autocast(device_type="cuda"):
        outputs = self.model(inputs)
        loss = self.compute_loss(outputs, labels)

    self.scaler.scale(loss).backward()
    self.scaler.step(self.optimizer)
    self.scaler.update()

    total_loss += loss.item()
        if self.logger:
            self.logger.log_message(msg)
            self.logger.log_metrics({
                "train_loss": loss.item()
            }, step=epoch * len(self.dataloader) + batch_idx)

    avg_loss = total_loss / len(self.dataloader)
    print(f"Epoch {epoch} | Avg Training Loss: {avg_loss:.4f}")
    return avg_loss


In [None]:

def save_checkpoint(self, path):
    checkpoint = {
        "model_state": self.model.state_dict(),
        "optimizer_state": self.optimizer.state_dict(),
        "scaler_state": self.scaler.state_dict()
    }
    if self.scheduler:
        checkpoint["scheduler_state"] = self.scheduler.state_dict()
    torch.save(checkpoint, path)
    print(f"‚úÖ Checkpoint saved to {path}")


In [None]:

def load_checkpoint(self, path):
    checkpoint = torch.load(path)
    self.model.load_state_dict(checkpoint["model_state"])
    self.optimizer.load_state_dict(checkpoint["optimizer_state"])
    self.scaler.load_state_dict(checkpoint["scaler_state"])
    if self.scheduler and "scheduler_state" in checkpoint:
        self.scheduler.load_state_dict(checkpoint["scheduler_state"])
    print(f"‚úÖ Checkpoint loaded from {path}")


In [None]:

def fit(self, num_epochs=20, checkpoint_path="checkpoints/checkpoint.pth"):
    for epoch in range(1, num_epochs + 1):
        train_loss = self._train_epoch(epoch)
        self.save_checkpoint(checkpoint_path)

        if self.scheduler:
            self.scheduler.step()

    print("üèÅ Training complete!")
