In [1]:
from transformers import glue_processors
from transformers import glue_compute_metrics
from collections import Counter
from scipy.stats import norm
import numpy as np
np.random.seed(1337)

In [2]:
counter_dict = {}
for task in ["CoLA", "MNLI", "MNLI-mm", "MRPC", "QNLI", "QQP", "RTE", "SST-2", "STS-B", "WNLI"]:
    processor = glue_processors[task.lower()]()
    data_dir = "MNLI" if task.startswith("MNLI") else task
    eval_examples = processor.get_dev_examples(f"../data/glue/{data_dir}")
    if task == "STS-B":
        items = []
        for example in eval_examples:
            items.append(float(example.label))
        counter_dict[task] = norm.fit(items)
    else:
        counter = Counter()
        for example in eval_examples:
            counter[example.label] += 1
        counter_dict[task] = counter

In [3]:
counter_dict

{'CoLA': Counter({'1': 721, '0': 322}),
 'MNLI': Counter({'neutral': 3123, 'contradiction': 3213, 'entailment': 3479}),
 'MNLI-mm': Counter({'contradiction': 3240,
          'entailment': 3463,
          'neutral': 3129}),
 'MRPC': Counter({'1': 279, '0': 129}),
 'QNLI': Counter({'entailment': 2702, 'not_entailment': 2761}),
 'QQP': Counter({'0': 25545, '1': 14885}),
 'RTE': Counter({'not_entailment': 131, 'entailment': 146}),
 'SST-2': Counter({'1': 444, '0': 428}),
 'STS-B': (2.3639075555555555, 1.4999854042902065),
 'WNLI': Counter({'0': 40, '1': 31})}

# Frequency baseline prediction

In [4]:
freq_base_line_prediction = {}
for task in counter_dict.keys():
    if task != "STS-B":
        freq_base_line_prediction[task] = counter_dict[task].most_common()[0][0]
    else:
        freq_base_line_prediction[task] = counter_dict[task]

In [6]:
freq_base_line_prediction

{'CoLA': '1',
 'MNLI': 'entailment',
 'MNLI-mm': 'entailment',
 'MRPC': '1',
 'QNLI': 'not_entailment',
 'QQP': '0',
 'RTE': 'entailment',
 'SST-2': '1',
 'STS-B': (2.3639075555555555, 1.4999854042902065),
 'WNLI': '0'}

# Frequency baseline prediction Benchmark

In [10]:
freq_baseline = {}
for task in freq_base_line_prediction.keys():
    processor = glue_processors[task.lower()]()
    data_dir = "MNLI" if task.startswith("MNLI") else task
    eval_examples = processor.get_dev_examples(f"../data/glue/{data_dir}")
    label_list = processor.get_labels()
    prediction = freq_base_line_prediction[task]
    if task == "STS-B":
        labels = []
        predictions = np.random.normal(prediction[0], prediction[1], len(eval_examples))
        for example in eval_examples:
            labels.append(float(example.label))
        labels = np.array(labels)
    else:
        labels = []
        predictions = []
        for example in eval_examples:
            labels.append(label_list.index(example.label))
            predictions.append(label_list.index(prediction))
        predictions = np.array(predictions)
        labels = np.array(labels)
    results = glue_compute_metrics(task.lower(), predictions, labels)
    if task.startswith("MNLI"):
        if "MNLI" not in freq_baseline:
            freq_baseline["MNLI"] = {}
        for k, v in results.items():
            freq_baseline["MNLI"][f"{task.lower()}_{k}"] = v
    else:
        freq_baseline[task] = results


In [11]:
freq_baseline

{'CoLA': {'mcc': 0.0},
 'MNLI': {'mnli_acc': 0.3544574630667346, 'mnli-mm_acc': 0.3522172497965826},
 'MRPC': {'acc': 0.6838235294117647,
  'f1': 0.8122270742358079,
  'acc_and_f1': 0.7480253018237863},
 'QNLI': {'acc': 0.5053999633900788},
 'QQP': {'acc': 0.6318327974276527,
  'f1': 0.0,
  'acc_and_f1': 0.3159163987138264},
 'RTE': {'acc': 0.5270758122743683},
 'SST-2': {'acc': 0.5091743119266054},
 'STS-B': {'pearson': 0.02425364185684005,
  'spearmanr': 0.020099238016388975,
  'corr': 0.022176439936614514},
 'WNLI': {'acc': 0.5633802816901409}}

In [12]:
import json
def write_results(results, output_file_path):
    with open(output_file_path, "w") as fp:
        json.dump(results, fp, indent=4, sort_keys=True)

In [13]:
!mkdir -p ../experiments/freq_baseline
write_results(freq_baseline, "../experiments/freq_baseline/results.json")