# Accuracy of Fine-tuned Model for each Dataset

In [1]:
import pandas as pd
from utils import calculate_test_accuracy
import pickle
import numpy as np

from scipy.stats import pearsonr

## 1. IMDB (Large Movie Review Dataset)
## 2. Twitter Semeval
## 3. Twitter Sentiment140

In [5]:
def load_pickle(fpath):
    with open(fpath, 'rb') as f:
        pred = pickle.load(f)
    return pred

def calculate_pearson_correlation(task, model):
    label_path = f"../../asset/{task}/test.csv"
    pred_path = f"../../asset/{task}/predictions/{model}.pkl"

    test_df = pd.read_csv(label_path, header=None, sep="\t")

    test_labels = test_df[0].values
    predicitons = load_pickle(pred_path)

    return pearsonr(test_labels, predicitons)

models = ["bert-base-uncased", "bert-base-cased", "roberta-base", "xlnet-base-cased",
          "albert-base-v2", "microsoft/mpnet-base", "microsoft/deberta-base",
          "facebook/muppet-roberta-base", "google/electra-base-generator"]

tasks = ["imdb", "twitter_semeval", "twitter_s140"]  # dataset used for fine-tuning


accuracies = {}
pearsoncors = {}
accuracies["imdb"] = []
accuracies["twitter_semeval"] = []
pearsoncors["twitter_semeval"] = []
accuracies["twitter_s140"] = []

for model in models:
    for task in tasks :
        test_accuracy = calculate_test_accuracy(task, model)
        accuracies[task].append(test_accuracy)
        
        if task == "twitter_semeval" :
            pearsoncors["twitter_semeval"].append(calculate_pearson_correlation(
                task, model)[0])

df = pd.DataFrame(data={
                        "model" : models,
                        "accuracy-imdb": accuracies["imdb"],
                        "accuracy-twitter-semeval": accuracies["twitter_semeval"],
                        "pearson-corr-twitter-semeval": pearsoncors["twitter_semeval"],
                        "accuracy-twitter-s140": accuracies["twitter_s140"]
                        })

df


Unnamed: 0,model,accuracy-imdb,accuracy-twitter-semeval,pearson-corr-twitter-semeval,accuracy-twitter-s140
0,bert-base-uncased,92.57,86.92,0.734651,82.76
1,bert-base-cased,89.12,85.76,0.707453,80.71
2,roberta-base,92.84,88.37,0.768773,83.64
3,xlnet-base-cased,93.88,87.21,0.737494,81.86
4,albert-base-v2,89.45,85.17,0.701001,83.35
5,microsoft/mpnet-base,93.62,90.7,0.811291,82.28
6,microsoft/deberta-base,93.42,90.99,0.815184,83.83
7,facebook/muppet-roberta-base,95.29,90.7,0.809335,81.67
8,google/electra-base-generator,91.2,88.66,0.767157,79.87
