Don't edit. Should be copy of WSD_analysis_v3

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
from scipy import stats
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
import os

In [2]:
pred_dev_filename = 'data/BEM-WiC-preds/XLM-R_05_unbal_base_epochs/epoch_{0}/preds_dev.{1}-{1}.data'
pred_test_filename = 'data/BEM-WiC-preds/XLM-R_05_unbal_base_epochs/epoch_{0}/preds_test.{1}-{1}.data'

In [3]:
def construct_preds(pred_filename):
    with open(pred_filename) as f:
        return json.load(f)

In [4]:
data_dev_files = 'data/MCL-WiC/dev/multilingual/dev.{0}-{0}'
data_test_files = 'data/MCL-WiC/test/multilingual/test.{0}-{0}'

In [5]:
def construct_data_samples(data_files):
    with open(data_files + '.data') as f_data, open(data_files + '.gold') as f_labels:
        data_json = json.load(f_data)
        labels_json = json.load(f_labels)
    
    labels_dict = {sample['id']: sample for sample in labels_json}
    for sample in data_json:
        sample['tag'] = labels_dict[sample['id']]['tag']

    return data_json

In [6]:
def construct_test_data_samples(data_files):
    with open(data_files + '.data') as f_data:
        data_json = json.load(f_data)

    return data_json

In [7]:
def construct_samples(data_files, pred_filename, with_labels=True):
    preds = construct_preds(pred_filename)
    if with_labels:
        samples = construct_data_samples(data_files)
    else:
        samples = construct_test_data_samples(data_files)
    
    preds_dict = {pred['id']: pred for pred in preds}

    for sample in samples:
        pred = preds_dict[sample['id']]
        for key, value in pred.items():
            sample[key] = value
    
    return samples

In [8]:
from abc import ABC, abstractmethod

In [9]:
class AbsPredictor(ABC):
    @abstractmethod
    def predict(self, probs_1, probs_2):
        pass
    
    @abstractmethod
    def predict_proba(self, probs_1, probs_2):
        pass

In [10]:
class DivergencePredictor(AbsPredictor):
    def __init__(self, threshold, normalize=True, divergence='Kullback–Leibler'):
        assert divergence in ['Kullback–Leibler', 'Jensen–Shannon']
        
        self.threshold = threshold
        self.normalize = normalize
        self.divergence = divergence
        
    @staticmethod
    def _kullback_leibler(probs_1, probs_2):
        return sum(probs_1 * np.log(probs_1 / probs_2))
        
    def predict_proba(self, probs_1, probs_2):
        if len(probs_1) < 2:
            return True
        
        lst_probs_1, lst_probs_2 = [], []
        for key in probs_1:
            lst_probs_1.append(probs_1[key])
            lst_probs_2.append(probs_2[key])
            
        lst_probs_1 = np.array(lst_probs_1, dtype=np.float128)
        lst_probs_2 = np.array(lst_probs_2, dtype=np.float128)
            
        lst_probs_1 = np.exp(lst_probs_1) / sum(np.exp(lst_probs_1))
        lst_probs_2 = np.exp(lst_probs_2) / sum(np.exp(lst_probs_2))
        
        if self.divergence == 'Kullback–Leibler':
            diver = self._kullback_leibler(lst_probs_1, lst_probs_2)
        elif self.divergence == 'Jensen–Shannon':
            m = (lst_probs_1 + lst_probs_2) / 2
            diver = (self._kullback_leibler(lst_probs_1, m)
                   + self._kullback_leibler(lst_probs_2, m)) / 2
        
        if self.normalize:
            diver /= len(probs_1)
        
        return diver
    
    def predict(self, probs_1, probs_2):
        return self.predict_proba(probs_1, probs_2) < self.threshold

In [11]:
class ProbsDotPredictor(AbsPredictor):
    def __init__(self, threshold):
        self.threshold = threshold
        
    def predict_proba(self, probs_1, probs_2):
        lst_probs_1, lst_probs_2 = [], []
        for key in probs_1:
            lst_probs_1.append(probs_1[key])
            lst_probs_2.append(probs_2[key])
            
        lst_probs_1 = np.array(lst_probs_1, dtype=np.float128)
        lst_probs_2 = np.array(lst_probs_2, dtype=np.float128)
            
        lst_probs_1 = np.exp(lst_probs_1) / sum(np.exp(lst_probs_1))
        lst_probs_2 = np.exp(lst_probs_2) / sum(np.exp(lst_probs_2))
        
        return sum(lst_probs_1 * lst_probs_2)
    
    def predict(self, probs_1, probs_2):
        return self.predict_proba(probs_1, probs_2) > self.threshold

In [12]:
class VectorsDotPredictor(AbsPredictor):
    def __init__(self, threshold, normalize=True, norm_ord=2):
        self.threshold = threshold
        self.normalize = normalize
        self.norm_ord = norm_ord
    
    def predict(self, out_vector_1, out_vector_2):
        return self.predict_proba(out_vector_1, out_vector_2) > self.threshold
    
    def predict_proba(self, out_vector_1, out_vector_2):
        out_vector_1 = np.array(out_vector_1)
        out_vector_2 = np.array(out_vector_2)
        
        if self.normalize:
            out_vector_1 /= np.linalg.norm(out_vector_1, ord=self.norm_ord)
            out_vector_2 /= np.linalg.norm(out_vector_2, ord=self.norm_ord)
            
        return sum(out_vector_1 * out_vector_2)

