In [None]:
'''
Libaries
'''
import os
import re
import torch
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

from transformers import BertTokenizer, BertModel

In [None]:
'''
Dataset
'''

def filter(text):
    pattern = re.compile("[^.^!^?^'^ ^a-z^A-Z^0-9]")
    text = pattern.sub('', text)

    text = re.sub(" +", " ", text)
    text = re.sub("''+", "", text)
    return text

class Cleaner(object):
    def __init__(self):
        pass
    def __call__(self, text):
        text = re.sub(r'-', ' ', text)
        text = re.sub(r"$NEWLINE$", " ", text)
        text = re.sub(r"NEWLINE", " ", text)

        # remove urls
        text = re.sub('((www\.[^\s]+)|(https?://[^\s]+))', ' ', text)
        # remove @somebody
        text = re.sub(r"@\S+", "", text)

        # remove #topic
        text = re.sub(r"#\S+", "", text)

        # clean unrecognizable characters
        text = filter(text)

        # text = text.lower()
        text = re.sub(" +", " ", text)

        return text.strip()

class EICDataset(Dataset):
    def __init__(self, path, mode):
        super().__init__()
        self.mode = mode
        self.data = pd.read_csv(path)
        self.original = []
        self.edited = []
        cleaner = Cleaner()
        for i in range(0, len(self.data)):
            temp = self.data['original'][i]
            self.original.append(cleaner(temp))
            temp = re.sub('<.*/>', self.data['edit'][i], temp)
            self.edited.append(cleaner(temp))

    def __getitem__(self, index):
        if self.mode == 'train' or self.mode == 'dev':
            return self.data['id'][index], self.original[index], self.edited[index], self.data['meanGrade'][index]
        else:
            return self.data['id'][index], self.original[index], self.edited[index]

    def __len__(self):
        return len(self.data)


class Collator(object):
    def __init__(self, tokenizer, max_len, mode):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mode = mode

    def __call__(self, batch):
        passages = [ex[1] for ex in batch]
        passages = self.tokenizer.batch_encode_plus(
            passages,
            max_length=self.max_len if self.max_len > 0 else None,
            padding='max_length',
            return_tensors='pt',
            truncation=True if self.max_len > 0 else False,)

        if self.mode == 'train' or self.mode == 'dev':
            targets = torch.tensor([ex[3] for ex in batch]).float()
            return passages, targets
        else:
            ids = [ex[0] for ex in batch]
            return passages, ids

In [None]:
'''
Model
'''

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.linear = nn.Linear(768, 1)

    def forward(self, input):
        outputs = self.bert(**input) 
        pooled_output = outputs[1] # B x 768
        logits = self.linear(pooled_output) # B x 1

        return logits

In [None]:
'''
Hyper parameters
'''
random_seed = 42
max_len = 64
batch_size = 4
epochs = 20
lr = 2e-5
device = 'cuda:0'
checkpoint_dir = './'

In [None]:
'''
Environment setup
'''
# random seed
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
random.seed(random_seed)

tk = BertTokenizer.from_pretrained('bert-base-uncased')

train_data = EICDataset('./Dataset/train.csv', 'train')
dev_data = EICDataset('./Dataset/dev.csv', 'dev')
test_data = EICDataset('./Dataset/test.csv', 'test')

label_collator = Collator(tk, max_len, 'train')
test_collator = Collator(tk, max_len, 'test')

train_loader = DataLoader(train_data, batch_size = batch_size, collate_fn=label_collator)
dev_loader = DataLoader(dev_data, batch_size = batch_size, collate_fn=label_collator)
test_loader = DataLoader(test_data, batch_size = batch_size, collate_fn=test_collator)

model = MyModel().to(device)

optimizer = optim.AdamW(model.parameters(), lr=lr)

criterion = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
'''
Training loop
'''
best_ckp = None
best_loss = np.inf
for epoch in range(1, epochs + 1):
    model.train()
    loss_accum = 0

    for step, batch in enumerate(tqdm(train_loader, desc="Epoch {}".format(epoch))):

        for key in batch[0].keys():
            batch[0][key] = batch[0][key].to(device)

        pred = model(batch[0])
        #print(pred, batch[1])
        optimizer.zero_grad()

        loss = criterion(pred, batch[1].to(device))
        loss.backward()
        optimizer.step()

        loss_accum += loss.detach().cpu().item()

    train_loss = loss_accum / (step + 1)

    model.eval()

    loss_accum = 0
    for step, batch in enumerate(tqdm(dev_loader, desc="Dev")):

        for key in batch[0].keys():
            batch[0][key] = batch[0][key].to(device)

        with torch.no_grad():
            pred = model(batch[0])

        loss = criterion(pred, batch[1].to(device))

        loss_accum += loss.detach().cpu().item()

    dev_loss = loss_accum / (step + 1)
    print(f'Current Train Loss: {train_loss}, Current Dev Loss: {dev_loss}, Latest Lr: {scheduler.get_last_lr()[0]}')
    
    if dev_loss < best_loss:
        best_loss = dev_loss
        checkpoint = {'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_metric': best_loss}
        best_ckp = os.path.join(checkpoint_dir, 'checkpoint.pt')
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'checkpoint.pt'))
        print(f'Best validation metric so far: {best_loss}')

    scheduler.step()
    


In [None]:
'''
Give test results
'''
ids = []
res = []
model.load_state_dict(torch.load(best_ckp)['model_state_dict'])
model.eval()
for step, batch in enumerate(tqdm(test_loader, desc="Test")):

    for key in batch[0].keys():
        batch[0][key] = batch[0][key].to(device)

    with torch.no_grad():
        pred = model(batch[0])

    ids.extend(batch[1])
    res.append(pred[0].item())

with open('./output.csv', 'w+', encoding='utf-8') as f:
    f.write('id\tpred\n')
    for item in zip(ids, res):
        f.write(item[0] + '\t' + item[1] + '\n')
        