In [1]:
import logging
import torch, os, pytorch_lightning as pl, glob
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset



class MedMCQA_Datamodule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, parser_args=None):
        super().__init__()
        self.logger = logging.getLogger("lightning")

        self.batch_size = batch_size
        # self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.parser_args = parser_args


    def setup(self, stage=None):
        folder_path = f"/root/pubmedQA_291/dataset_pickles/medmcqa/{self.parser_args['pretrained_model']}/{self.parser_args['model_style']}_style/"   # either classification or clip
        if stage == 'fit' or stage == "validate" or stage is None:
            
            self.logger.info('Setting up train and val dataset')
            self.train_dataset = torch.load(folder_path + "train_dataset.pt")

            self.val_dataset = torch.load(folder_path + "val_dataset.pt")
            
        if stage == "test"   : 
            self.logger.info('Setting up test dataset')
            self.test_dataset = torch.load(folder_path + "test_dataset.pt")
            
        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, shuffle=True, 
                          num_workers=self.parser_args["num_workers"], pin_memory=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size,
                            num_workers=self.parser_args["num_workers"], pin_memory=True, persistent_workers=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, self.batch_size, 
                          num_workers=self.parser_args["num_workers"], pin_memory=True, persistent_workers=True)

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
import yaml

parser_args = yaml.safe_load(open('/root/pubmedQA_291/exps/lightning_logs/bert-base-uncased_classification/experiment/hparams.yaml'))


datamodule_mcqa = MedMCQA_Datamodule(parser_args=parser_args)
datamodule_mcqa.prepare_data()
datamodule_mcqa.setup()
datamodule_mcqa.setup('test')

In [14]:
ids, masks, lables = zip(*datamodule_mcqa.test_dataset)

In [15]:
all_lables = torch.stack(lables)
all_lables.min(), all_lables.max()

(tensor(-1), tensor(-1))

In [16]:
all_lables

tensor([-1, -1, -1,  ..., -1, -1, -1])

In [None]:
val_loader = datamodule_mcqa.val_dataloader()

In [6]:
dataset = load_dataset("openlifescienceai/medmcqa")

In [29]:
i = 0 


In [41]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [43]:
tokens = tokenizer(dataset["validation"]["question"], padding=True, truncation=True, return_tensors="pt")

In [47]:
(tokens).keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [50]:
tokens['input_ids'][0]

tensor([  101,  2029,  1997,  1996,  2206,  2003,  2025,  2995,  2005,  2026,
        18809,  4383,  9113, 16662,  1024,   102,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0, 

In [58]:
tokens['attention_mask'][0], tokens['attention_mask'][1], 

(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [56]:
tokens['attention_mask'].shape

torch.Size([4183, 176])

In [69]:
A,B,C,D = dataset["validation"]['opa'], dataset["validation"]['opb'],dataset["validation"]['opc'],dataset["validation"]['opd']
questions = dataset["validation"]['question']
subject_names = dataset["validation"]['subject_name']
topic_names = dataset["validation"]['topic_name']

labels = dataset["validation"]['cop']
# def create_input_string():
#     input_str = f"Subject: {subject}, Topic: {topic}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n"
#     return input_str

input_prompts = map(lambda a,b,c,d,q,l,subject,topic: f"Subject: {subject}, Topic: {topic}\nQuestion: {q}\nA: {a}\nB: {b}\nC: {c}\nD: {d}\n",
                    A,B,C,D,questions,labels, subject_names, topic_names)

for i in input_prompts:
    print(i)
    break

Subject: Physiology, Topic: None
Question: Which of the following is not true for myelinated nerve fibers:
A: Impulse through myelinated fibers is slower than non-myelinated fibers
B: Membrane currents are generated at nodes of Ranvier
C: Saltatory conduction of impulses is seen
D: Local anesthesia is effective only when the nerve is not covered by myelin sheath



a

In [60]:
a

'Impulse through myelinated fibers is slower than non-myelinated fibers'

In [61]:
q

'Which of the following is not true for myelinated nerve fibers:'

In [36]:
i = i+1
while dataset["validation"]['subject_name'][i] != "Medicine":
    i+=1

print(dataset["validation"]['question'][i])

Which of the following is not. true regarding myelopathy?


In [37]:
dataset["validation"]['opa'][i], dataset["validation"]['opb'][i],dataset["validation"]['opc'][i],dataset["validation"]['opd'][i],   

('Sensory loss of facial area',
 'Brisk jaw jerk',
 'Brisk pectoral jerk',
 'Urgency and incontinence of micturition')

In [38]:
dataset["validation"]['cop'][i]

1

In [39]:
dataset["validation"]['exp'][i]

"Ans. b. Brisk jaw jerk(Ref: De Jongs Neurological examination/ p194, 201, 474.'Jaw jerk is exaggerated in supranuclear lesions that are above the mid pons."

In [40]:
len(dataset["validation"])

4183