In [None]:
import math
import torch
import time

import torch.nn as nn
import torch.nn.functional as F

from typing import Iterable
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sklearn.metrics import auc
from sklearn.metrics import roc_curve
import sys

from datasets import load_dataset
from torch.utils.data import Dataset

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads # dimension of each head

        assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        
        # output layer
        self.W_o = nn.Linear(embed_dim, embed_dim)


    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
        b = query.shape[0] # batch size
        
        q = self.W_q(query).view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len, head_dim)
        k = self.W_k(key).view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.W_v(value).view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)

        dot_product_score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            dot_product_score = dot_product_score.masked_fill(mask == 0, -1e9)

        attention_scores = F.softmax(dot_product_score, dim=-1)
        out = torch.matmul(attention_scores, v)

        out = out.transpose(1, 2).contiguous().view(b, -1, self.embed_dim)
        out = self.W_o(out)

        return out

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=10000):
        super(PositionalEncoding, self).__init__()

        self.embed_dim = embed_dim
        self.max_len = max_len

        self.pe = self.compute_positional_encoding()
        
    def forward(self, x):
        seq_len = x.shape[1]
        input_dim = x.shape[2]

        pe = self.pe[:, :seq_len, :]
        x = x + pe.to(x.device)
        return x
    
    def compute_positional_encoding(self):
        position = torch.arange(0, self.max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, self.embed_dim, 2).float() * (-math.log(10000.0) / self.embed_dim))
        pe = torch.zeros(self.max_len, self.embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        return pe

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        super(FeedForwardNetwork, self).__init__()

        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim

        self.linear1 = nn.Linear(embed_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x
    
class TransformerEncoderCell(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):
        super(TransformerEncoderCell, self).__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim

        self.multi_head_attention = MultiHeadAttention(embed_dim, num_heads)
        self.feed_forward_network = FeedForwardNetwork(embed_dim, hidden_dim, dropout)

        self.norm_attention = nn.LayerNorm(embed_dim)
        self.norm_ffn = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # multi-head attention
        attention_output = self.multi_head_attention(x, x, x, mask)
        
        x = x + self.dropout(attention_output)
        x = self.norm_attention(x)

        # feed forward network
        ffn_output = self.feed_forward_network(x)
        y = x + self.dropout(ffn_output)
        y = self.norm_ffn(y)

        return y
    
class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, hidden_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()

        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim

        self.norm = nn.LayerNorm(embed_dim)

        self.encoder_cells = nn.ModuleList([TransformerEncoderCell(embed_dim, num_heads, hidden_dim, dropout) for _ in range(num_layers)])

    def forward(self, x, mask):
        for encoder_cell in self.encoder_cells:
            x = encoder_cell(x, mask)
        x = self.norm(x)
        return x


class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, num_layers, embed_dim, num_heads, hidden_dim, num_classes, dropout=0.1, pad_token:int=0):
        super(TransformerClassifier, self).__init__()

        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_token)
        self.positional_encoding = PositionalEncoding(embed_dim)
        self.encoder = TransformerEncoder(num_layers, embed_dim, num_heads, hidden_dim, dropout)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, text, mask):
        embedded = self.embedding(text) * math.sqrt(self.embed_dim)
        position_encoded = self.positional_encoding(embedded)
        transformer_output = self.encoder(position_encoded, mask)
        
        # Average Pool
        x = torch.mean(transformer_output, dim=1)
        logits = self.fc(x)
        
        return logits
    

In [None]:
class TextClassificationDataset(Dataset):
    def __init__(self, data_dir):
        self.data = load_dataset('csv', data_files=data_dir, split='train')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, text = self.data[idx]['output'], self.data[idx]['input']
        return label, text

In [None]:
tokenizer = get_tokenizer("basic_english")
dataset_name = "../../data/lm_finetune_data/ctx-only_train.csv"
train_iter = TextClassificationDataset(dataset_name)

def yield_tokens(data_iter: Iterable, tokenizer):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter, tokenizer), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

PAD_TOKEN = vocab(tokenizer('<pad>'))
assert len(PAD_TOKEN) == 1
PAD_TOKEN = PAD_TOKEN[0]


def collate_batch(batch):
    label_list, text_list, text_len_list = [], [], []
    for (_label, _text) in batch:
        label_list.append(label_pipeline(_label))
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        text_len_list.append(processed_text.size(0))

    max_len = max(text_len_list)
    padded_text_list = [F.pad(text, pad=(0, max_len - len(text)), value=PAD_TOKEN) for text in text_list]

    batched_label, batched_text = torch.tensor(label_list), torch.stack(padded_text_list, dim=0)
    return batched_label, batched_text

##### Decide Hyperparameters and Training

