In [None]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['figure.figsize'] = [14, 8]

from datasets import Dataset, load_dataset, concatenate_datasets
import gc
import os
import pyarrow as pa
import re
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold
from tqdm.auto import trange
from typing import Iterator, List

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, Sampler
from transformers import AutoTokenizer, AutoModel

import sys
sys.path.append('..')
from mcpt.contrastlearning import DataManager, TrainerA, WeightedCosineSimilarityLoss

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

In [None]:
data_path = '../data'
#data_path = '../input/semeval/data'
DEV = True
#model_name = 'sentence-transformers/all-mpnet-base-v2'
#model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model_name = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
#model_name = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
model_sampler = 'random'
N_EPOCHS = 1000
N_FINETUNE_EPOCHS = 50
N_EPOCHS_BEFORE_FINETUNE = 50
N_POST_FINETUNE_EPOCHS = 50
MODEL_BATCH_SIZE = 26
HEAD_BATCH_SIZE = 200
MIN_SAMPLES_FROM_CLASS = 2
HEAD_LR = 1e-3
HEAD_GAMMA = .99
MODEL_LR = 2e-5
BETA = 0.01
MODEL_GAMMA = .98
VALIDATE_EVERY = -1
CHECKPOINT_EVERY = 10
EARLIEST_CHECKPOINT = 39

In [None]:
model = AutoModel.from_pretrained(model_name)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_langs = ['en', 'ge', 'fr', 'it', 'ru', 'po']
datamanager = DataManager(
    tokenizer=tokenizer,
    data_dir=data_path,
    use_dev=DEV,
    languages_for_head_eval=[],
    languages_for_head_train=all_langs,
    languages_for_contrastive=all_langs,
)
N_CLASSES = datamanager.num_classes
metrics = list()
reference_list = list()

In [None]:
EMBEDDING_DIM = model.embeddings.word_embeddings.embedding_dim
head = nn.Sequential(
    nn.Linear(EMBEDDING_DIM, 256),
    nn.Dropout(),
    nn.ReLU(),
    nn.Linear(256, 256),
    nn.Dropout(),
    nn.ReLU(),
    nn.Linear(256, N_CLASSES),
    nn.Dropout(),
)

# Train Full

## Sanity Check

# Train Target Language

In [None]:
for lang in all_langs:
    print(f'Training {lang}')
    datamanager = DataManager(
        tokenizer=tokenizer,
        data_dir=data_path,
        use_dev=DEV,
        languages_for_head_eval=[],
        languages_for_head_train=[lang],
        languages_for_contrastive=[lang],
    )
    dataset_contrastive = datamanager.get_contrastive_dataset()
    dataset_head_train = datamanager.get_head_train_dataset()
    dataset_head_eval = datamanager.get_head_eval_dataset()
    trainer = TrainerA(
        model=model,
        head=head,
        device=device,
        head_loss=nn.BCEWithLogitsLoss(),
        model_loss=WeightedCosineSimilarityLoss(N_CLASSES),
        model_dataset=dataset_contrastive,
        head_dataset=dataset_head_train,
        eval_dataset=dataset_head_eval,        
        n_classes=N_CLASSES,
        model_loader_type=model_sampler,
        train_head_batch_size=HEAD_BATCH_SIZE,
        train_model_batch_size=MODEL_BATCH_SIZE,
        head_lr=HEAD_LR,
        model_lr=MODEL_LR,
        head_gamma=HEAD_GAMMA,
        model_gamma=MODEL_GAMMA,
        beta=BETA,
        min_samples_from_class=MIN_SAMPLES_FROM_CLASS,
        validate_every_n_epochs=VALIDATE_EVERY,
        checkpoint_every_n_epochs=1000,
        earliest_checkpoint=1000,
    )
    trainer.load_from_checkpoint('joint_49')
    head = nn.Sequential(
        nn.Linear(EMBEDDING_DIM, 256),
        nn.Dropout(),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.Dropout(),
        nn.ReLU(),
        nn.Linear(256, N_CLASSES),
        nn.Dropout(),
    )
    head_optimizer = AdamW(head.parameters(), lr=HEAD_LR)
    trainer.set_head(head, head_optimizer)

    trainer.train_head(N_EPOCHS_BEFORE_FINETUNE)
    trainer.train_joint(N_FINETUNE_EPOCHS)
    trainer.train_head(N_POST_FINETUNE_EPOCHS)
    
    dataset_sanity = datamanager._get_single_named_dataset(lang, dev=True)
    dataset_sanity = datamanager._preprocess_head_dataset(dataset_sanity)
    embeddings = trainer.compute_embeddings(dataset_sanity)
    predictions = trainer.predict(embeddings.tensors[0], 'cpu')
    f1 = f1_score(dataset_sanity['labels'], predictions, average='micro')
    print('  ', lang, ': ', f1)
    
    print('  Writing prediction file.')
    datamanager.predict_and_write(
        trainer,
        articles_dir=f'../input/semeval/data/{lang}/test-articles-subtask-2',
        output_file=f'predictions_{lang}.csv'
    )

