In [None]:
import shutil
import os

# Remove the directory if already exist 
dir_name = 'neural_medical_qa'
if os.path.exists(dir_name):
    shutil.rmtree(dir_name)

#clone the repo from github
!git clone https://github.com/trduc97/neural_medical_qa.git
%cd neural_medical_qa
# install the requirement
!pip install -r requirements.txt

In [None]:
from import_datasets import load_bioasq_pubmedqa,  train_val_test_split

bioasq, pubmedqa = load_bioasq_pubmedqa()

# Display the first few samples of the PubMedQA dataset
print(pubmedqa['train'].to_pandas().head())

responses = pubmedqa['train']['final_decision']
# Counting the occurrences of each value
yes_count = responses.count('yes')
no_count = responses.count('no')
maybe_count = responses.count('maybe')

# Display the counts
print(f"Yes: {yes_count}")
print(f"No: {no_count}")
print(f"Maybe: {maybe_count}")

pubmedqa_train,pubmedqa_val, pubmedqa_test = train_val_test_split(pubmedqa)
print(f"Train size: {len(pubmedqa_train)}")
print(f"Validation size: {len(pubmedqa_val)}")
print(f"Test size: {len(pubmedqa_test)}")

In [None]:
from collections import defaultdict
# Initialize a defaultdict to hold the bucket counts
length_buckets = defaultdict(int)

# Define the bucket size
bucket_size = 128

