# Medical Model Training Notebook

This notebook demonstrates how to train medical language models using the Kaggle disease dataset.

## Overview
- Load processed Q&A data
- Create PyTorch datasets
- Initialize medical language model
- Train the model
- Evaluate performance
- Save trained model


In [None]:
# Import required libraries
import sys
import os
sys.path.append('../')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModel, 
    TrainingArguments, Trainer,
    AutoModelForCausalLM, AutoModelForSequenceClassification
)
import pandas as pd
import numpy as np
import json
from pathlib import Path
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


## 1. Load Training Data

In [None]:
# Load processed Q&A data
data_path = Path("../data/processed/train_qa_pairs.json")
test_data_path = Path("../data/processed/test_qa_pairs.json")

with open(data_path, "r") as f:
    train_data = json.load(f)

with open(test_data_path, "r") as f:
    test_data = json.load(f)

print(f"Training data: {len(train_data)} samples")
print(f"Test data: {len(test_data)} samples")

# Display sample data
print("\nSample training data:")
print(f"Question: {train_data[0]['question']}")
print(f"Answer: {train_data[0]['answer'][:100]}...")
print(f"Disease: {train_data[0]['disease']}")
print(f"Symptoms: {train_data[0]['symptoms']}")


## 2. Create PyTorch Dataset


In [None]:
# Create PyTorch Dataset for medical Q&A
class MedicalQADataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Add special tokens
        self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Create input text
        question = item['question']
        answer = item['answer']
        
        # Format for causal language modeling
        text = f"Question: {question}\nAnswer: {answer}"
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': encoding['input_ids'].squeeze()  # For causal LM, labels are same as input_ids
        }

# Initialize tokenizer
model_name = "microsoft/DialoGPT-medium"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create datasets
train_dataset = MedicalQADataset(train_data, tokenizer)
test_dataset = MedicalQADataset(test_data, tokenizer)

print(f"Created datasets:")
print(f"Train: {len(train_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")


## 3. Initialize Model and Training Setup


In [None]:
# Initialize the model
model = AutoModelForCausalLM.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# Move model to device
model = model.to(device)

print(f"Model loaded: {model_name}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Training arguments
training_args = TrainingArguments(
    output_dir="../data/models/medical_model",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir="../logs",
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)


## 4. Train the Model


In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
)

# Start training
print("Starting training...")
trainer.train()

print("Training completed!")


## 5. Evaluate and Test the Model


In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print("Evaluation results:")
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

# Test the model with sample questions
def generate_response(question, max_length=200):
    prompt = f"Question: {question}\nAnswer:"
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("Answer:")[-1].strip()

# Test with sample questions
test_questions = [
    "What disease causes fever and cough?",
    "I have itching and skin rash, what could it be?",
    "What condition is associated with chills and fever?"
]

print("\nTesting model with sample questions:")
for question in test_questions:
    response = generate_response(question)
    print(f"Q: {question}")
    print(f"A: {response}")
    print("-" * 50)
