In [3]:
import logging
import pickle, random
import torch, os, pytorch_lightning as pl, glob
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer
import torch.nn.functional as F
from datasets import load_dataset


class MedMCQA_Datamodule_to_save_datasets(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, model_name: str = "bert-base-uncased"):
        # OR: microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext
        super().__init__()
        self.logger = logging.getLogger("lightning")

        self.batch_size = batch_size
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        
    def prepare_data(self):
        self.logger.info('Loading dataset')
        self.dataset = load_dataset("openlifescienceai/medmcqa")

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            
            self.logger.info('Setting up train dataset')
            train_dataset = self.dataset["train"]
            self.train_dataset = self.convert_to_dataset(train_dataset)

            self.logger.info('Setting up val dataset')
            val_dataset = self.dataset['validation']
            self.val_dataset = self.convert_to_dataset(val_dataset)
            
        if stage == "test"   : 
            self.logger.info('Setting up test dataset')
            test_dataset = self.dataset["test"]
            self.test_dataset = self.convert_to_dataset(test_dataset)
            
    def convert_to_dataset(self, dataset, mode = "classification"):
        # mode can be either "classification" or "clip"
        # Convert dataset to PyTorch format
        if mode == "classification":
            # 
            # print(dataset.keys())
            A,B,C,D = dataset['opa'], dataset['opb'],dataset['opc'],dataset['opd']
            questions = dataset['question']
            subject_names = dataset['subject_name']
            topic_names = dataset['topic_name']

            input_prompts = list(map(lambda a,b,c,d,q,subject,topic: f"Subject: {subject}, Topic: {topic}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n",
                                A,B,C,D,questions, subject_names, topic_names))
            
            tokens = self.tokenizer(input_prompts, padding=True, truncation=True, return_tensors="pt")
            labels = dataset['cop']            
            return torch.utils.data.TensorDataset(tokens.input_ids, tokens.attention_mask, torch.tensor(labels))
        
        elif mode == "clip":
            raise Exception("Not implemented yet")
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, self.batch_size)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
datamodule_mcqa = MedMCQA_Datamodule_to_save_datasets(model_name='microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext')
datamodule_mcqa.prepare_data()
datamodule_mcqa.setup()
datamodule_mcqa.setup("test")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
import os
import torch

# Create a folder to save the files
folder_path = '/root/pubmedQA_291/dataset_pickles/medmcqa/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
            #"/root/pubmedQA_291/dataset_pickles/medmcqa/classification_style/"
os.makedirs(folder_path, exist_ok=True)

# Save the train_dataset
train_dataset_path = os.path.join(folder_path, "train_dataset.pt")
torch.save(datamodule_mcqa.train_dataset, train_dataset_path)

# Save the val_dataset
val_dataset_path = os.path.join(folder_path, "val_dataset.pt")
torch.save(datamodule_mcqa.val_dataset, val_dataset_path)

# Save the test_dataset
test_dataset_path = os.path.join(folder_path, "test_dataset.pt")
torch.save(datamodule_mcqa.test_dataset, test_dataset_path)


In [1]:
# inspect longest sequence
import torch

In [2]:
train_dataset = torch.load('/root/pubmedQA_291/dataset_pickles/medmcqa/classification_style/train_dataset.pt')

In [5]:
train_dataset.tensors[0].shape, train_dataset.tensors[1].shape, train_dataset.tensors[2].shape

(torch.Size([182822, 453]), torch.Size([182822, 453]), torch.Size([182822]))

In [7]:
train_dataset.tensors[0][0]

tensor([  101,  3395,  1024, 13336,  1010,  8476,  1024, 24471,  3981,  2854,
        12859,  3160,  1024, 11888, 24471, 11031,  7941, 27208,  2349,  2000,
        28378, 26113, 12070, 23760, 24759, 15396,  2064,  2599,  2000,  1996,
         2206,  2689,  1999, 14234, 11968,  2368, 11714,  2863,  1037,  1024,
        23760, 24759, 15396,  1038,  1024, 23760,  7361, 10536,  1039,  1024,
         2012, 18981, 10536,  1040,  1024,  1040, 22571,  8523,  2401,   102,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [8]:
val_dataset = torch.load('/root/pubmedQA_291/dataset_pickles/medmcqa/classification_style/val_dataset.pt')

In [9]:
val_dataset.tensors[0].shape, val_dataset.tensors[1].shape, val_dataset.tensors[2].shape

(torch.Size([4183, 234]), torch.Size([4183, 234]), torch.Size([4183]))

In [10]:
val_dataset.tensors[0][0]

tensor([  101,  3395,  1024, 16127,  1010,  8476,  1024,  3904,  3160,  1024,
         2029,  1997,  1996,  2206,  2003,  2025,  2995,  2005,  2026, 18809,
         4383,  9113, 16662,  1024,  1037,  1024, 14982,  2083,  2026, 18809,
         4383, 16662,  2003, 12430,  2084,  2512,  1011,  2026, 18809,  4383,
        16662,  1038,  1024, 10804, 14731,  2024,  7013,  2012, 14164,  1997,
         2743, 14356,  1039,  1024,  5474, 14049,  6204,  3258,  1997, 14982,
         2015,  2003,  2464,  1040,  1024,  2334,  2019, 25344,  2003,  4621,
         2069,  2043,  1996,  9113,  2003,  2025,  3139,  2011,  2026, 18809,
        21867,   102,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 