### Load in Packages

In [1]:
import os
os.environ["HF_HOME"] = "/home/shared/.cache/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/home/shared/.cache/huggingface/hub"

In [2]:
from peft import LoraConfig, get_peft_model
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import json as pd
import numpy as np

### Check GPU Availability

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.cuda.is_available())

True


### Load Data

In [4]:
import pandas as pd

file_path = "Data/PubMedQA_cleaned.json"
QA_data = pd.read_json(file_path)
data = QA_data

In [5]:
# Calculate class weights based on the training data
class_counts = data['gold_index'].value_counts().sort_index().values 
total_samples = len(data)
class_weights = torch.tensor([total_samples / count for count in class_counts], dtype=torch.float).to(device)

print(f"Class counts: {class_counts}")
print(f"Class weights: {class_weights}")

Class counts: [338 110 552]
Class weights: tensor([2.9586, 9.0909, 1.8116], device='cuda:0')


### Hugging Face Login

In [6]:
import os
from huggingface_hub import login
from dotenv import load_dotenv

load_dotenv()
hf_token = os.getenv("hf_token")
login(token=hf_token)

### Set LoRA Config & Load in Teacher Model

In [7]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],  # works well with LLaMA structure
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

teacher_model_name = "Henrychur/MMed-Llama-3-8B" 
# teacher_model_name = "microsoft/biogpt" 
# teacher_model_name = "microsoft/BioGPT-Large-PubMedQA"
# teacher_model_name = "stanford-crfm/BioMedLM"

teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_name, 
    torch_dtype=torch.float16, 
    device_map="auto"
)
lora_model = get_peft_model(teacher_model, lora_config)
lora_model.enable_input_require_grads()
lora_model.gradient_checkpointing_enable()


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [8]:
# Assign eos_token as pad_token
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

In [9]:
class LoraClassificationModel(nn.Module):
    def __init__(self, lora_model, hidden_size, num_classes=3, class_weights=None):
        super(LoraClassificationModel, self).__init__()
        self.lora_model = lora_model
        self.class_weights = class_weights

        # Multi-layer perceptron (MLP) classification head
        self.classification_head = nn.Sequential(
            nn.Linear(hidden_size, 768),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(768, 256),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_classes)
        )

    def forward(self, input_ids, attention_mask, labels=None):
        # Use the LoRA model to get the hidden states
        outputs = self.lora_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )

        # Get the last hidden state of the last layer
        cls_hidden_state = outputs.hidden_states[-1][:, -1, :]

        logits = self.classification_head(cls_hidden_state)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fn(logits.float(), labels)

        return logits, loss


In [10]:
# Wrapping the model with LoRA
num_classes = 3
model = LoraClassificationModel(
    lora_model=lora_model, 
    hidden_size=teacher_model.config.hidden_size,
    num_classes=num_classes
).to(device)

### Tokenize the data

In [11]:
# Prepare the dataset
class QADataset(Dataset):
    def __init__(self, tokenizer, data, max_length=512):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        input_text = f"Context: {row['context']} Question: {row['question']}"
        label = row['gold_index']  # The correct class index (0, 1, or 2)

        # Tokenize the input text
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long)
        }


### Split into Train/Test

In [12]:
# Split the dataset into training and validation sets
train_data, val_data = train_test_split(data, test_size=0.5, random_state=401)

In [13]:
# Create DataLoaders
train_dataset = QADataset(teacher_tokenizer, train_data)
val_dataset = QADataset(teacher_tokenizer, val_data)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)

### Fine-tune

In [14]:
from torch.cuda.amp import GradScaler, autocast
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import precision_score, recall_score, f1_score

In [18]:
# Fine-tune the model
optimizer = AdamW(model.parameters(), lr=1e-5)
num_epochs = 10

num_epochs = 50
epoch=0
train_losses = []
val_losses = []
val_accuracies = []
val_precisions = []
val_recalls = []
val_f1_scores = []

In [19]:
# FOR MIXED PRECISION
from torch.amp import GradScaler, autocast
scaler = torch.amp.GradScaler()

# Early Stopping parameters
best_val_loss = float("inf")
patience = 30
counter = 0

for epoch in range(num_epochs):
    # Training loop
    model.train()
    train_loss = 0

    for batch in train_dataloader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        with torch.amp.autocast(device_type=device):
            logits, loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss_value = loss.item()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss_value

    avg_train_loss = train_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    # --- Validation ---
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            with torch.amp.autocast(device_type=device):
                logits, loss = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

            val_loss += loss.item()

            predictions = torch.argmax(logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    sum_0 = np.sum(np.array(all_predictions) == 0)
    sum_1 = np.sum(np.array(all_predictions) == 1)
    sum_2 = np.sum(np.array(all_predictions) == 2)
    print(f"Predictions: {sum_0}x0s, {sum_1}x1s, {sum_2}x2s")

    avg_val_loss = val_loss / len(val_dataloader)
    val_losses.append(avg_val_loss)

    accuracy = correct / total
    val_accuracies.append(accuracy)

    precision = precision_score(all_labels, all_predictions, average="weighted", zero_division=0)
    recall = recall_score(all_labels, all_predictions, average="weighted", zero_division=0)
    f1 = f1_score(all_labels, all_predictions, average="weighted", zero_division=0)

    val_precisions.append(precision)
    val_recalls.append(recall)
    val_f1_scores.append(f1)

    print(f"Epoch {epoch + 1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
          f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

    # Early Stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        # Save the best model
        # torch.save(model.state_dict(), "best_model.pth")
        # print("Best model saved!")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered")
            break

Predictions: 108x0s, 9x1s, 383x2s
Epoch 1 - Train Loss: 0.0002, Val Loss: 2.8475, Accuracy: 0.5500, Precision: 0.5143, Recall: 0.5500, F1-Score: 0.4990
Predictions: 117x0s, 13x1s, 370x2s
Epoch 2 - Train Loss: 0.0001, Val Loss: 2.8357, Accuracy: 0.5480, Precision: 0.5056, Recall: 0.5480, F1-Score: 0.5037
Predictions: 122x0s, 19x1s, 359x2s
Epoch 3 - Train Loss: 0.0000, Val Loss: 2.8310, Accuracy: 0.5400, Precision: 0.5043, Recall: 0.5400, F1-Score: 0.5035
Predictions: 130x0s, 15x1s, 355x2s
Epoch 4 - Train Loss: 0.0000, Val Loss: 2.9707, Accuracy: 0.5300, Precision: 0.4864, Recall: 0.5300, F1-Score: 0.4919
Predictions: 126x0s, 15x1s, 359x2s
Epoch 5 - Train Loss: 0.0000, Val Loss: 3.0468, Accuracy: 0.5340, Precision: 0.4898, Recall: 0.5340, F1-Score: 0.4943
Predictions: 116x0s, 13x1s, 371x2s
Epoch 6 - Train Loss: 0.0000, Val Loss: 3.1894, Accuracy: 0.5500, Precision: 0.5061, Recall: 0.5500, F1-Score: 0.5045
Predictions: 128x0s, 15x1s, 357x2s
Epoch 7 - Train Loss: 0.0000, Val Loss: 3.1821, 

KeyboardInterrupt: 