In [None]:
def train(model, dataloader, loss_func, device, grad_norm_clip, optimizer, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()
    global logits_checker
    total_predictions = None
    total_labels = None

    for idx, (label, text) in enumerate(dataloader):
        label = label + 1
        label = label.to(device)
        text = text.to(device)
        optimizer.zero_grad()
        logits = model(text, mask=None)
        loss = loss_func(logits, label)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.step()
        total_acc += (logits.argmax(1) == label).sum().item()
        total_count += label.size(0)

        if idx == 0:
            total_predictions = logits.argmax(1)
            total_labels = label
        else:
            total_predictions = torch.cat((total_predictions, logits.argmax(1)), 0)
            total_labels = torch.cat((total_labels, label), 0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | positive precision {:8.3f}'.format(epoch, idx, len(dataloader), total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(model, dataloader, loss_func, device):
    model.eval()
    total_acc, total_count = 0, 0
    total_predictions = None
    total_labels = None
    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            label = label + 1
            label = label.to(device)
            text = text.to(device)
            logits = model(text, mask=None)
            loss = loss_func(logits, label)
            if idx == 0:
                total_predictions = logits.argmax(1)
                total_labels = label
            else:
                total_predictions = torch.cat((total_predictions, logits.argmax(1)), 0)
                total_labels = torch.cat((total_labels, label), 0)
            total_acc += (logits.argmax(1) == label).sum().item()
            total_count += label.size(0)
    
    print("Num of positive labels: ", total_labels.sum())
    print("Num of labels: ", total_labels.size(0))

    # calculate the confusion matrix
    cm = confusion_matrix(total_labels.cpu().numpy(), total_predictions.cpu().numpy())
    tn,fp,fn,tp = cm.ravel()
    fpr = fp / (fp + tn)

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)    

    print(f"Positive Precision: {precision}")
    print(f"Positive Recall: {recall}")
    f1 = f1_score(total_labels.cpu().numpy(), total_predictions.cpu().numpy(), average=None)
    print(f"Positive F1: {f1}")

    fpr, tpr, _ = roc_curve(total_labels.cpu().numpy(), total_predictions.cpu().numpy())
    print(f"Area Under Curve: {auc(fpr, tpr)}")
    print(f"tn: {tn}, fp: {fp}, fn: {fn}, tp: {tp}")
    return total_acc / total_count, f1[1]


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
epochs = 50
lr = 0.0005
batch_size = 256

dataset_name = "../../data/lm_finetune_data/context-only_train.csv"
validation_dataset_name = "../../data/lm_finetune_data/context-only_dev.csv"
train_iter = TextClassificationDataset(dataset_name)
validation_iter = TextClassificationDataset(validation_dataset_name)
num_classes = len(set([label for label, _ in train_iter]))
vocab_size = len(vocab)

gradient_norm_clip = 1.0
emb_size = 64

def model_training(epochs, lr, batch_size, num_layers, num_heads, emb_size):
    sys.stdout = open(f'../../outputs/MHA/output_e_{epochs}_lr_{lr}_b_{batch_size}_heads_{num_heads}_layers_{num_layers}_emb_{emb_size}.txt', 'w', buffering=1)
    print(f"Running with epochs: {epochs}, lr: {lr}, batch_size: {batch_size}, num_heads: {num_heads}, num_layers: {num_layers}, emb_size: {emb_size}")
    model = TransformerClassifier(vocab_size=vocab_size, 
                                            num_layers=num_layers, 
                                            embed_dim=emb_size, 
                                            num_heads=num_heads, 
                                            hidden_dim=emb_size, 
                                            num_classes=num_classes)

    loss_fn = nn.CrossEntropyLoss()

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-8)
    total_accu = None

    train_dataloader = DataLoader(train_iter, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    validation_dataloader = DataLoader(validation_iter, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

    best_f1 = 0
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train(model, train_dataloader, loss_fn, device, gradient_norm_clip, optimizer, epoch)
        accu_val, f1 = evaluate(model, validation_dataloader, loss_fn, device)
        if f1 > best_f1:
            best_f1 = f1
            model.save_state_dict(f"results/MHA/model_e_{epochs}_lr_{lr}_b_{batch_size}_heads_{num_heads}_layers_{num_layers}_emb_{emb_size}.pt")
        if total_accu is not None and total_accu > accu_val:
            scheduler.step()
        else:
            total_accu = accu_val
        print('-' * 59)
        print('| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} '.format(epoch, (time.time() - epoch_start_time), accu_val))
        print('-' * 59)
    print(f"Best validation F1: {best_f1}")


for num_layers in range(4,13):
    for num_heads in [2,4]:
        for emb_idx, emb_size in enumerate([64, 128, 256]):
            model_training(epochs, lr, batch_size, num_layers, num_heads, emb_size)