In [1]:
import os

import torch

from torch.utils.data import DataLoader

import pandas as pd

from dataset import HISCODataset
from model import HISCOClassifier
from dirs import DATA_DIR

In [2]:
train_data = pd.read_csv(os.path.join(DATA_DIR, 'toy_data_train.csv'))
test_data = pd.read_csv(os.path.join(DATA_DIR, 'toy_data_test.csv'))

In [3]:
dataset_train = HISCODataset(train_data)
dataset_test = HISCODataset(test_data)

In [4]:
data_loader_train = DataLoader(dataset_train, batch_size=32)
data_loader_val = DataLoader(dataset_test, batch_size=32)

In [5]:
model = HISCOClassifier()

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

In [7]:
def train_epoch(model, optimizer, data_loader, loss_fn):
    model.train()

    for idx, batch in enumerate(data_loader, start=1):
        # if idx % 10 == 0:
        #     print(f'Batch {idx} of {len(data_loader)}')
        optimizer.zero_grad()
        
        out = model(batch['encoded'])
        loss = loss_fn(out, batch['label'])
        loss.backward()

        optimizer.step()

In [8]:
@torch.no_grad
def evaluate(model, data_loader):
    model.eval()

    total_correct = 0
    total_count = 0

    for batch in data_loader:
        out = model(batch['encoded'])

        total_correct += (out.argmax(1) == batch['label']).sum().item()
        total_count += batch['label'].size(0)

    return total_correct / total_count

In [9]:
for epoch in range(1, 11):
    train_epoch(model, optimizer, data_loader_train, loss_fn)
    acc = evaluate(model, data_loader_val)
    
    print(f'Trained for {epoch} epochs. Validation accuracy: {acc}')

Trained for 1 epochs. Validation accuracy: 0.676
Trained for 2 epochs. Validation accuracy: 0.701
Trained for 3 epochs. Validation accuracy: 0.711
Trained for 4 epochs. Validation accuracy: 0.742
Trained for 5 epochs. Validation accuracy: 0.765
Trained for 6 epochs. Validation accuracy: 0.759
Trained for 7 epochs. Validation accuracy: 0.753
Trained for 8 epochs. Validation accuracy: 0.751
Trained for 9 epochs. Validation accuracy: 0.759
Trained for 10 epochs. Validation accuracy: 0.756
