In [38]:
import json
import csv
import torch
from torch.utils.data import Dataset,DataLoader

from loguru import logger
import pytorch_lightning as pl

from torchnlp.utils import collate_tensors, lengths_to_mask

from transformers import AutoTokenizer

def load_mednli(datadir='../data/mednli/'):
    filenames = [
        'mli_train_v1.jsonl',
        'mli_dev_v1.jsonl',
        'mli_test_v1.jsonl',
    ]

    filenames = [datadir+f  for f in filenames]

    mednli_train, mednli_dev, mednli_test = [read_mednli(f) for f in filenames]

    return mednli_train, mednli_dev, mednli_test


def read_mednli(filename) -> list:
    data = []

    with open(filename, 'r') as f:
        for line in f:
            example = json.loads(line)
            premise = (example['sentence1'])
            hypothesis = (example['sentence2'])
            label = example.get('gold_label', None)
            # data.append({premise:'premise',hypothesis:'hypothesis',label:'label'})
            data.append((premise,hypothesis,label))

    print(f'MedNLI file loaded: {filename}, {len(data)} examples')
    return data

# tokenizer=AutoTokenizer.from_pretrained('bert-base-uncased')
def preprocess_function(examples):
    pp,hh,_ = zip(*examples)
    return tokenizer(pp, hh, truncation=True)

In [39]:
trn,val,test = load_mednli()

MedNLI file loaded: ../data/mednli/mli_train_v1.jsonl, 11232 examples
MedNLI file loaded: ../data/mednli/mli_dev_v1.jsonl, 1395 examples
MedNLI file loaded: ../data/mednli/mli_test_v1.jsonl, 1422 examples


In [40]:
preprocess_function(trn[:5])

{'input_ids': [[101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 8319, 13675, 102], [101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 3671, 13675, 102], [101, 13625, 2020, 3862, 2005, 13675, 1015, 1012, 1021, 1006, 26163, 1014, 1012, 1019, 2566, 2214, 2636, 1007, 1998, 18749, 12259, 1016, 1012, 1018, 1012, 102, 5776, 2038, 8319, 21122, 102], [101, 6396, 9153, 21693, 2271, 1998, 1056, 12414, 2075, 1997, 1054, 2849, 2001, 3264, 1012, 102, 1996, 5776, 2018, 19470, 11265, 10976, 11360, 1012, 102], [101, 6396, 9153, 21693, 2271, 1998, 1056, 12414, 2075, 1997, 1054, 2849, 2001, 3264, 1012, 102, 1996, 5776, 2038, 1037, 3671, 11265, 10976, 11360, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0