Skip to content

Datasets

Prajjwal Bhargava edited this page Oct 10, 2020 · 4 revisions

Using a MetaDataset

Create a standard GlueDataset from transformers:

data_args = DataTrainingArguments(task_name = 'MNLI', data_dir = /path/to/mnli)
tokenizer = AutoTokenizer.from_pretrained('albert-base-v2)
dataset = GlueDataset(data_args, tokenizer)

And then wrap inside MetaDataset class.

from fluence.datasets import MetaDataset
metadataset = MetaDataset(dataset)

Now it will return a python dict containing samples from all the classes (suitable for few shot learning tasks). Note that it does not require any non default collate_fn. This works with default pytorch's default collate_fn. Simply iterate and feed it into the model. Here's how:

dataloader = DataLoader(meta_dataset, batch_size=8, drop_last=True)
loss = 0.
for idx, batch in enumerate(dataloader):
    for class_sample in batch:                 # class_sample represents number of classes. In case of MNLI, this loop will run for 3 times.
        loss += model(**class_sample)[0]       # Transformer model returns a tuple.
        print(class_sample['input_ids'].shape) # [8, 128]
    break
Clone this wiki locally