In [13]:
class VectorsDistPredictor(AbsPredictor):
    def __init__(self, threshold, normalize=True, norm_ord=2):
        self.threshold = threshold
        self.normalize = normalize
        self.norm_ord = norm_ord
    
    def predict(self, out_vector_1, out_vector_2):
        return self.predict_proba(out_vector_1, out_vector_2) < self.threshold
    
    def predict_proba(self, out_vector_1, out_vector_2):
        out_vector_1 = np.array(out_vector_1)
        out_vector_2 = np.array(out_vector_2)
        
        if self.normalize:
            out_vector_1 /= np.linalg.norm(out_vector_1, ord=self.norm_ord)
            out_vector_2 /= np.linalg.norm(out_vector_2, ord=self.norm_ord)
        
        return np.linalg.norm(out_vector_1 - out_vector_2, ord=self.norm_ord)

In [14]:
def get_probs_predictions(predictor, samples):
    return [predictor.predict(sample['probs1'], sample['probs2']) for sample in samples]

def get_contexts_predictions(predictor, samples):
    return [predictor.predict(sample['context_output1'], sample['context_output2']) for sample in samples]

In [15]:
def get_best_threshold(predictor_class, get_predictions, samples, y_true, thresholds, **args):
    scores = []

    for threshold in thresholds:
        y_pred = get_predictions(predictor_class(threshold=threshold, **args), samples)
        scores.append(accuracy_score(y_true, y_pred))
        
    return max(scores), thresholds[np.argmax(scores)]

In [16]:
def save_vectors_dot(dev_samples, y_dev_true, test_samples, y_test_true, filename):
    _, dev_threshold = get_best_threshold(VectorsDotPredictor, get_contexts_predictions, dev_samples, y_dev_true, 
                                   np.linspace(0.6, 1, 100), normalize=True, norm_ord=2)

    test_score = accuracy_score(y_test_true, get_contexts_predictions(
        VectorsDotPredictor(threshold=dev_threshold, normalize=True, norm_ord=2),
        test_samples))

    with open(filename, 'a') as f:
        f.write(f'Dot Embs, p=2: {test_score}\n')

In [17]:
def save_vectors_dist_2(dev_samples, y_dev_true, test_samples, y_test_true, filename):
    _, dev_threshold = get_best_threshold(VectorsDistPredictor, get_contexts_predictions, dev_samples, y_dev_true, 
                np.linspace(0, 1, 100), normalize=True, norm_ord=2)

    test_score = accuracy_score(y_test_true, get_contexts_predictions(
            VectorsDistPredictor(threshold=dev_threshold, normalize=True, norm_ord=2),
            test_samples))

    with open(filename, 'a') as f:
        f.write(f'Dist Embs, p=2: {test_score}\n')

In [18]:
def save_vectors_dist_1(dev_samples, y_dev_true, test_samples, y_test_true, filename):
    _, dev_threshold = get_best_threshold(VectorsDistPredictor, get_contexts_predictions, dev_samples, y_dev_true, 
                np.linspace(0, 1, 100), normalize=True, norm_ord=1)

    test_score = accuracy_score(y_test_true, get_contexts_predictions(
            VectorsDistPredictor(threshold=dev_threshold, normalize=True, norm_ord=1),
            test_samples))

    with open(filename, 'a') as f:
        f.write(f'Dist Embs, p=1: {test_score}\n')

In [19]:
def save_dot_probs(dev_samples, y_dev_true, test_samples, y_test_true, filename):
    _, dev_threshold = get_best_threshold(ProbsDotPredictor, get_probs_predictions,
                                          dev_samples, y_dev_true, np.linspace(0, 1, 100))

    test_score = accuracy_score(y_test_true, get_probs_predictions(
        ProbsDotPredictor(threshold=dev_threshold),
        test_samples))

    with open(filename, 'a') as f:
        f.write(f'Dot Probs: {test_score}\n')

In [20]:
def save_probs_jen_shen(dev_samples, y_dev_true, test_samples, y_test_true, filename):
    _, dev_threshold = get_best_threshold(DivergencePredictor, get_probs_predictions, dev_samples, y_dev_true, 
                np.linspace(0, 1, 300), divergence='Jensen–Shannon', normalize=False)

    test_score = accuracy_score(y_test_true, get_probs_predictions(
            DivergencePredictor(threshold=dev_threshold, divergence='Jensen–Shannon', normalize=False),
            test_samples))

    with open(filename, 'a') as f:
        f.write(f'Jen-Shen Probs: {test_score}\n')

In [21]:
langs = ['en', 'ru', 'fr', 'ar', 'zh']

In [22]:
epochs_num = 20
base_dir = 'data/BEM-WiC-epoch-scores/XLMR_Base'

In [23]:
for epoch in tqdm(range(1, 1 + epochs_num)):
    for lang in langs:
        dev_samples = construct_samples(data_dev_files.format(lang), pred_dev_filename.format(epoch, lang))
        test_samples = construct_samples(data_test_files.format(lang), pred_test_filename.format(epoch, lang))

        y_dev_true = [sample['tag'] == 'T' for sample in dev_samples]
        y_test_true = [sample['tag'] == 'T' for sample in test_samples]
        
        epoch_dir = os.path.join(base_dir, f'epoch_{epoch}')
        os.makedirs(epoch_dir, exist_ok=True)
        out_filename = os.path.join(epoch_dir, f'{lang}.txt')
        
        if lang == 'en':
            save_dot_probs(dev_samples, y_dev_true, test_samples, y_test_true, out_filename)
            save_probs_jen_shen(dev_samples, y_dev_true, test_samples, y_test_true, out_filename)
        
        save_vectors_dot(dev_samples, y_dev_true, test_samples, y_test_true, out_filename)
        save_vectors_dist_2(dev_samples, y_dev_true, test_samples, y_test_true, out_filename)
        save_vectors_dist_1(dev_samples, y_dev_true, test_samples, y_test_true, out_filename)

  0%|          | 0/20 [00:00<?, ?it/s]