In [51]:
import json

from collections import Counter
from IPython.display import HTML, display

from datasets import load_dataset

In [68]:
multinli = load_dataset("multi_nli")
esnli = load_dataset("../datasets/esnli.py")
sick = load_dataset("sick")

Found cached dataset multi_nli (/home/imger/.cache/huggingface/datasets/multi_nli/default/0.0.0/591f72eb6263d1ab527561777936b199b714cda156d35716881158a2bd144f39)
100%|██████████| 3/3 [00:00<00:00, 347.02it/s]
Found cached dataset esnli (/home/imger/.cache/huggingface/datasets/esnli/plain_text/0.0.2/262495ebbd9e71ec9b0c37a93e378f1b353dc28bb904305e011506792a02996b)
100%|██████████| 3/3 [00:00<00:00, 256.44it/s]
Found cached dataset sick (/home/imger/.cache/huggingface/datasets/sick/default/0.0.0/c6b3b0b44eb84b134851396d6d464e5cb8f026960519d640e087fe33472626db)
100%|██████████| 3/3 [00:00<00:00, 841.67it/s]


In [56]:
label_names = {
    0: "entailment",
    1: "neutral",
    2: "contradiction"
}
def describe_dataset(dataset):
    sample = dataset[list(dataset.keys())[0]][0]

    print("Features:")
    print(list(sample.keys()))


    print("Sample from dataset:",end="")
    display(HTML(f"<pre>{json.dumps(sample, sort_keys=True, indent=4)}</pre>"))


    # Table for stats per label
    stats = {
        split: Counter(
            d["label"] for d in dataset[split]
        ) for split in dataset.keys()
    }
    all_labels = set()
    for split_stats in stats.values():
        all_labels = all_labels.union(split_stats.keys())
    header = "<thead><th>split</th>"
    for label in all_labels:
        header += f"<th>{label_names[label]}</th>"
    header += "<th>sum</th>"
    header += "</thead>"
    body = "<tbody>"
    for split,split_stats in stats.items():
        body += f"<tr><td>{split}</td>"
        body += "".join(f"<td>{split_stats[label]}</td>" for label in all_labels)
        body += f"<td>{sum(split_stats.values())}</td>"
        body += "</tr>"
    
    body += "<tr><td>complete</td>"
    body += "".join(f"<td>{sum(stats[split][label] for split in dataset.keys())}</td>" for label in all_labels)
    body += f"<td>{sum(sum(split_stats.values()) for split_stats in stats.values())}</td>"
    body += "</tr>"
    
    body += "</tbody>"
    display(HTML(f"<table>{header}{body}</table>"))

In [69]:
describe_dataset(multinli)

Features:
['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label']
Sample from dataset:

split,entailment,neutral,contradiction,sum
train,130899,130900,130903,392702
validation_matched,3479,3123,3213,9815
validation_mismatched,3463,3129,3240,9832
complete,137841,137152,137356,412349


In [70]:
describe_dataset(esnli)

Features:
['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3', 'sentence1_highlighted_1', 'sentence2_highlighted_1', 'sentence1_highlighted_2', 'sentence2_highlighted_2', 'sentence1_highlighted_3', 'sentence2_highlighted_3']
Sample from dataset:

split,entailment,neutral,contradiction,sum
train,183416,182764,183187,549367
validation,3329,3235,3278,9842
test,3368,3219,3237,9824
complete,190113,189218,189702,569033


In [71]:
describe_dataset(sick)

Features:
['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset']
Sample from dataset:

split,entailment,neutral,contradiction,sum
train,1274,2524,641,4439
validation,143,281,71,495
test,1404,2790,712,4906
complete,2821,5595,1424,9840
