-
Notifications
You must be signed in to change notification settings - Fork 3
Datasets
Prajjwal Bhargava edited this page Oct 10, 2020
·
4 revisions
Create a standard GlueDataset
from transformers:
data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = /path/to/mnli)
tokenizer = AutoTokenizer.from_pretrained('albert-base-v2)
dataset = GlueDataset(data_args, tokenizer)
And then wrap inside MetaDataset
class.
from fluence.datasets import MetaDataset
metadataset = MetaDataset(dataset)
Now it will return a python dict containing samples from all the classes (suitable for few shot learning tasks).
Note that it does not require any non default collate_fn
. This works with default pytorch's default collate_fn
. Simply
iterate and feed it into the model. Here's how:
dataloader = DataLoader(meta_dataset, batch_size=8, drop_last=True)
loss = 0.
for idx, batch in enumerate(dataloader):
for class_sample in batch: # class_sample represents number of classes. In case of MNLI, this loop will run for 3 times.
loss += model(**class_sample)[0] # Transformer model returns a tuple.
print(class_sample['input_ids'].shape) # [8, 128]
break