/
bigbench.py
79 lines (57 loc) · 2.5 KB
/
bigbench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from datasets import load_dataset
def get_bb_dataset(split):
if split == "causal_judgement":
raw_dataset = load_dataset("tasksource/bigbench", "causal_judgment")
choices = ["Yes", "No"]
dataset = []
for split_ in ["validation", "train"]:
for dp in raw_dataset[split_]:
targets = dp["targets"]
assert len(targets) == 1
assert targets[0] in choices
dataset.append((dp["inputs"], targets[0]))
elif split == "web_of_lies":
raw_dataset = load_dataset("lighteval/big_bench_hard", "web_of_lies")
choices = ["Yes", "No"]
dataset = []
for dp in raw_dataset["train"]:
target = dp["target"]
assert target in choices
dataset.append((dp["input"], target))
elif split == "epistemic_reasoning":
raw_dataset = load_dataset("tasksource/bigbench", "epistemic_reasoning")
choices = ["entailment", "non-entailment"]
dataset = []
for split_ in ["validation", "train"]:
for dp in raw_dataset[split_]:
targets = dp["targets"]
assert len(targets) == 1
assert targets[0] in choices
dataset.append((dp["inputs"], targets[0]))
elif split == "epistemic_reasoning_y":
raw_dataset = load_dataset("tasksource/bigbench", "epistemic_reasoning")
choices = ["True", "False"]
dataset = []
for split_ in ["validation", "train"]:
for dp in raw_dataset[split_]:
targets = dp["targets"]
assert len(targets) == 1
assert targets[0] in choices
assert dp["inputs"].endswith("Relation:")
text = dp["inputs"][:-len("Relation:")] + \
"Does the premise entails the hypothesis, True or False? Answer is"
dataset.append((text, targets[0]))
elif split == "qa_wikidata":
raw_dataset = load_dataset("tasksource/bigbench", "qa_wikidata")
dataset = []
choices = None
for split_ in ["validation", "train"]:
for dp in raw_dataset[split_]:
targets = dp["targets"]
if len(targets) == 1:
dataset.append((dp["inputs"], targets[0]))
else:
raise AssertionError(f"Unhandled split {split}.")
if choices is not None:
assert len(set(choices)) == len(choices), f"Found duplicates in {choices}"
return dataset, choices