In [None]:
from config import Config
from data import DeepSpeakBertDataset, save_samples, split_raw_data
from models import BaselineModel, DeepSpeakBertModel
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils import run_epoch, train
import logging
import os
import torch
import torch.nn as nn

In [None]:
cfg = Config(
    datasets_dir="datasets", raw_dir="raw",
    split_dir="split", train_dir="train", val_dir="val", test_dir="test",
    groups_dir="groups", messages_dir="messages",
    samples_dir="samples", meta_json="meta.json",
    recreate_datasets=False, recreate_samples=False,
    val_split=0.2, test_split=0.2,
    device="mps", best_model_path="best_model.pt", checkpoint="checkpoint.pt",
    resume=False, num_epochs=256, batch_size=8,
    max_context_length=512, max_group_size=256, max_samples_per_group=1024, patience=8,
    output_dir="output", log="bert.log",
)

if not os.path.isdir(cfg.output_dir):
    os.makedirs(cfg.output_dir, exist_ok=True)

logging.basicConfig(
    filename=os.path.join(cfg.output_dir, cfg.log),
    filemode='w',
    format='%(asctime)s - %(levelname)s: %(message)s',
    level=logging.INFO,
)

In [None]:
split_raw_data(cfg)

In [None]:
save_samples(cfg)

In [None]:
samples_dir = os.path.join(cfg.datasets_dir, cfg.samples_dir)
samples_dirs = tuple(os.path.join(samples_dir, d) for d in (cfg.train_dir, cfg.val_dir, cfg.test_dir))
datasets = tuple(DeepSpeakBertDataset(cfg, d) for d in samples_dirs)
train_dl = DataLoader(datasets[0], batch_size=cfg.batch_size)
val_dl = DataLoader(datasets[1], batch_size=cfg.batch_size)
test_dl = DataLoader(datasets[2], batch_size=cfg.batch_size)

In [None]:
#model = BaselineModel()
#optimizer = None
#criterion = nn.CrossEntropyLoss()

model = DeepSpeakBertModel(cfg)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=2e-5)

In [None]:
train(cfg, model, train_dl, val_dl, criterion, optimizer)

In [None]:
model.load_state_dict(torch.load(os.path.join(cfg.output_dir, cfg.best_model_path)))
model.to(cfg.device)

test_loss, test_accuracy = run_epoch(cfg.device, model, test_dl, criterion)

log_msg = f"Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.4f}"
logging.info(log_msg)
print(log_msg)