# Predictions

In [None]:
langs = ['en', 'ge', 'fr', 'it', 'ru', 'po', 'ka', 'es', 'gr']
for lang in langs:
    datamanager.predict_and_write(
        trainer,
        articles_dir=f'../input/semeval/data/{lang}/test-articles-subtask-2',
        output_file=f'predictions_{lang}.csv'
    )

In [None]:
print(f'Average Max MicroF1: {np.mean(np.array([np.max(m["microf1"]) for m in metrics]))}')
print(f'Average Max MacroF1: {np.mean(np.array([np.max(m["macrof1"]) for m in metrics]))}')
print(f'Average Max Train MicroF1: {np.mean(np.array([np.max(m["train_microf1"]) for m in metrics]))}')
print(f'Average Max Train MacroF1: {np.mean(np.array([np.max(m["train_macrof1"]) for m in metrics]))}')

In [None]:
for log_dict in metrics:
    TrainerA.plot_metrics(log_dict, 1)

In [None]:
KNN_scores = []
ground = reference_list[1]['labels'].numpy()
for KNN_preds in metrics[1]['KNNlogits']:
    KNN_preds = torch.round(KNN_preds).numpy()
    KNN_scores.append(f1_score(ground, KNN_preds, average='micro'))
plt.plot(KNN_scores)
plt.show()

In [None]:
plt.plot(trainer.log_dict['WCSL'])
plt.show()

# Evaluation

In [None]:
categories = INT2LABEL

def per_label_f1(predictions, references):
    f1 = f1_score(references, predictions, average=None)
    print("f1:", f1)
    micro_f1 = f1_score(references, predictions, average="micro")
    print("micro-f1:", micro_f1)
    macro_f1 = f1_score(references, predictions, average="macro")
    print("macro-f1:", macro_f1)

    correct = []
    label_names = []
    for c in range(len(categories)):
        correct.append(f1[c])
        label_names.append(categories[c])
    correct = np.array(correct)
    label_names = np.array(label_names)
    df_correct_pred = pd.DataFrame({"f1_score": correct, "label_name": label_names})

    order = sorted(range(len(categories)), key=lambda i: f1[i])
    return order, sns.barplot(x="f1_score", y="label_name", data=df_correct_pred, order=np.array(categories)[order])

In [None]:
def plot_all(predictions, references):
    order, _ = per_label_f1(predictions, references)
    plt.show()
    
    pred_heatmap = pd.DataFrame(predictions, columns=categories)
    correct_predictions = references == predictions
    false_predictions   = references != predictions
    pred_heatmap[(correct_predictions & (predictions == 1))] = 3 # correct and one
    pred_heatmap[(correct_predictions & (predictions == 0))] = 2 # correct and zero
    pred_heatmap[(false_predictions & (predictions == 1))] = 1   # false and actually zero
    pred_heatmap[(false_predictions & (predictions == 0))] = 0   # false and actually one
    
    pred_heatmap = pred_heatmap.iloc[:,order[::-1]]
    pred_heatmap['false_predictions'] = false_predictions.sum(axis=1)
    #pred_heatmap = pred_heatmap.sort_values(by='false_predictions', ascending=False)
    pred_heatmap = pred_heatmap.sort_values(by=list(pred_heatmap.columns), ascending=False)
    
    fig, ax = plt.subplots(figsize=[14, 20])
    cmap = sns.color_palette("coolwarm_r", 4)
    sns.heatmap(pred_heatmap.iloc[:,:-1], cmap=cmap)
    plt.show()

In [None]:
for r, m in zip(reference_list, metrics):
    best_epoch = np.argmax(m['microf1'])
    predictions = np.array([p.cpu().tolist() for p in m['predictions']][best_epoch])
    references = r['labels'].numpy()
    plot_all(predictions, references)