# MEME CLASSIFIER - TASK 2

Useful links:

- [TRANSFORMERS DOCS](https://huggingface.co/docs/transformers)
- [TASK PAGE](https://propaganda.math.unipd.it/semeval2021task6/)
- [TASK GITHUB](https://github.com/di-dimitrov/SEMEVAL-2021-task6-corpus#task-description)
- [TOKEN CLASSIFICATION TUTORIAL KAGGLE](https://www.kaggle.com/code/thanish/bert-for-token-classification-ner-tutorial?scriptVersionId=66778357)
- [SOURCE FOR TOKEN CLASSIFICATION](https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1705)

### PACKAGE INSTALLATION

In [1]:
! pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 5.0 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 16.9 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 65.5 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.0 tokenizers-0.13.2 transformers-4.24.0


### DRIVE LINKING

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### CONSTANTS

In [3]:
LABELS = {
    'Black-and-white Fallacy/Dictatorship': 0,
    'Loaded Language': 1, 
    'Name calling/Labeling': 2, 
    'Slogans': 3, 
    'Smears': 4, 
    'Causal Oversimplification': 5, 
    'Appeal to fear/prejudice': 6, 
    'Exaggeration/Minimisation': 7, 
    'Reductio ad hitlerum': 8, 
    'Repetition': 9, 
    'Glittering generalities (Virtue)': 10, 
    "Misrepresentation of Someone's Position (Straw Man)": 11, 
    'Doubt': 12, 
    'Obfuscation, Intentional vagueness, Confusion': 13, 
    'Whataboutism': 14, 
    'Flag-waving': 15, 
    'Thought-terminating cliché': 16, 
    'Presenting Irrelevant Data (Red Herring)': 17, 
    'Appeal to authority': 18, 
    'Bandwagon': 19,
}
N_LABELS = len(LABELS)

### IMPORTS

In [4]:
import os
from datetime import datetime

import re

import warnings

import json

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from tabulate import tabulate
import seaborn as sn

from tqdm.notebook import tqdm

from torch.utils.data import DataLoader

import torch 
from torch import nn
from torch.nn import functional as F
from torch.optim import AdamW

from transformers import AutoTokenizer
from transformers import AutoModel, AutoConfig
from transformers import PreTrainedModel, BertModel, BertConfig
from transformers import get_linear_schedule_with_warmup
from transformers.utils import logging

from sklearn.metrics import f1_score
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.metrics import roc_curve

### DATASET

####  DATASET LOADING

In [5]:
def load_datasets():

    def load_json2pandas(path): 
        
        with open(path) as f:
            data = json.load(f)
        return pd.DataFrame(data)

    folder = "drive/MyDrive/DeepLearning/Dataset/task2/"

    train = load_json2pandas(folder + "training_set_task2.txt")
    test = load_json2pandas(folder + "test_set_task2.txt")
    dev = load_json2pandas(folder + "dev_set_task2.txt")

    return train, test, dev

#### LABELS ENCODING

In [6]:
class Span():


    def __init__(self, start=-1, end=-1):

        if start > end: 
            start = -1 
            end = -1
        self.start = start
        self.end = end

    def is_valid(self) -> bool: 

        return self.start >= 0 and self.end >= 0 and self.start <= self.end

    def intersect(self, other):

        if not self.is_valid() or not other.is_valid(): 
            return Span()
        return Span(max(self.start, other.start), min(self.end, other.end))

    def length(self) -> int: 

        if not self.is_valid(): 
            return 0
        return self.end - self.start

    def __eq__(self, __o: object) -> bool:

        return self.start == __o.start and self.end == __o.end

    def __lt__(self, __o: object) -> bool:

        if self.start == __o.start: 
            return ((self.end) < (__o.end))
        return ((self.start) < (__o.start))

    def __str__(self) -> str:

        return "({}, {})".format(self.start, self.end)

    def __repr__(self) -> str:

        return "({}, {})".format(self.start, self.end)


def label_dict_to_ohe(label_dict, offset_mapping):

    encoded = []
    for start, end in offset_mapping: 
        label_ohe = [0]*N_LABELS
        for l in label_dict:
            if start == 0 and end == 0: 
                continue
            if start >= l['start'] and end <= l['end'] :
                label_ohe[LABELS[l['technique']]] = 1

        encoded.append(label_ohe)

    return encoded

def label_dict_to_spans(label_dict):

    return [[l['start'], l['end'], LABELS[l['technique']]] for l in label_dict]

def spans_to_spans_for_metrics(spans):

    spans_for_metrics = [[] for _ in range(N_LABELS)]
    for s, e, t in spans: 
        if t == -1: break
        spans_for_metrics[t].append(Span(s, e))

    return spans_for_metrics

def ohe_to_spans(ohe_labels, spans):

    results = []
    for class_result in ohe_labels.T:
        result = []
        span = None
        for token_res, token_span in zip(class_result, spans):
            if span is None:
                if token_res == 1:
                    span = Span(token_span[0], token_span[1])
            else:
                if token_res == 1:
                    span.end = token_span[1]
                if token_res == 0:
                    result.append(span)
                    span = None

        if span is not None: result.append(span)
        results.append(result)
        
    return results

#### EXAMPLES ON THE DATASET

In [7]:
train, _, _ = load_datasets()
train.head()

Unnamed: 0,id,text,labels
0,128,THERE ARE ONLY TWO GENDERS\n\nFEMALE \n\nMALE\n,"[{'start': 0, 'end': 41, 'technique': 'Black-a..."
1,189,This is not an accident!,[]
2,96,SO BERNIE BROS HAVEN'T COMMITTED VIOLENCE EH?\...,"[{'end': 83, 'start': 47, 'technique': 'Slogan..."
3,154,PATHETIC\n\nThe Cowardly Asshole\nWeak Failure...,"[{'end': 8, 'start': 0, 'technique': 'Loaded L..."
4,15,WHO TRUMP REPRESENTS\n\nWHO DEMOCRATS REPRESENT\n,[]


#### DATASET

In [8]:
def get_val_mask(word_ids):
    
    return [0 if id == None else 1 for id in word_ids]

def pad_spans(l, max):

    num_pad = max - len(l)
    for i in range(num_pad): 
        l.append([0, 0, -1])
    return l


class MemeDataset():


    def __init__(self, ids, texts, label_dicts, tokenizer):

        self.texts = texts
        self.label_dicts = label_dicts

        self.tokenizer = tokenizer
    
    def __len__(self):

        return len(self.texts)

    def __getitem__(self, index):

        text = self.texts[index]
        label_dict = self.label_dicts[index]

        inputs = self.tokenizer.__call__(text,
                                         None,
                                         add_special_tokens=True,
                                         padding="max_length",
                                         truncation=True,
                                         return_offsets_mapping=True
                                         )
        
        input_ids = inputs["input_ids"]

        mask = inputs["attention_mask"]
        val_mask = get_val_mask(inputs.word_ids())
        labels = label_dict_to_ohe(label_dict, inputs.offset_mapping)
        mapping = inputs.offset_mapping
        labels_spans = label_dict_to_spans(label_dict)
        labels_spans = pad_spans(labels_spans, 20)

        return {
            "ids": torch.tensor(input_ids, dtype=torch.long),
            "mask": torch.tensor(mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "val_mask": torch.tensor(val_mask, dtype=torch.long),
            "mapping": torch.tensor(mapping, dtype=torch.long),
            "labels_spans": torch.tensor(labels_spans, dtype=torch.long)
        }


#### WRAPPERS

In [9]:
def build_datasets(train, test, dev):

    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

    train_ds = MemeDataset(
        train.id.tolist(), 
        train.text.tolist(), 
        train.labels.tolist(), 
        tokenizer)
    test_ds = MemeDataset(
        test.id.tolist(), 
        test.text.tolist(), 
        test.labels.tolist(), 
        tokenizer)
    dev_ds = MemeDataset(
        dev.id.tolist(), 
        dev.text.tolist(), 
        dev.labels.tolist(), 
        tokenizer)
    
    return train_ds, test_ds, dev_ds

def build_dataloaders(train_ds, test_ds, dev_ds):
    
    train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
    test_dl = DataLoader(test_ds, batch_size=16, shuffle=True, num_workers=2)
    dev_dl = DataLoader(dev_ds, batch_size=16, shuffle=True, num_workers=2)
    return train_dl, test_dl, dev_dl 

### MODEL

In [10]:
class MemeClassifier(nn.Module):
    """Our classifier for Task 2.
    
    It predicts a label on each token for every class based on the 
    output of a Bert transformer, then fed to a linear layer."""

    def __init__(self, n_classes, do_prob=0.2, bert_path="bert-base-uncased"):

        super(MemeClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(
            bert_path, add_pooling_layer=False)
        self.dropout = nn.Dropout(do_prob)
        self.classifier = nn.Linear(768, n_classes)
        self.step_scheduler_after = "batch"
        self.sigmoid = torch.sigmoid

    def forward(self, input_ids, attention_mask=None):
        output = self.bert(input_ids, attention_mask=attention_mask)[0]
        output = self.dropout(output)
        output = self.classifier(output)
        output = self.sigmoid(output)

        return output

In [11]:
def save_model(model, thresholds, path):
    """Saves the models with their thresholds."""

    torch.save({
            'model_state_dict': model.state_dict(),
            'thresholds': thresholds,
            }, path)

def load_model(path, device):
    """Loads the models saved with their thresholds."""

    model = MemeClassifier(N_LABELS)
    model = nn.DataParallel(model)

    checkpoint = torch.load(path, map_location=device)

    model.load_state_dict(checkpoint['model_state_dict'])
    thresholds = checkpoint['thresholds']
    
    return model, thresholds

### OPTIMIZER

In [12]:
def build_optimizer(model, lr=3e-5):

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(
                    nd in n for nd in no_decay)
            ],
            "weight_decay": 0.001,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(
                    nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    opt = AdamW(optimizer_parameters, lr=lr)
    return opt

### PLOTTER

In [13]:
class Plotter():
    """Keeps track of the metrics and create plots."""

    def __init__(self):

        self.data = {}

    def add(self, d):

        for metric in d:
            if metric in self.data: 
                self.data[metric].append(d[metric])
            else: 
                self.data[metric] = [d[metric]]

    def plot_metric(self, metric):

        y = self.data[metric]
        x = range(1, len(y) + 1)

        plt.figure()
        plt.title(metric)
        plt.xlabel('epoch')
        plt.plot(x, y) 

    def plot(self):

        for metric in self.data:
            self.plot_metric(metric)
            plt.show()

    def save(self, dir_name):

        if not os.path.isdir(dir_name): return
            
        for metric in self.data:
            self.plot_metric(metric)

            plt.savefig("".join([dir_name, metric]))
            plt.close()

### TRAINER

In [14]:
class FocalLoss(nn.modules.loss._WeightedLoss):


    def __init__(self, weight=None, gamma=2):

        super(FocalLoss, self).__init__(weight)
        self.gamma = gamma
        self.weight = weight 

    def forward(self, input, target):

        ce_loss =  F.binary_cross_entropy(
            input, target, reduction='none', weight=self.weight)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1-pt)**self.gamma * ce_loss).mean()
        return focal_loss

In [15]:
def find_opt_threshold(outputs, labels):

    def to_labels(outputs, threshold): 
        return (outputs >= threshold).astype('int')

    thresholds = np.arange(0.1, 1, 0.001)

    scores = [
        f1_score(
            labels, 
            to_labels(outputs, t), 
            average=None, 
            zero_division=1
        ) for t in thresholds
    ]

    idxs = np.argmax(scores, axis=0)

    return np.take(thresholds, idxs)

def find_opt_threshold_gmean(outputs, labels):

    def get_t(lab, out):

        fpr, tpr, thresholds = roc_curve(lab, out)
        gmeans = np.sqrt(tpr * (1-fpr))
        return thresholds[np.argmax(gmeans)]

    return np.array(
        [get_t(lab, out) for lab, out in zip(labels.T, outputs.T)])

def apply_threshold_per_class(outputs, thresholds):

    y = []

    for s_output in outputs: 
        t_y = []
        for t_output in s_output:
            t_y.append(
                [0 if p < t else 1 for p, t in zip(t_output, thresholds)])
        y.append(t_y)

    return np.array(y)

def apply_threshold_to_all(outputs, threshold):

    return (outputs >= threshold).astype('int')

def C(s, t, h): return s.intersect(t).length() / h

def precision_metric(S, T):

    Sn = 0
    sum = 0
    for c_S, c_T in zip(S, T):
        for s in c_S:
            for t in c_T: 
                sum += C(s, t, s.length())

        Sn += len(c_S)

    if Sn == 0 and sum == 0: 
        return 1
    if Sn == 0: 
        return 0

    return sum / Sn

def recall_metric(S, T): return precision_metric(T, S)

def f1_metric(pred_spans, label_spans):

    p_sum = 0
    r_sum = 0
    for S, T in zip(pred_spans, label_spans):
        p_sum += precision_metric(S, T)
        r_sum += recall_metric(S, T)

    p = p_sum / len(pred_spans)
    r = r_sum / len(pred_spans)

    return 2 * (p*r) / (p+r), p, r

def get_metrics(preds, labels, pred_spans, label_spans): 

    p = np.reshape(preds, (-1, 20))
    l = np.reshape(labels, (-1, 20))

    f1_micro_tok = f1_score(l, p, average='micro', zero_division=1)
    f1_macro_tok = f1_score(l, p, average='macro', zero_division=1)

    f1, precision, recall = f1_metric(pred_spans, label_spans)

    metrics = {'f1-micro_tokens': f1_micro_tok, 
               'f1-macro_tokens': f1_macro_tok,
               'f1': f1,
               'precision': precision,
               'recall': recall}
    return metrics

In [16]:
def eval(dataloader, model, device, loss_fn=None, thresholds=None):

    eval_loss = 0.0
    model.eval()
    fin_targets = []
    fin_outputs = []
    fin_mapping = []
    fin_l_spans = []
    fin_va_mask = []
    
    with torch.no_grad():
        for bi, d in enumerate(dataloader):
            ids = d["ids"]
            mask = d["mask"]
            targets = d["labels"]
            va_mask = d["val_mask"]
            l_spans = d["labels_spans"]
            mapping = d["mapping"]

            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(ids, mask)

            if loss_fn is not None:
                loss = loss_fn(outputs, targets)
                eval_loss += loss.item()

            fin_targets.append(targets.cpu().detach().numpy())
            fin_outputs.append(outputs.cpu().detach().numpy())
            fin_mapping.append(mapping.cpu().detach().numpy())
            fin_l_spans.append(l_spans.cpu().detach().numpy())
            fin_va_mask.append(va_mask.cpu().detach().numpy())

    fin_outputs = np.concatenate(fin_outputs, axis=0)
    fin_targets = np.concatenate(fin_targets, axis=0)
    fin_mapping = np.concatenate(fin_mapping, axis=0)
    fin_l_spans = np.concatenate(fin_l_spans, axis=0)
    fin_va_mask = np.concatenate(fin_va_mask, axis=0)

    idxs = fin_va_mask == 1

    true = fin_targets[idxs]

    if thresholds is None:
        thresholds = find_opt_threshold(fin_outputs[idxs], true)

    pred = apply_threshold_per_class(fin_outputs, thresholds)
    pred_non_opt = apply_threshold_to_all(fin_outputs, 0.5)

    pred4spans = np.array([p[i] for p, i in zip(pred, idxs)])
    pred4spans_non_opt = np.array([p[i] for p, i in zip(pred_non_opt, idxs)])
    mapping = np.array([m[i] for m, i in zip(fin_mapping, idxs)])

    pred = pred[idxs]
    pred_non_opt = pred_non_opt[idxs]

    pred_spans = [
        ohe_to_spans(l, m) for l, m in zip(pred4spans, mapping)]
    pred_spans_non_opt = [
        ohe_to_spans(l, m) for l, m in zip(pred4spans_non_opt, mapping)]
    true_spans = [spans_to_spans_for_metrics(s) for s in fin_l_spans]

    return {
        "loss": eval_loss,
        "pred": pred,
        "pred_non_opt": pred_non_opt,
        "pred_spans": pred_spans,
        "pred_spans_non_opt": pred_spans_non_opt,
        "true": true,
        "true_spans": true_spans,
        "thresholds": thresholds,
    }


class Trainer():
    """Trains a model while printing results."""

    def __init__(self, model, optimizer, device, loss_fn):

        self.net = model
        self.opt = optimizer 
        self.device = device
        self.plotter = Plotter()
        self.loss_fn = loss_fn

    def eval(self, dataloader, thresholds):

        return eval(dataloader, 
                    self.net, 
                    self.device, 
                    loss_fn=self.loss_fn, 
                    thresholds=thresholds)

    def train(self, train_dl, val_dl, epochs, save_dir=None, dir_name=None):

        def train_fn():

            train_loss = 0.0
            self.net.train()

            fin_outputs = []
            fin_targets = []
            for bi, d in tqdm(enumerate(train_dl), total=len(train_dl)):
                ids = d["ids"]
                mask = d["mask"]
                targets = d["labels"]
                val_mask = d["val_mask"]

                ids = ids.to(self.device, dtype=torch.long)
                mask = mask.to(self.device, dtype=torch.long)
                targets = targets.to(self.device, dtype=torch.float)
                outputs = self.net(ids, mask)
                self.opt.zero_grad()

                idxs = val_mask == 1
                outputs, targets = outputs[idxs], targets[idxs]

                loss = loss_fn(outputs, targets)
                loss.backward()
                train_loss += loss.item()
                self.opt.step()

                fin_targets.extend(targets)
                fin_outputs.extend(outputs)

            fin_outputs = torch.stack(fin_outputs)
            fin_targets = torch.stack(fin_targets)
            fin_outputs = fin_outputs.cpu().detach().numpy()
            fin_targets = fin_targets.cpu().detach().numpy() 

            thresholds = find_opt_threshold(fin_outputs, fin_targets)
                
            return train_loss, thresholds
                        
        if save_dir is not None and os.path.isdir(save_dir):
            datestring = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

            if dir_name is None:  
                save_dir = "".join([save_dir, datestring, "/"])
            elif os.path.isdir("".join([save_dir, dir_name, "/"])):
                save_dir = "".join([save_dir, dir_name, datestring, "/"])
                print(
                    "dir_name specicified already exists. Saving as {}".format(
                        save_dir))
            else:
                save_dir = "".join([save_dir, dir_name, "/"])

            os.mkdir(save_dir)

        best_val_loss = 100
        best_val_micro = 0
        best_val_micro_no = 0
        best_val_f1 = 0
        best_val_f1_no = 0

        for epoch in tqdm(range(epochs)):
            train_loss, thresholds = train_fn()
            res = self.eval(val_dl, thresholds)

            eval_loss = res["loss"]
            pred = res["pred"]
            pred_non_opt = res["pred_non_opt"]
            pred_spans = res["pred_spans"]
            pred_spans_non_opt = res["pred_spans_non_opt"]
            true = res["true"]
            true_spans = res["true_spans"]

            avg_train_loss = train_loss / len(train_dl)
            avg_val_loss = eval_loss / len(val_dl)

            metrics = get_metrics(pred, true, pred_spans, true_spans)
            metrics_non_opt = get_metrics(
                pred_non_opt, true, pred_spans_non_opt, true_spans)

            metrics['train_loss'] = avg_train_loss
            metrics['val_loss'] = avg_val_loss
            metrics['f1_non_opt'] = metrics_non_opt['f1']
            metrics['f1-micro_tokens_non_opt'] = metrics_non_opt[
                'f1-micro_tokens']
            metrics['f1-macro_tokens_non_opt'] = metrics_non_opt[
                'f1-macro_tokens']

            self.plotter.add(metrics)

            print("Epoch {}:".format(epoch))
            print("Average Train loss: ", avg_train_loss)
            print("Average Valid loss: ", avg_val_loss)
            print("F1-micro (token-level): ", metrics['f1-micro_tokens'])
            print("F1-macro (token-level): ", metrics['f1-macro_tokens'])
            print("F1: ", metrics['f1'])
            print("Precision: ", metrics['precision'])
            print("Recall: ", metrics['recall'])

            if save_dir is not None:
                save_msg = "Model saved as current {} is: {:.4f}"
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    save_model(
                        self.net, thresholds, save_dir + "best_model_loss.pt")
                    print(save_msg.format("val_loss", best_val_loss))

                if metrics['f1-micro_tokens'] > best_val_micro:
                    best_val_micro = metrics['f1-micro_tokens']
                    save_model(
                        self.net, thresholds, save_dir + "best_model_micro.pt")
                    print(save_msg.format("micro", best_val_micro))

                if metrics['f1-micro_tokens_non_opt'] > best_val_micro_no:
                    best_val_micro_no = metrics['f1-micro_tokens_non_opt']
                    save_model(
                        self.net, 
                        thresholds, 
                        save_dir + "best_model_micro_non_opt.pt")
                    print(save_msg.format("micro_non_opt", best_val_micro_no))

                if metrics['f1'] > best_val_f1:
                    best_val_f1 = metrics['f1']
                    save_model(
                        self.net, thresholds, save_dir + "best_model_f1.pt")
                    print(save_msg.format("f1", best_val_f1))

                if metrics['f1_non_opt'] > best_val_f1_no:
                    best_val_f1_no = metrics['f1_non_opt']
                    save_model(
                        self.net, 
                        thresholds, 
                        save_dir + "best_model_f1_non_opt.pt")
                    print(save_msg.format("f1_non_opt", best_val_f1_no))

                self.plotter.save(save_dir) 

    def plot_metrics(self):

        self.plotter.plot()

### TRAINING

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.manual_seed(1)

<torch._C.Generator at 0x7fd69b63d130>

In [18]:
train, test, dev = load_datasets()

train_ds, test_ds, dev_ds = build_datasets(train, test, dev)
train_dl, test_dl, dev_dl = build_dataloaders(train_ds, test_ds, dev_ds)

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [19]:
bert_unfrozen = "drive/MyDrive/DeepLearning/Models/\
Pretrained_unfrozen/checkpoint-40000"
bert_frozen = "drive/MyDrive/DeepLearning/Models/\
Pretrained_frozen/checkpoint-40000"

model = MemeClassifier(N_LABELS, bert_path=bert_frozen)

loss_fn = nn.BCELoss()
# loss_fn = FocalLoss()

model.to(device)
model = nn.DataParallel(model)

Some weights of the model checkpoint at drive/MyDrive/DeepLearning/Models/Pretrained_frozen/checkpoint-40000 were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [20]:
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

TRAIN = False
save_dir = "drive/MyDrive/DeepLearning/Task2/"
epochs = 50

total_steps = len(train_dl) * epochs

optimizer = build_optimizer(model, lr=4e-5)
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,
                                            num_training_steps=total_steps)

trainer = Trainer(model, optimizer, device, loss_fn)

if TRAIN:
    trainer.train(
        train_dl, 
        dev_dl, 
        epochs, 
        dir_name="final-frozen", 
        save_dir=save_dir)
    
    trainer.plot_metrics()

### TESTER

In [21]:
def add_threshold_uncertainty(thresholds, degree=2):
    """Adds uncertainty to the thresholds by flatting them towards
    0.5, following a polynomial function. The 'degree' parameter chooses
    the degree of the polynomial. If it is set to 1, the 
    thresholds are all set to 0.5 (max uncertainty)."""

    assert (int(degree) == degree and 
            degree > 0), '"degree" has to be a positive integer.'

    def uncertainty(x):
        
        u = - (x - 0.5)**degree
        if degree % 2:
            return u
        else:
            return u if x > 0.5 else -u
    
    return thresholds + np.vectorize(uncertainty)(thresholds)

x = np.arange(0, 1.1, 0.1)
y = add_threshold_uncertainty(x, degree=3)

print("EXAMPLE:")
print("OLD THRESHOLDS:\n", x)
print("\nNEW THRESHOLDS:\n", y)

EXAMPLE:
OLD THRESHOLDS:
 [0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1. ]

NEW THRESHOLDS:
 [0.125 0.164 0.227 0.308 0.401 0.5   0.599 0.692 0.773 0.836 0.875]


In [22]:
class Tester():
    """Evaluates a model and prints plots and metrics."""

    def __init__(self, model, dl, device, thresholds=None):

        if thresholds is None: 
            thresholds = np.array([0.5] * 20)
        
        res = eval(dl, model, device, thresholds=thresholds)
            
        self.pred = res["pred"]
        self.true = res["true"]
        self.pred_spans = res["pred_spans"]
        self.true_spans = res["true_spans"]
        self.pred_non_opt = res["pred_non_opt"]
        self.pred_spans_non_opt = res["pred_spans_non_opt"]

        thresholds = add_threshold_uncertainty(thresholds, degree=5)

        res = eval(dl, model, device, thresholds=thresholds)

        self.pred_uncertain = res["pred"]
        self.pred_spans_uncertain = res["pred_spans"]

    def print_f_table(self):

        f = f1_score(self.true, self.pred, average=None, zero_division=1)
        print(tabulate([f], headers=LABELS.keys(), floatfmt=".4f"))

    def print_fpr(self):

        f, p, r = f1_metric(self.pred_spans, self.true_spans)
        print("f_sentence: {:10.4f}".format(f))
        print("precision_sentence: {:10.4f}".format(p))
        print("recall_sentence: {:10.4f}".format(r))
        
        f, p, r = f1_metric(self.pred_spans_non_opt, self.true_spans)
        print("f_sentence w/out threshold moving: {:10.4f}".format(f))
        print("precision_sentence w/out threshold moving: {:10.4f}".format(p))
        print("recall_sentence w/out threshold moving: {:10.4f}".format(r))

        f, p, r = f1_metric(self.pred_spans_uncertain, self.true_spans)
        print("f_sentence w/ threshold uncertainty: {:10.4f}".format(f))
        print("precision_sentence w/ threshold uncertainty: {:10.4f}".format(p))
        print("recall_sentence w/ threshold uncertainty: {:10.4f}".format(r))

    def print_fmacro(self):

        print("f_macro_token: {:10.4f}".format(f1_score(
            self.true, self.pred, average='macro', zero_division=1)))
        print("f_macro w/out threshold moving: {:10.4f}".format(f1_score(
            self.true, self.pred_non_opt, average='macro', zero_division=1)))
        print("f_macro w/ threshold uncertainty: {:10.4f}".format(f1_score(
            self.true, self.pred_uncertain, average='macro', zero_division=1)))

    def print_fmicro(self): 

        print("f_micro_token: {:10.4f}".format(f1_score(
            self.true, self.pred, average='micro', zero_division=1)))
        print("f_micro w/out threshold moving: {:10.4f}".format(f1_score(
            self.true, self.pred_non_opt, average='micro', zero_division=1)))
        print("f_micro w/ threshold uncertainty: {:10.4f}".format(f1_score(
            self.true, self.pred_uncertain, average='micro', zero_division=1)))
        
    def plot_confusion_matrices(self):

        cms = multilabel_confusion_matrix(self.true, self.pred)

        r = N_LABELS // 4
        c = N_LABELS // r
        fig, axs = plt.subplots(r, c, figsize=(12, 14))
        axs = axs.flat

        titles = [
            l if len(l) < 22 else l[:20] + ".." for l in list(LABELS.keys())]

        for idx, cm in enumerate(cms):
            df_cm = pd.DataFrame(
                cm, index = ["F", "T"], columns = ["F", "T"])
            axs[idx] = sn.heatmap(
                df_cm, annot=True, ax=axs[idx], cmap=plt.cm.Greens)
            axs[idx].title.set_text(titles[idx])
            
        fig.tight_layout()
        fig.show()
        plt.show()

In [23]:
def test_model(path, test_level=0, model_name=["MODEL"]):

    model, thresholds = load_model(path, device)
    model.to(device)
    tester = Tester(model, test_dl, device, thresholds)

    print("####### TESTING {} #######".format(model_name.upper()))

    tester.print_fpr()
    print()
    tester.print_fmicro()
    print()
    tester.print_fmacro()

    if test_level >= 1:
        tester.print_f_table()
        tester.plot_confusion_matrices()

    print("---------------------------")

def test_models(models_dir, models_names, to_test=None, test_level=0):

    err_msg = '"to_test" has to be a list of string or None'
    if to_test is not None and (not isinstance(to_test, list) or 
                                not all([isinstance(t, str) for t in to_test])):
        raise TypeError(err_msg)

    logging.set_verbosity(40)

    for model_name in models_names:
        dir = "".join([models_dir, model_name])

        f_names = []
        if to_test is None:
            for f in os.listdir(dir):
                if f.endswith(".pt"):
                    f_names.append(f)
        else:
            for m in to_test:
                f_names.append("best_model_{}.pt".format(m))

        for name in f_names:   
            path = "{}/{}".format(dir, name)
            m_name = " ".join([model_name, name[:-3]])

            test_model(path, model_name=m_name, test_level=test_level)

    logging.set_verbosity(20)

### TESTING

In [24]:
models_dir = "drive/MyDrive/DeepLearning/Task2/"
models_names = ["frozen", "frozen-focal"]
to_test = ["f1", "f1_non_opt", 'micro', 'micro_non_opt']

test_models(models_dir, 
            models_names, 
            to_test=to_test, 
            test_level=0)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

####### TESTING FROZEN BEST_MODEL_F1 #######
f_sentence:     0.5457
precision_sentence:     0.6340
recall_sentence:     0.4789
f_sentence w/out threshold moving:     0.5447
precision_sentence w/out threshold moving:     0.6821
recall_sentence w/out threshold moving:     0.4534
f_sentence w/ threshold uncertainty:     0.2944
precision_sentence w/ threshold uncertainty:     0.3488
recall_sentence w/ threshold uncertainty:     0.2547

f_micro_token:     0.3337
f_micro w/out threshold moving:     0.3233
f_micro w/ threshold uncertainty:     0.1137

f_macro_token:     0.1658
f_macro w/out threshold moving:     0.1388
f_macro w/ threshold uncertainty:     0.0321
---------------------------
####### TESTING FROZEN BEST_MODEL_F1_NON_OPT #######
f_sentence:     0.5538
precision_sentence:     0.6415
recall_sentence:     0.4873
f_sentence w/out threshold moving:     0.5560
precision_sentence w/out threshold moving:     0.6695
recall_sentence w/out threshold moving:     0.4753
f_sentence w/ thresho