-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark.py
254 lines (221 loc) · 9.37 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
from pathlib import Path
from multiprocessing.dummy import Pool as ThreadPool # multithreading for IO operations
from multiprocessing import cpu_count
from typing import Callable, List, Optional, Dict, Tuple
from pprint import pprint
from torch.utils.data import DataLoader, SequentialSampler
import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
import pytorch_lightning as pl
from omegaconf import DictConfig
import datasets
import hydra
from claficle.data.process import process_dataset
def default_collate_fn(batch, **kwargs) -> List[Dict]:
return batch
class BenchmarkDataModule(pl.LightningDataModule):
"""
PL DataModule responsible for dataloaders for various datasets used
for our multitask benchmarking.
Note: run `set_tokenizer(tokenizer)` before asking for a dataloader
PS: you may also wish to run `set_pre_collate_fn(fn)`
to apply any pre-collation processing
"""
def __init__(self, config: DictConfig, lang: str):
super().__init__()
self.cfg = config
self.lang = lang
self.raw_save_path: str = os.path.join(self.cfg.data_dir, "raw")
self._metadata = {"lang": self.lang, "datasets": []}
self._pre_collate_fn: Callable[
..., List[Dict]
] = default_collate_fn # default no-op (can be set)
self.is_setup = False
pl.seed_everything(self.cfg.seed)
def prepare_data(self):
"""takes care of downloading data"""
if self.is_setup:
return
print("Loading datasets...")
# make use of multithreading to speed up by downloading in parallel
thread_pool = ThreadPool(cpu_count())
thread_pool.map(self._download_dataset, self.cfg.dataset_names)
def setup(self, stage: Optional[str] = None):
"""
processes each dataset, obtaining test split and relevant metric(s)
"""
if self.is_setup:
return
print("Processing datasets...")
self._processed_datasets = []
for dataset_name in self.cfg.dataset_names:
dataset = self._load_raw_dataset(dataset_name)
try:
test_dataset, metrics, num_classes, collection_name = process_dataset(
dataset, self.lang, self.cfg, dataset_name
)
# test_dataset is None if dataset is not available in language
if test_dataset is not None:
# map dataset idx to name & metrics, so to track in LightningModule
self._metadata["datasets"].append(
{
"name": collection_name,
"metrics": metrics,
"num_classes": num_classes,
}
)
self._processed_datasets.append(test_dataset)
except Exception as e:
# skip to the next dataset if there's an error
print(f"Error processing dataset {dataset_name}: {e}")
continue
print("Done.")
self.is_setup = True
def get_metadata(self):
return self._metadata
def collate_fn(self, batch: List[Dict]) -> Tuple[Tensor, Tensor, Tensor]:
"""
For each input, encodes it and concatenates it with each
of available (encoded) options
Padding is applied in the process
token_type_ids (0 is input, 1 is option, 2 is padding) are tracked throughout
Batches labels into LongTensor
Returns Tuple of (input_ids, token_type_ids, labels)
Dimensions of ((B x O x S), (B x O x S), (B, ))
where B is batch size, O is number of options, S is max sequence length in B
"""
# apply any pre-collation processing first
pre_collate_kwargs = {"src_lang": self.lang, "separator": self.cfg.separator}
proc_batch: List[Dict] = self._pre_collate_fn(batch=batch, **pre_collate_kwargs)
# batch encode the inputs
input_encodings = self.tokenizer(
[x["input"] for x in proc_batch], truncation=True
)["input_ids"]
# we then go through batch to concatenate each option to a given input
batch_concats = []
batch_tok_type_ids = []
batch_labels = []
for input_encoding, item in zip(input_encodings, proc_batch):
batch_labels.append(item["label"])
input_tok_type_ids = [0 for _ in input_encoding]
# encode each option, prefixed by separator
option_encodings = self.tokenizer(
[self.cfg.separator + option for option in item["options"]],
truncation=True,
)["input_ids"]
# we then concatenate each option to our current input encoding
tok_type_ids = []
concat_encodings = []
for option_encoding in option_encodings:
concatenated = input_encoding + option_encoding
tok_type_id = input_tok_type_ids + [1 for _ in option_encoding]
# truncate from left side to see most recent tokens if necessary
concatenated = concatenated[-self.max_seq_length :]
tok_type_id = tok_type_id[-self.max_seq_length :]
# and add to options
concat_encodings.append(torch.LongTensor(concatenated))
tok_type_ids.append(torch.LongTensor(tok_type_id))
# here we are padding across options
concat_encodings = pad_sequence(
concat_encodings, batch_first=False, padding_value=self.pad_token_id
)
tok_type_ids = pad_sequence(
tok_type_ids, batch_first=False, padding_value=2
)
batch_concats.append(concat_encodings)
batch_tok_type_ids.append(tok_type_ids)
# here we pad across the batch
batch_concats = (
pad_sequence(
batch_concats, batch_first=True, padding_value=self.pad_token_id
)
.permute(0, 2, 1)
.contiguous()
)
batch_tok_type_ids = (
pad_sequence(batch_tok_type_ids, batch_first=True, padding_value=2)
.permute(0, 2, 1)
.contiguous()
)
return batch_concats, batch_tok_type_ids, torch.LongTensor(batch_labels)
def test_dataloader(self) -> List[DataLoader]:
"""Returns a test dataloader for each processed dataset"""
return [
DataLoader(
dataset,
batch_size=self.cfg.batch_size,
sampler=SequentialSampler(dataset),
collate_fn=self.collate_fn,
num_workers=self.cfg.num_workers,
pin_memory=True,
)
for dataset in self._processed_datasets
]
def _load_raw_dataset(self, dataset_name: str):
# parses dataset name
if ";" in dataset_name:
collection: str
subcollection: Optional[str]
collection, subcollection = dataset_name.split(";")
else:
collection = dataset_name
subcollection = None
dataset_path = os.path.join(
self.raw_save_path,
collection,
subcollection if subcollection is not None else "",
)
if os.path.exists(dataset_path):
dataset = datasets.load_from_disk(dataset_path)
else:
dataset = datasets.load_dataset(collection, subcollection)
# create save directory if it doesn't exist, and save to disk
Path(dataset_path).mkdir(parents=True, exist_ok=True)
dataset.save_to_disk(dataset_path)
return dataset
def _download_dataset(self, dataset_name: str):
"""Downloads huggingface dataset"""
self._load_raw_dataset(dataset_name)
return
def set_pre_collate_fn(self, pre_collate_fn: Callable[[List[Dict]], List[Dict]]):
"""
Sets a pre-collate processing function, which is applied to each batch
before collation
"""
self._pre_collate_fn = pre_collate_fn
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
self.tokenizer.truncation_side = "left"
# see https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/2
self.tokenizer.pad_token = self.tokenizer.eos_token
self.pad_token_id = tokenizer.convert_tokens_to_ids(
tokenizer.special_tokens_map["pad_token"]
)
self.max_seq_length = min(1024, tokenizer.model_max_length)
@hydra.main(version_base=None, config_path="../conf", config_name="setup_data")
def main(cfg: DictConfig):
"""
downloads and processes the data for benchmark for each of the available languages
when calling from CLI, pass data=benchmark
"""
from omegaconf import OmegaConf
import wandb
cfg.data.seed = cfg.seed
print(OmegaConf.to_yaml(cfg))
script_host = "slurm" if "SLURM_JOB_ID" in os.environ else "local"
wandb.init(
project="claficle",
entity="giulio-uva",
job_type="benchmark",
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
mode="disabled" if cfg.disable_wandb else "online",
group=script_host,
)
benchmark = BenchmarkDataModule(cfg.data, cfg.lang)
benchmark.prepare_data()
benchmark.setup()
pprint(benchmark.get_metadata())
if __name__ == "__main__":
main()