In [0]:
%pip install torch
%pip install transformers

Collecting torch
  Obtaining dependency information for torch from https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata
  Using cached torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting networkx (from torch)
  Obtaining dependency information for networkx from https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl.metadata
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Obtaining dependency information for jinja2 from https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl.metadata
  Using cached jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Obtaining dependency information for fsspec from https://files.pythonh

In [0]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW 
from sklearn.model_selection import train_test_split
import numpy as np
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
 

# 1. Load and preprocess the dataset
class MedicalQADataset(Dataset):
    def __init__(self, questions, answers, tokenizer, max_length=512):
        self.questions = questions
        self.answers = answers
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        question = str(self.questions[idx])
        answer = str(self.answers[idx])
        
        # Tokenize the question and answer pair
        encoding = self.tokenizer.encode_plus(
            question,
            answer,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(1, dtype=torch.long)  # Positive label for correct pairs
        }

def load_and_prepare_data(csv_path):
    # Load the CSV file
    df = pd.read_csv('med_qa.csv')
    
    # Ensure we have 'question' and 'answer' columns
    if 'question' not in df.columns or 'answer' not in df.columns:
        raise ValueError("CSV must contain 'question' and 'answer' columns")
    
    # Split the data into training and validation sets
    train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
    
    return train_df, val_df

# 2. Fine-tune ClinicalBERT
def fine_tune_clinical_bert(train_dataset, val_dataset, model_name="emilyalsentzer/Bio_ClinicalBERT"):
    # Load the pre-trained model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8)
    
    # Set up optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=2e-5)
    total_steps = len(train_loader) * 3  # 3 epochs
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0,
        num_training_steps=total_steps
    )
    
    # Training loop
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    best_val_loss = float('inf')
    
    for epoch in range(3):
        print(f"Epoch {epoch + 1}/3")
        
        # Training
        model.train()
        train_loss = 0
        
        for batch in tqdm(train_loader):
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            train_loss += loss.item()
            
            # Backward pass
            loss.backward()
            
            # Update parameters
            optimizer.step()
            scheduler.step()
        
        avg_train_loss = train_loss / len(train_loader)
        print(f"Average training loss: {avg_train_loss}")
        
        # Validation
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for batch in tqdm(val_loader):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                
                loss = outputs.loss
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)
        print(f"Average validation loss: {avg_val_loss}")
        
        # Save the model if it's the best so far
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_clinical_bert_model.pt")
            print("Model saved!")
    
    # Load the best model
    model.load_state_dict(torch.load("best_clinical_bert_model.pt"))
    
    return model, tokenizer




Loading and preparing data...
Creating datasets...
Fine-tuning ClinicalBERT...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3


  0%|          | 0/1145 [00:00<?, ?it/s]  0%|          | 1/1145 [00:06<2:02:08,  6.41s/it]  0%|          | 2/1145 [00:13<2:05:03,  6.56s/it]  0%|          | 3/1145 [00:19<2:00:34,  6.33s/it]  0%|          | 4/1145 [00:25<2:00:48,  6.35s/it]  0%|          | 5/1145 [00:31<1:59:02,  6.27s/it]  1%|          | 6/1145 [00:38<2:00:32,  6.35s/it]  1%|          | 7/1145 [00:44<1:58:27,  6.25s/it]  1%|          | 8/1145 [00:50<1:58:04,  6.23s/it]  1%|          | 9/1145 [00:56<1:56:48,  6.17s/it]  1%|          | 10/1145 [01:02<1:56:43,  6.17s/it]  1%|          | 11/1145 [01:08<1:56:07,  6.14s/it]  1%|          | 12/1145 [01:14<1:54:38,  6.07s/it]  1%|          | 13/1145 [01:20<1:55:47,  6.14s/it]  1%|          | 14/1145 [01:27<1:57:13,  6.22s/it]

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:136)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:136)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:133)
	at scala.collection.immutable.Range.foreach(Range.scala:158)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:133)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:728)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:446)
	at scala.Option.getOrElse(Option.scala:189)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:446)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can