In [None]:
'''
Libaries
'''
import os
import torch
import random
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

from transformers import BertTokenizer, BertModel

In [None]:
'''
Dataset
'''

class EICDataset(Dataset):
    def __init__(self):
        super().__init__()

    def __getitem__(self, index):
        return super().__getitem__(index)

    def __len__(self):
        pass


class Collator(object):
    def __init__(self):
        pass

    def __call__(self):
        pass

In [None]:
'''
Model
'''

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self):
        pass

In [None]:
'''
Hyper parameters
'''
random_seed = 42
max_len = 64
batch_size = 4
epochs = 20
lr = 2e-5
device = 'cuda:0'
checkpoint_dir = './'

In [None]:
'''
Environment setup
'''
# random seed
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)

train_data = EICDataset()
dev_data = EICDataset()
test_data = EICDataset()

collator = Collator()

train_loader = DataLoader(train_data, batch_size = batch_size, collate_fn=collator)
dev_loader = DataLoader(dev_data, batch_size = batch_size, collate_fn=collator)
test_loader = DataLoader(test_data, batch_size = batch_size, collate_fn=collator)

model = MyModel().to(device)

optimizer = optim.AdamW(model.parameters(), lr=lr)

criterion = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
'''
Training loop
'''
best_ckp = None
best_loss = np.inf
for epoch in range(1, epochs + 1):
    model.train()
    loss_accum = 0

    for step, batch in enumerate(tqdm(train_loader, desc="Epoch {}".format(epoch))):

        for key in batch[0].keys():
            batch[0][key] = batch[0][key].to(device)

        pred = model(batch[0])
        #print(pred, batch[1])
        optimizer.zero_grad()

        loss = criterion(pred, batch[1].to(device))
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().cpu().item()

    train_loss = loss_accum / (step + 1)

    model.eval()

    loss_accum = 0
    for step, batch in enumerate(tqdm(dev_loader, desc="Dev")):

        for key in batch[0].keys():
            batch[0][key] = batch[0][key].to(device)

        with torch.no_grad():
            pred = model(batch[0])

        loss = criterion(pred, batch[1].to(device))

        loss_accum += loss.detach().cpu().item()

    dev_loss = loss_accum / (step + 1)
    
    if dev_loss < best_loss:
        best_loss = dev_loss
        checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_metric': best_loss}
        best_ckp = os.path.join(checkpoint_dir, 'checkpoint.pt')
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint.pt'))

    scheduler.step()
    print(f'Best validation metric so far: {best_loss}, Latest Lr: {scheduler.get_last_lr()[0]}')


In [None]:
'''
Give test results
'''
ids = []
res = []
model.load_state_dict(torch.load(best_ckp)['model_state_dict'])
model.eval()
for step, batch in enumerate(tqdm(test_loader, desc="Test")):

    for key in batch[0].keys():
        batch[0][key] = batch[0][key].to(device)

    with torch.no_grad():
        pred = model(batch[0])

    res.append(pred)

with open('./output.csv', 'w+', encoding='utf-8') as f:
    f.write('id\tpred\n')
    for item in zip(ids, res):
        f.write(item[0] + '\t' + item[1] + '\n')
        