In [1]:
from collections import Counter
from datasets import load_dataset, ClassLabel

In [2]:
def binarize_mnli(dataset, remove_neutral=True):
    if remove_neutral:
        # neutral class has label 1
        dataset = dataset.filter(lambda example: example["label"] != 1)

    # change labels of contradiction examples from 2 to 1
    def change_label(example):
        # convert labels 2 into labels 1. this merges the neutral and contradiction class
        example["label"] = 1 if example["label"] == 2 else example["label"]
        return example
        
    # change labels
    dataset = dataset.map(change_label)

    # change features to reflect the new labels
    features = dataset["train"].features.copy()
    features["label"] = ClassLabel(num_classes=2, names=['entailment', 'contradiction'], id=None)
    dataset = dataset.cast(features)  # overwrite old features
        
    return dataset

In [3]:
def load_paws_qqp_dataset(path="/llmft/data/paws_qqp/dev_and_test.tsv"):
    data_files = {"validation": path}
    dataset = load_dataset("csv", data_files=data_files, sep="\t")
    dataset = dataset["validation"]

    def _clean_data(sample):
        # the paws-qqp dataset was created as a stream of bytes. So every sentence starts with "b and ends with ".
        # we remove these
        sample["sentence1"] = sample["sentence1"][2:-1]
        sample["sentence2"] = sample["sentence2"][2:-1]
        return sample

    dataset = dataset.map(_clean_data, batched=False)
    dataset = dataset.rename_column("id", "idx")

    return dataset

In [4]:
def load_cola_ood_dataset(path="/llmft/data/cola_ood/dev.tsv", label=None, cache_dir=None):
    data_files = {"validation": path}
    dataset = load_dataset("csv", data_files=data_files, sep="\t", column_names=[
                           'code', 'label', 'annotation', 'sentence'], cache_dir=cache_dir)
    dataset = dataset["validation"]

    return dataset

In [7]:
# task_name = "rte"
task_name = "mnli"
# task_name = "qqp"
# task_name = "cola"

In [12]:
dataset = load_dataset("glue", task_name)

In [13]:
if task_name == "mnli":
    dataset = binarize_mnli(dataset, remove_neutral=True) # mnli
    # dataset = binarize_mnli(dataset, remove_neutral=False) # mnli-original

In [14]:
dataset

In [15]:
print("task_name:", task_name)
# for split in ["train", "validation"]:
for split in ["train", "validation_matched"]:
    c = Counter(dataset[split]["label"])
    total = len(list(c.elements()))
    print("Total number of sampels:", total)
    print(split)
    for k in c:
        print(f"fraction of labels per class: {k}={c[k] / total}")

In [None]:
# dataset = load_paws_qqp_dataset()
dataset = load_cola_ood_dataset()

In [None]:
c = Counter(dataset["label"])
total = len(list(c.elements()))
print("Total number of sampels:", total)
for k in c:
    print(f"fraction of labels per class: {k}={c[k] / total}")