# Loop through each string in the list
for s in pubmedqa_train['long_answer']:
    # Determine the bucket for the current string length
    bucket = (len(s) // bucket_size) * bucket_size
    # Increment the count for the appropriate bucket
    length_buckets[bucket] += 1

# Display the counts for each bucket
for bucket, count in sorted(length_buckets.items()):
    print(f"Length {bucket} - {bucket + bucket_size - 1}: {count} strings")

In [None]:
bioasq, pubmedqa_artificial = load_bioasq_pubmedqa(pubmed_kaggle_path='/kaggle/input/pubmed-qa/pubmed_qa_pga_artificial.parquet')

In [None]:
from datasets import DatasetDict, Dataset
from sklearn.model_selection import train_test_split

df_artificial=pubmedqa_artificial['train'].to_pandas()
df_sample, _=train_test_split(df_artificial, test_size=0.95, random_state=42, stratify=df_artificial['decision_encoded'])   
df_sample=df_sample[['pubid', 'question', 'context', 'long_answer', 'final_decision', 'decision_encoded']]
data_art=Dataset.from_pandas(df_sample,preserve_index=False)

In [None]:
# Convert back to datasets
pubmedqa_arti = DatasetDict({'train': data_art})
pubmedqa_art_train,pubmedqa_art_val, pubmedqa_art_test = train_val_test_split(pubmedqa_arti)

In [None]:
import os
import torch
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel, GPT2Tokenizer, GPT2Model
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, f1_score

class QAModel(nn.Module):
    def __init__(self, model, classes=3, dropout_prob=0.5):
        super(QAModel, self).__init__()
        self.bert = model
        self.dropout1 = nn.Dropout(dropout_prob)
        self.linear1 = nn.Linear(model.config.hidden_size, 128)
        self.dropout2 = nn.Dropout(dropout_prob)
        self.linear2 = nn.Linear(128, classes)  # number of classes may vary between BioASQ (2 classes) and PubMedQA (3 classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        cls_output = self.dropout1(cls_output)  # Apply first dropout
        cls_output = self.linear1(cls_output)  # Apply first linear layer
        cls_output = self.dropout2(cls_output)  # Apply second dropout
        logits = self.linear2(cls_output)  # Apply second linear layer
        return logits

class TrainandValidate:

    def __init__(self, model_name, source, df_train, df_val, df_test, stratify_col='decision_encoded'):
        self.name = model_name
        self.source = source
        self.batch_size = 16 if 'GPT' in self.name or 'artificial' in self.name else 64

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.tokenizer = self.initialize_tokenizer()
        self.stratify_col = stratify_col

        self.train_inputs, self.train_labels = self.encode_data(df_train)
        self.validate_inputs, self.validate_labels = self.encode_data(df_val)
        self.test_inputs, self.test_labels = self.encode_data(df_test)

        self.train_loader = self.create_dataloader(self.train_inputs, self.train_labels)
        self.validate_loader = self.create_dataloader(self.validate_inputs, self.validate_labels)
        self.test_loader = self.create_dataloader(self.test_inputs, self.test_labels)

        self.model = self.create_model().to(self.device)  
        self.optimizer = optim.AdamW(self.model.parameters(), lr=2e-5)
        self.loss_fn = nn.CrossEntropyLoss()

    def initialize_tokenizer(self):
        if 'GPT' in self.name:
            tokenizer = GPT2Tokenizer.from_pretrained(self.source)
            if tokenizer.pad_token is None:
                tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            return tokenizer
        elif 'BioLinkBERT' in self.name:
            return AutoTokenizer.from_pretrained(self.source)
        else:
            return BertTokenizer.from_pretrained(self.source)

    def encode_data(self, df):
        inputs = self.tokenizer(
            text=df['question'], 
            text_pair=df['long_answer'], 
            padding=True, 
            truncation=True, 
            return_tensors='pt', 
            max_length=128*4
        )
        labels = torch.tensor(df[self.stratify_col])
        return inputs, labels

    def create_dataloader(self, inputs, labels):
        dataset = TensorDataset(inputs['input_ids'], inputs['attention_mask'], labels)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

    def create_model(self):
        if 'GPT' in self.name:
            model = GPT2Model.from_pretrained(self.source)
            model.resize_token_embeddings(len(self.tokenizer))
            model = QAModel(model)
        elif 'BioLinkBERT' in self.name:
            model = AutoModel.from_pretrained(self.source)
            model = QAModel(model)
        else:
            model = BertModel.from_pretrained(self.source)
            model = QAModel(model)
        return model

    def calculate_f1_score(self, preds, labels):
        preds_flat = np.argmax(preds, axis=1).flatten()
        labels_flat = labels.flatten()
        return f1_score(labels_flat, preds_flat, average='weighted')

    def evaluate(self, dataloader):
        self.model.eval()
        total_loss = 0
        predictions, true_labels = [], []
    
        with torch.no_grad():
            for batch in dataloader:
                b_input_ids, b_attention_mask, b_labels = [t.to(self.device) for t in batch]
                outputs = self.model(b_input_ids, b_attention_mask)
                logits = outputs.detach().cpu().numpy()
                label_ids = b_labels.cpu().numpy()
                predictions.extend(np.argmax(logits, axis=1))
                true_labels.extend(label_ids)
    
        accuracy = accuracy_score(true_labels, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average='weighted')
    
        return accuracy, precision, recall, f1

    def training(self, epochs=15):
        self.model.train()
        epochs = 5 if 'artificial' in self.name else epochs
        for epoch in range(epochs):
            total_loss = 0
            all_preds = []
            all_labels = []
        
            for batch in self.train_loader:
                b_input_ids, b_attention_mask, b_labels = [t.to(self.device) for t in batch]
                self.optimizer.zero_grad()
            
                outputs = self.model(b_input_ids, b_attention_mask)
                loss = self.loss_fn(outputs, b_labels)
                loss.backward()
                self.optimizer.step()
            
                total_loss += loss.item()
            
                preds = outputs.detach().cpu().numpy()
                label_ids = b_labels.to('cpu').numpy()
            
                all_preds.append(preds)
                all_labels.append(label_ids)
        
            avg_loss = total_loss / len(self.train_loader)
            all_preds = np.concatenate(all_preds, axis=0)
            all_labels = np.concatenate(all_labels, axis=0)
            avg_f1_score = self.calculate_f1_score(all_preds, all_labels)
        
            print(f"Epoch {epoch+1}, Loss: {avg_loss}, F1 Score: {avg_f1_score}")
        
        self.save_model()

    def save_model(self):
        os.makedirs('/kaggle/working/models', exist_ok=True)
        model_path = f'/kaggle/working/models/{self.name}_model.pth'
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, model_path)
        print(f"Model saved to {model_path}")

    def load_model(self, model_path):
        checkpoint = torch.load(model_path)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Model loaded from {model_path}")

    def val(self, load_model_path=None):
        if load_model_path:
            self.load_model(load_model_path)
        
        val_accuracy, val_precision, val_recall, val_f1 = self.evaluate(self.validate_loader)
        print(f"Validation - Accuracy: {val_accuracy}, Precision: {val_precision}, Recall: {val_recall}, F1-Score: {val_f1}")
        
        test_accuracy, test_precision, test_recall, test_f1 = self.evaluate(self.test_loader)
        print(f"Test - Accuracy: {test_accuracy}, Precision: {test_precision}, Recall: {test_recall}, F1-Score: {test_f1}")
    
        return {
            'validation': {
                'accuracy': val_accuracy,
                'precision': val_precision,
                'recall': val_recall,
                'f1': val_f1
            },
            'test': {
                'accuracy': test_accuracy,
                'precision': test_precision,
                'recall': test_recall,
                'f1': test_f1
            }
        }

In [None]:
models = [
    {
        'model_name': 'BERT',
        'source': 'bert-base-uncased',
        'df_train': pubmedqa_train,
        'df_val': pubmedqa_val,
        'df_test': pubmedqa_test
    },
    {
        'model_name': 'BioLinkBERT',
        'source': 'michiyasunaga/BioLinkBERT-base',
        'df_train': pubmedqa_train,
        'df_val': pubmedqa_val,
        'df_test': pubmedqa_test
    },
    {
        'model_name': 'GPT',
        'source': 'gpt2',
        'df_train': pubmedqa_train,
        'df_val': pubmedqa_val,
        'df_test': pubmedqa_test
    },
    {
        'model_name': 'BiomedNLP',
        'source': 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract',
        'df_train': pubmedqa_train,
        'df_val': pubmedqa_val,
        'df_test': pubmedqa_test
    }
]

In [None]:
val_results = {}

for model in models:
    trainer = TrainandValidate(
        model_name=model['model_name'],
        source=model['source'],
        df_train=model['df_train'],
        df_val=model['df_val'],
        df_test=model['df_test']
    )   
    # Train the model
    trainer.training()
    
    # Validate the model
    val_result = trainer.val()
    val_results[model['model_name']] = val_result