In [None]:
from config import Config
from data import get_dataframes, DeepSpeakBertDataset
from models import BaselineModel, DeepSpeakBertModel
from torch.optim import Adam
from torch.utils.data import DataLoader
import logging
import numpy as np
import os
import torch.nn as nn
import torch

In [None]:
cfg = Config(
    datasets_dir="datasets", raw_dir="raw",
    groups_csv="groups.csv", messages_csv="messages.csv",
    val_split=0.2, test_split=0.2, recreate_datasets=False,
    device="mps",
    num_epochs=256, batch_size=4,
    max_context_length=4096, max_group_size=65536,
    output_dir="output", log="bert.log",
)

if not os.path.exists(cfg.output_dir):
    os.mkdir(cfg.output_dir)

logging.basicConfig(
    filename=os.path.join(cfg.output_dir, cfg.log),
    filemode='w',
    format='%(asctime)s - %(levelname)s: %(message)s',
    level=logging.INFO,
)

rng = np.random.default_rng()

In [None]:
train_val_groups_df, train_messages_df, val_messages_df, test_groups_df, test_messages_df = get_dataframes(cfg, rng)

train_ds = DeepSpeakBertDataset(cfg, train_val_groups_df, train_messages_df)
val_ds = DeepSpeakBertDataset(cfg, train_val_groups_df, val_messages_df)
test_ds = DeepSpeakBertDataset(cfg, test_groups_df, test_messages_df)

train_dl = DataLoader(train_ds, batch_size=cfg.batch_size)
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size)
test_dl = DataLoader(test_ds, batch_size=cfg.batch_size)

In [None]:
#model = BaselineModel()
#optimizer = None
#criterion = nn.CrossEntropyLoss()

model = DeepSpeakBertModel(cfg)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=2e-5)

In [None]:
model.to(cfg.device)

In [None]:
def run_epoch(device: str, model: nn.Module, dataloader: DataLoader, criterion: nn.CrossEntropyLoss, optimizer: Adam | None = None):
    total_loss = 0.0
    correct = 0
    total = 0

    if is_training := optimizer is not None:
        model.train()
        context = torch.enable_grad()
    else:
        model.eval()
        context = torch.no_grad()

    with context:
        for samples, labels in dataloader:
            samples = {k: v.to(device) for k, v in samples.items()}
            labels = labels.to(device)

            if is_training:
                optimizer.zero_grad()

            logits = model(**samples)
            loss = criterion(logits, labels)

            if is_training:
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * labels.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = correct / total

    return avg_loss, accuracy

In [None]:
best_val_loss = torch.inf
epochs_without_improvement = 0
patience = 8
best_model_path = os.path.join(cfg.output_dir, "best_model.pt")

for epoch in range(cfg.num_epochs):
    log_msg = f"Epoch {epoch + 1}/{cfg.num_epochs}"
    logging.info(log_msg)
    print(log_msg)

    for is_training in (True, False):
        avg_loss, accuracy = run_epoch(
            cfg.device,
            model,
            train_dl if is_training else val_dl,
            criterion,
            optimizer if is_training else None,
        )

        mode = "Train" if is_training else "Val"
        log_msg = f"{mode} Loss: {avg_loss:.4f}, {mode} Acc: {accuracy:.4f}"
        logging.info(log_msg)
        print(log_msg)

    if avg_loss < best_val_loss:
        best_val_loss = avg_loss
        epochs_without_improvement = 0
        torch.save(model.state_dict(), best_model_path)
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            log_msg = f"Stopping early after {epoch + 1} epochs (no improvement for {patience} epochs)."
            logging.info(log_msg)
            print(log_msg)
            break

In [None]:
model.load_state_dict(torch.load(best_model_path))
model.to(cfg.device)

test_loss, test_accuracy = run_epoch(cfg.device, model, test_dl, criterion)

log_msg = f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}"
logging.info(log_msg)
print(log_msg)