In [104]:

import json
import csv
import torch
from torch.utils.data import Dataset,DataLoader

from loguru import logger
import pytorch_lightning as pl

from torchnlp.utils import collate_tensors, lengths_to_mask

from transformers import AutoTokenizer



def load_mednli(datadir='../data/mednli/'):
    filenames = [
        'mli_train_v1.jsonl',
        'mli_dev_v1.jsonl',
        'mli_test_v1.jsonl',
    ]

    filenames = [datadir+f  for f in filenames]

    mednli_train, mednli_dev, mednli_test = [read_mednli(f) for f in filenames]

    return mednli_train, mednli_dev, mednli_test


def read_mednli(filename) -> list:
    data = []

    with open(filename, 'r') as f:
        for line in f:
            example = json.loads(line)

            premise = (example['sentence1'])
            hypothesis = (example['sentence2'])
            label = example.get('gold_label', None)
            data.append((premise,hypothesis,label))

    print(f'MedNLI file loaded: {filename}, {len(data)} examples')
    return data


class MedNLIDataset(torch.utils.data.Dataset):
    LABEL_TO_ID = {'contradiction': 0, 'entailment': 1, 'neutral': 2}

    def __init__(self, hparams,mednli_data):
        self.premises, self.hypotheses, labels = zip(*mednli_data)

        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')#hparams.encoder_model)
        self.labels = [MedNLIDataset.LABEL_TO_ID[l] if l is not None else -1 for l in labels]

    def __getitem__(self, index):
        premise = self.premises[index]
        hypothesis = self.hypotheses[index]
        label = self.labels[index]
        encoded_inputs = self.tokenizer(premise, hypothesis, truncation=True)
        return encoded_inputs, label

    def __len__(self):
        return len(self.labels)


class MedNLIDataModule(pl.LightningDataModule):
        def __init__(self, hparams=None):
            super().__init__()
            self.hparams = hparams
            # if self.hparams.transformer_type == 'longformer':
            #     self.hparams.batch_size = 1


        def setup(self, stage=None):
            mednli_train, mednli_dev, mednli_test = load_mednli()
            self.train_dataset, self.val_dataset, self.test_dataset = MedNLIDataset(self.hparams,mednli_train),MedNLIDataset(self.hparams,mednli_dev),MedNLIDataset(self.hparams,mednli_test)
            logger.info('MedNLI JSONs loaded...')

        def train_dataloader(self) -> DataLoader:
            logger.warning('Loading training data...')
            return DataLoader(
                dataset=self.train_dataset,
                shuffle=True,
                batch_size=1,
            )
        def val_dataloader(self) -> DataLoader:
            logger.warning('Loading validation data...')
            return DataLoader(
                dataset=self.val_dataset,
                shuffle= False,
                batch_size= 1,
            )
        def test_dataloader(self) -> DataLoader:
            logger.warning('Loading testing data...')
            return DataLoader(
                dataset=self.test_dataset,
                shuffle= False,
                batch_size=1,
            )


In [105]:
dm = MedNLIDataModule(None)
dm.setup()
test_dl = dm.test_dataloader()
test_dl

MedNLI file loaded: ../data/mednli/mli_train_v1.jsonl, 11232 examples
MedNLI file loaded: ../data/mednli/mli_dev_v1.jsonl, 1395 examples
MedNLI file loaded: ../data/mednli/mli_test_v1.jsonl, 1422 examples
2020-11-27 10:27:16.958 | INFO     | __main__:setup:76 - MedNLI JSONs loaded...


<torch.utils.data.dataloader.DataLoader at 0x7fdbefd74090>

In [109]:
batch = next(iter(test_dl))
input, target = batch


In [112]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model = AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased ')


TypeError: __init__() takes 1 positional argument but 2 were given

In [44]:
preprocess_function(trn[:5])

{'input_ids': [[101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 8319, 13675, 102], [101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 3671, 13675, 102], [101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 8319, 21122, 102], [101, 6396, 9153, 21693, 2271, 1998, 1056, 12414, 2075, 1997, 1054, 2849, 2001, 3264, 1012, 102, 1996, 5776, 2018, 19470, 11265, 10976, 11360, 1012, 102], [101, 6396, 9153, 21693, 2271, 1998, 1056, 12414, 2075, 1997, 1054, 2849, 2001, 3264, 1012, 102, 1996, 5776, 2038, 1037, 3671, 11265, 10976, 11360, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [45]:
encoded_dataset = test.map(preprocess_function, batched=True)
encoded_dataset

AttributeError: 'list' object has no attribute 'map'