In [1]:
print("="*80)
print("STEP 0: Installing Required Packages")
print("="*80)

!pip install -q transformers datasets torch accelerate sentencepiece
!pip install -q rouge-score bert-score
!pip install -q openai anthropic
!pip install -q faiss-cpu sentence-transformers
!pip install -q pandas numpy matplotlib seaborn plotly
!pip install -q scikit-learn openpyxl
!pip install -q huggingface_hub
!pip install -q kaggle  # For downloading datasets from Kaggle

print("✓ Installation complete!")


STEP 0: Installing Required Packages
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m388.2/388.2 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.6/23.6 MB[0m [31m60.5 MB/s[0m eta [36m0:00:00[0m
[?25h✓ Installation complete!


In [2]:
import os
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
from datasets import Dataset, DatasetDict, load_dataset
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from sentence_transformers import SentenceTransformer
import faiss

print("✓ All imports successful!")

✓ All imports successful!


In [6]:
print("\n" + "="*80)
print("STEP 1: Loading Actual Clinical Dialogue Datasets")
print("="*80)

class DatasetLoader:
    """
    Utility class to download and load clinical dialogue datasets.
    """

    def __init__(self):
        self.dataset_dir = "./clinical_datasets"
        os.makedirs(self.dataset_dir, exist_ok=True)

    def load_mtsamples(self):
        """
        Load MT-Samples dataset - Medical transcription samples.
        This dataset is available on Kaggle and Hugging Face.
        """
        print("\n--- Loading MT-Samples Dataset ---")

        try:
            # Method 1: Load from Hugging Face
            print("Attempting to load from Hugging Face...")
            dataset = load_dataset("keivalya/MedQuad-MedicalQnADataset")
            print(f"✓ Loaded {len(dataset['train'])} medical dialogues")
            return dataset
        except Exception as e:
            print(f"Hugging Face loading failed: {e}")

            # Method 2: Manual download instructions
            print("\nManual Download Instructions:")
            print("1. Visit: https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions")
            print("2. Download the dataset")
            print("3. Upload to Colab or place in ./clinical_datasets/")

            # Try to load from local CSV if exists
            csv_path = os.path.join(self.dataset_dir, "mtsamples.csv")
            if os.path.exists(csv_path):
                print(f"✓ Found local file: {csv_path}")
                df = pd.read_csv(csv_path)
                return df

            return None

    def load_medDialog(self):
        """
        Load MedDialog dataset - English medical conversations.
        """
        print("\n--- Loading MedDialog Dataset ---")

        try:
            # Available on Hugging Face
            dataset = load_dataset("medical_dialog", "en")
            print(f"✓ Loaded MedDialog dataset")
            print(f"  Train size: {len(dataset['train'])}")
            print(f"  Test size: {len(dataset['test'])}")
            return dataset
        except Exception as e:
            print(f"Error loading MedDialog: {e}")
            print("\nAlternative: Visit https://github.com/UCSD-AI4H/Medical-Dialogue-System")
            return None

    def load_healthcaremagic(self):
        """
        Load HealthCareMagic-100k dataset.
        """
        print("\n--- Loading HealthCareMagic Dataset ---")

        try:
            dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")
            print(f"✓ Loaded HealthCareMagic dataset: {len(dataset['train'])} conversations")
            return dataset
        except Exception as e:
            print(f"Error: {e}")
            return None

    def load_from_kaggle(self, dataset_name: str):
        """
        Load dataset from Kaggle.

        Setup instructions:
        1. Go to https://www.kaggle.com/settings/account
        2. Create API token (downloads kaggle.json)
        3. Upload kaggle.json to Colab
        """
        print(f"\n--- Loading {dataset_name} from Kaggle ---")

        # Check if kaggle.json exists
        kaggle_json_path = os.path.expanduser("~/.kaggle/kaggle.json")

        if not os.path.exists(kaggle_json_path):
            print("\n⚠️  Kaggle API credentials not found!")
            print("\nSetup Instructions:")
            print("1. Go to https://www.kaggle.com/settings/account")
            print("2. Click 'Create New API Token'")
            print("3. Upload the downloaded kaggle.json file:")
            print("\n   # Run this in a Colab cell:")
            print("   from google.colab import files")
            print("   uploaded = files.upload()")
            print("   !mkdir -p ~/.kaggle")
            print("   !cp kaggle.json ~/.kaggle/")
            print("   !chmod 600 ~/.kaggle/kaggle.json")
            return None

        try:
            # Download dataset
            import kaggle
            kaggle.api.dataset_download_files(
                dataset_name,
                path=self.dataset_dir,
                unzip=True
            )
            print(f"✓ Downloaded {dataset_name}")
            return True
        except Exception as e:
            print(f"Error: {e}")
            return None

    def create_custom_dataset_from_files(self, file_path: str):
        """
        Load custom clinical dialogue dataset from CSV/JSON.
        """
        print(f"\n--- Loading Custom Dataset from {file_path} ---")

        if not os.path.exists(file_path):
            print(f"❌ File not found: {file_path}")
            return None

        ext = os.path.splitext(file_path)[1].lower()

        try:
            if ext == '.csv':
                df = pd.read_csv(file_path)
            elif ext == '.json':
                df = pd.read_json(file_path)
            elif ext in ['.xlsx', '.xls']:
                df = pd.read_excel(file_path)
            else:
                print(f"❌ Unsupported file format: {ext}")
                return None

            print(f"✓ Loaded {len(df)} records")
            print(f"✓ Columns: {list(df.columns)}")
            return df
        except Exception as e:
            print(f"Error loading file: {e}")
            return None

# Initialize loader
loader = DatasetLoader()


STEP 1: Loading Actual Clinical Dialogue Datasets


In [7]:
print("\n" + "="*80)
print("Attempting to load available datasets...")
print("="*80)

# Try loading different datasets
datasets_dict = {}

# 1. Try MedDialog
med_dialog = loader.load_medDialog()
if med_dialog:
    datasets_dict['meddialog'] = med_dialog

# 2. Try HealthCareMagic
healthcare_magic = loader.load_healthcaremagic()
if healthcare_magic:
    datasets_dict['healthcaremagic'] = healthcare_magic

# 3. Try MT-Samples
mt_samples = loader.load_mtsamples()
if mt_samples:
    datasets_dict['mtsamples'] = mt_samples

print(f"\n✓ Successfully loaded {len(datasets_dict)} datasets")


Attempting to load available datasets...

--- Loading MedDialog Dataset ---


README.md: 0.00B [00:00, ?B/s]

medical_dialog.py: 0.00B [00:00, ?B/s]

Error loading MedDialog: Dataset scripts are no longer supported, but found medical_dialog.py

Alternative: Visit https://github.com/UCSD-AI4H/Medical-Dialogue-System

--- Loading HealthCareMagic Dataset ---


README.md:   0%|          | 0.00/542 [00:00<?, ?B/s]

data/train-00000-of-00001-5e7cb295b9cff0(…):   0%|          | 0.00/70.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112165 [00:00<?, ? examples/s]

✓ Loaded HealthCareMagic dataset: 112165 conversations

--- Loading MT-Samples Dataset ---
Attempting to load from Hugging Face...


README.md:   0%|          | 0.00/233 [00:00<?, ?B/s]

medDataset_processed.csv:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16407 [00:00<?, ? examples/s]

✓ Loaded 16407 medical dialogues

✓ Successfully loaded 2 datasets


In [13]:
print("\n" + "="*80)
print("[Step 2] LOADING REAL CLINICAL DATASETS")
print("="*80)

# Load HealthCareMagic Dataset
print("\n--- Loading HealthCareMagic Dataset ---")
try:
    healthcaremagic = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k")
    print(f"✓ Loaded HealthCareMagic: {len(healthcaremagic['train'])} conversations")
    hcm_available = True
except Exception as e:
    print(f"Error: {e}")
    hcm_available = False

# Load MT-Samples Dataset
print("\n--- Loading MT-Samples Dataset ---")
try:
    mtsamples = load_dataset("keivalya/MedQuad-MedicalQnADataset")
    print(f"✓ Loaded MT-Samples: {len(mtsamples['train'])} dialogues")
    mts_available = True
except Exception as e:
    print(f"Error: {e}")
    mts_available = False


[Step 2] LOADING REAL CLINICAL DATASETS

--- Loading HealthCareMagic Dataset ---
✓ Loaded HealthCareMagic: 112165 conversations

--- Loading MT-Samples Dataset ---
✓ Loaded MT-Samples: 16407 dialogues


In [14]:
print("\n" + "="*80)
print("[Step 3] EXPLORING DATASET STRUCTURE")
print("="*80)

if hcm_available:
    print("\n--- HealthCareMagic Dataset Structure ---")
    print(f"Features: {healthcaremagic['train'].features}")
    print(f"\nSample conversation:")
    sample = healthcaremagic['train'][0]
    for key, value in sample.items():
        print(f"{key}: {str(value)[:200]}...")

if mts_available:
    print("\n--- MT-Samples Dataset Structure ---")
    print(f"Features: {mtsamples['train'].features}")
    print(f"\nSample dialogue:")
    sample = mtsamples['train'][0]
    for key, value in sample.items():
        print(f"{key}: {str(value)[:200]}...")



[Step 3] EXPLORING DATASET STRUCTURE

--- HealthCareMagic Dataset Structure ---
Features: {'instruction': Value('string'), 'input': Value('string'), 'output': Value('string')}

Sample conversation:
instruction: If you are a doctor, please answer the medical questions based on the patient's description....
input: I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.....
output: Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. In this condition, the most common symptom i...

--- MT-Samples Dataset Structure ---
Features: {'qtype': Value('string'), 'Question': Value('string'), 'Answer': Value('string')}

Sample dialogue:
qtype: susceptibility...
Question: Who is at risk for Lymphocytic Choriomeningitis (LCM)? ?...
Answer: LCMV infections c

In [15]:
print("\n" + "="*80)
print("[Step 4] PREPARING DATA FOR ANNOTATION")
print("="*80)

def prepare_healthcaremagic_data(dataset, num_samples=100):
    """
    Extract dialogues from HealthCareMagic dataset.
    Format: Patient question + Doctor response
    """
    print(f"\nPreparing {num_samples} samples from HealthCareMagic...")

    data = []
    for i in range(min(num_samples, len(dataset['train']))):
        item = dataset['train'][i]

        # Extract instruction (patient question) and output (doctor response)
        if 'instruction' in item and 'output' in item:
            dialogue = f"Patient: {item['instruction']}\n\nDoctor: {item['output']}"

            data.append({
                'dialogue_id': i,
                'dialogue': dialogue,
                'source': 'HealthCareMagic'
            })

    df = pd.DataFrame(data)
    print(f"✓ Prepared {len(df)} HealthCareMagic dialogues")
    return df

def prepare_mtsamples_data(dataset, num_samples=100):
    """
    Extract dialogues from MT-Samples dataset.
    """
    print(f"\nPreparing {num_samples} samples from MT-Samples...")

    data = []
    for i in range(min(num_samples, len(dataset['train']))):
        item = dataset['train'][i]

        # Extract questions and answers
        if 'Answer' in item and 'Question' in item:
            dialogue = f"Patient: {item['Question']}\n\nDoctor: {item['Answer']}"

            data.append({
                'dialogue_id': i + 10000,  # Offset IDs
                'dialogue': dialogue,
                'source': 'MT-Samples'
            })

    df = pd.DataFrame(data)
    print(f"✓ Prepared {len(df)} MT-Samples dialogues")
    return df

# Prepare datasets
all_dialogues = []

if hcm_available:
    hcm_df = prepare_healthcaremagic_data(healthcaremagic, num_samples=150)
    all_dialogues.append(hcm_df)

if mts_available:
    mts_df = prepare_mtsamples_data(mtsamples, num_samples=150)
    all_dialogues.append(mts_df)

# Combine all dialogues
combined_df = pd.concat(all_dialogues, ignore_index=True)
print(f"\n✓ Total dialogues prepared: {len(combined_df)}")

# Display sample
print("\n--- SAMPLE DIALOGUE ---")
print(combined_df.iloc[0]['dialogue'])
print("="*80)



[Step 4] PREPARING DATA FOR ANNOTATION

Preparing 150 samples from HealthCareMagic...
✓ Prepared 150 HealthCareMagic dialogues

Preparing 150 samples from MT-Samples...
✓ Prepared 150 MT-Samples dialogues

✓ Total dialogues prepared: 300

--- SAMPLE DIALOGUE ---
Patient: If you are a doctor, please answer the medical questions based on the patient's description.

Doctor: Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. In this condition, the most common symptom is dizziness or giddiness, which is made worse with movements. Accompanying nausea and vomiting are common. The condition is due to problem in the ear, and improves in a few days on own. Betahistine tablets would help relieve your symptoms. Doing vestibular rehabilitation or adaptation exercises would prevent the recurrence of these symptoms. An ENT evaluation would also help. I hope it helps. Best wishes, Chat Doctor.


In [16]:
print("\n" + "="*80)
print("[Step 5] CREATING ANNOTATION TEMPLATE")
print("="*80)

def create_annotation_template(df, output_file='annotation_template.csv',
                               num_samples=50):
    """
    Create a CSV template for manual annotation.
    """
    # Select subset for annotation
    annotation_df = df.head(num_samples).copy()

    # Add annotation columns
    annotation_df['symptoms'] = ''
    annotation_df['assessment'] = ''
    annotation_df['treatment'] = ''
    annotation_df['annotator_name'] = ''
    annotation_df['notes'] = ''

    # Save to CSV
    annotation_df.to_csv(output_file, index=False)

    print(f"✓ Annotation template created: {output_file}")
    print(f"✓ Number of dialogues to annotate: {len(annotation_df)}")
    print(f"\nColumns in template:")
    print(f"  - dialogue_id: Unique identifier")
    print(f"  - dialogue: The conversation to summarize")
    print(f"  - symptoms: [TO FILL] Patient complaints")
    print(f"  - assessment: [TO FILL] Diagnosis/findings")
    print(f"  - treatment: [TO FILL] Prescribed care")
    print(f"  - annotator_name: [TO FILL] Your name")
    print(f"  - notes: [OPTIONAL] Any additional notes")

    return annotation_df

# Create template for annotation
annotation_template = create_annotation_template(
    combined_df,
    'clinical_annotation_template.csv',
    num_samples=50
)

# Download option for Colab
print("\n💾 To download the template in Google Colab, run:")
print("from google.colab import files")
print("files.download('clinical_annotation_template.csv')")



[Step 5] CREATING ANNOTATION TEMPLATE
✓ Annotation template created: clinical_annotation_template.csv
✓ Number of dialogues to annotate: 50

Columns in template:
  - dialogue_id: Unique identifier
  - dialogue: The conversation to summarize
  - symptoms: [TO FILL] Patient complaints
  - assessment: [TO FILL] Diagnosis/findings
  - treatment: [TO FILL] Prescribed care
  - annotator_name: [TO FILL] Your name
  - notes: [OPTIONAL] Any additional notes

💾 To download the template in Google Colab, run:
from google.colab import files
files.download('clinical_annotation_template.csv')


In [18]:
from google.colab import files
files.download('clinical_annotation_template.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
print("\n" + "="*80)
print("[Step 6] LOADING COMPLETED ANNOTATIONS")
print("="*80)

def load_annotations(file_path='clinical_annotation_template.csv'):
    """
    Load completed annotations and create reference summaries.
    """
    print(f"Loading annotations from: {file_path}")

    try:
        df = pd.read_csv(file_path)

        # Check if annotations are complete
        required_cols = ['symptoms', 'assessment', 'treatment']

        # Filter completed annotations (non-empty)
        completed = df[
            df['symptoms'].notna() &
            (df['symptoms'].str.strip() != '') &
            df['assessment'].notna() &
            (df['assessment'].str.strip() != '') &
            df['treatment'].notna() &
            (df['treatment'].str.strip() != '')
        ].copy()

        print(f"✓ Total annotations: {len(df)}")
        print(f"✓ Completed annotations: {len(completed)}")
        print(f"✓ Incomplete annotations: {len(df) - len(completed)}")

        if len(completed) > 0:
            # Create reference summaries
            completed['reference_summary'] = completed.apply(
                lambda row: f"Symptoms: {row['symptoms']}\n\n"
                           f"Assessment: {row['assessment']}\n\n"
                           f"Treatment: {row['treatment']}",
                axis=1
            )

            print("\n✓ Reference summaries created!")
            return completed
        else:
            print("\n⚠️  No completed annotations found!")
            print("Please fill in the template first.")
            return None

    except FileNotFoundError:
        print(f"\n❌ File not found: {file_path}")
        print("Please annotate the template first and upload it.")
        return None

# For demonstration, create some mock annotations
print("\n--- Creating Mock Annotations for Demonstration ---")

# Take first 20 dialogues and create simple summaries
mock_annotations = combined_df.head(20).copy()

# Simple rule-based summary generation (for demo only)
def create_mock_summary(dialogue):
    """Create a simple mock summary for demonstration."""
    return f"""Symptoms: Patient presents with medical concerns as described in dialogue.

Assessment: Clinical evaluation completed based on patient history and symptoms.

Treatment: Appropriate medical advice and treatment plan provided."""

mock_annotations['symptoms'] = 'Patient medical concerns from dialogue'
mock_annotations['assessment'] = 'Clinical evaluation based on presentation'
mock_annotations['treatment'] = 'Medical advice and treatment provided'
mock_annotations['reference_summary'] = mock_annotations['dialogue'].apply(create_mock_summary)

print(f"✓ Created {len(mock_annotations)} mock annotations for demonstration")


[Step 6] LOADING COMPLETED ANNOTATIONS

--- Creating Mock Annotations for Demonstration ---
✓ Created 20 mock annotations for demonstration


In [20]:


# ============================================================================
# STEP 8: TRAIN THE MODEL
# ============================================================================

print("\n" + "="*80)
print("[Step 7] TRAINING THE CLINICAL SUMMARIZATION MODEL")
print("="*80)

class ClinicalSummarizationModel:
    """
    Fine-tuned model for clinical dialogue summarization.
    """

    def __init__(self, model_name="facebook/bart-base"):
        """
        Initialize model and tokenizer.

        Recommended models:
        - facebook/bart-base (139M params - faster, good for testing)
        - facebook/bart-large (406M params - better quality)
        - google/flan-t5-base (250M params)
        - google/flan-t5-large (780M params)
        """
        print(f"\nInitializing model: {model_name}")
        self.model_name = model_name
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

        # Move to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        print(f"✓ Model loaded")
        print(f"✓ Device: {self.device}")
        print(f"✓ Model parameters: {sum(p.numel() for p in self.model.parameters()):,}")

    def prepare_datasets(self, df, test_size=0.2):
        """
        Prepare train/test datasets from annotated data.
        """
        print(f"\n--- Preparing Datasets ---")
        print(f"Total samples: {len(df)}")

        # Shuffle data
        df = df.sample(frac=1, random_state=42).reset_index(drop=True)

        # Split
        split_idx = int(len(df) * (1 - test_size))
        train_df = df[:split_idx]
        test_df = df[split_idx:]

        print(f"Train samples: {len(train_df)}")
        print(f"Test samples: {len(test_df)}")

        # Convert to HuggingFace Dataset
        train_dataset = Dataset.from_pandas(train_df[['dialogue', 'reference_summary']])
        test_dataset = Dataset.from_pandas(test_df[['dialogue', 'reference_summary']])

        # Tokenize
        def preprocess(examples):
            # Add task prefix
            inputs = ["summarize: " + dialogue for dialogue in examples['dialogue']]
            targets = examples['reference_summary']

            # Tokenize inputs
            model_inputs = self.tokenizer(
                inputs,
                max_length=512,
                truncation=True,
                padding='max_length'
            )

            # Tokenize targets
            labels = self.tokenizer(
                targets,
                max_length=256,
                truncation=True,
                padding='max_length'
            )

            model_inputs['labels'] = labels['input_ids']
            return model_inputs

        print("\nTokenizing datasets...")
        train_dataset = train_dataset.map(preprocess, batched=True, remove_columns=['dialogue', 'reference_summary'])
        test_dataset = test_dataset.map(preprocess, batched=True, remove_columns=['dialogue', 'reference_summary'])

        print("✓ Datasets prepared and tokenized")

        return train_dataset, test_dataset, train_df, test_df

    def train(self, train_dataset, eval_dataset,
              output_dir='./clinical_summarization_model',
              num_epochs=3,
              batch_size=4,
              learning_rate=5e-5):
        """
        Train the model.
        """
        print(f"\n{'='*80}")
        print("STARTING MODEL TRAINING")
        print(f"{'='*80}")

        print(f"\nTraining Configuration:")
        print(f"  Model: {self.model_name}")
        print(f"  Epochs: {num_epochs}")
        print(f"  Batch size: {batch_size}")
        print(f"  Learning rate: {learning_rate}")
        print(f"  Output directory: {output_dir}")

        # Training arguments
        training_args = Seq2SeqTrainingArguments(
            output_dir=output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir='./logs',
            logging_steps=50,
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            predict_with_generate=True,
            generation_max_length=256,
            fp16=torch.cuda.is_available(),
            report_to="none",
            save_total_limit=2,
        )

        # Data collator
        data_collator = DataCollatorForSeq2Seq(
            self.tokenizer,
            model=self.model,
            padding=True
        )

        # Initialize trainer
        trainer = Seq2SeqTrainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=data_collator,
            tokenizer=self.tokenizer,
        )

        # Train
        print("\n🚀 Starting training...")
        print("This may take some time depending on your hardware...\n")

        trainer.train()

        print("\n✓ Training complete!")

        # Save model
        print(f"\nSaving model to {output_dir}...")
        trainer.save_model(output_dir)
        self.tokenizer.save_pretrained(output_dir)
        print("✓ Model saved!")

        return trainer

    def generate_summary(self, dialogue, max_length=256, num_beams=4):
        """
        Generate a summary for a single dialogue.
        """
        # Prepare input
        input_text = "summarize: " + dialogue
        inputs = self.tokenizer(
            input_text,
            return_tensors="pt",
            max_length=512,
            truncation=True
        ).to(self.device)

        # Generate
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_length=max_length,
                num_beams=num_beams,
                length_penalty=2.0,
                early_stopping=True
            )

        # Decode
        summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return summary

    def batch_generate(self, dialogues, max_length=256):
        """
        Generate summaries for multiple dialogues.
        """
        summaries = []
        for dialogue in tqdm(dialogues, desc="Generating summaries"):
            summary = self.generate_summary(dialogue, max_length)
            summaries.append(summary)
        return summaries


[Step 7] TRAINING THE CLINICAL SUMMARIZATION MODEL


In [None]:
print("\n" + "="*80)
print("[Step 8] EXECUTING MODEL TRAINING")
print("="*80)

# Initialize model (using smaller model for faster training)
# Change to "facebook/bart-large" for better quality
model = ClinicalSummarizationModel("facebook/bart-base")

# Prepare datasets using mock annotations
train_dataset, test_dataset, train_df, test_df = model.prepare_datasets(
    mock_annotations,
    test_size=0.2
)

# Train the model
print("\n⚠️  Training will now begin. This may take 15-30 minutes on GPU.")
print("On CPU, this could take several hours.")
print("\nTo skip training for now, comment out the next line.\n")

# UNCOMMENT TO ACTUALLY TRAIN:
trainer = model.train(
    train_dataset,
    test_dataset,
    num_epochs=3,  # Increase for better results
    batch_size=4,   # Decrease if you run out of memory
)


[Step 8] EXECUTING MODEL TRAINING

Initializing model: facebook/bart-base


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

✓ Model loaded
✓ Device: cpu
✓ Model parameters: 139,420,416

--- Preparing Datasets ---
Total samples: 20
Train samples: 16
Test samples: 4

Tokenizing datasets...


Map:   0%|          | 0/16 [00:00<?, ? examples/s]

Map:   0%|          | 0/4 [00:00<?, ? examples/s]

✓ Datasets prepared and tokenized

⚠️  Training will now begin. This may take 15-30 minutes on GPU.
On CPU, this could take several hours.

To skip training for now, comment out the next line.


STARTING MODEL TRAINING

Training Configuration:
  Model: facebook/bart-base
  Epochs: 3
  Batch size: 4
  Learning rate: 5e-05
  Output directory: ./clinical_summarization_model

🚀 Starting training...
This may take some time depending on your hardware...



Epoch,Training Loss,Validation Loss
1,No log,15.063651
2,No log,14.815187


In [22]:
print("\n" + "="*80)
print("[Step 9] GENERATING SUMMARIES")
print("="*80)

# Test on a few examples
test_dialogues = test_df['dialogue'].head(5).tolist()
test_references = test_df['reference_summary'].head(5).tolist()

print("\nGenerating summaries for test examples...")
generated_summaries = model.batch_generate(test_dialogues)

# Display results
for i in range(min(3, len(test_dialogues))):
    print(f"\n{'='*80}")
    print(f"EXAMPLE {i+1}")
    print(f"{'='*80}")
    print(f"\nDIALOGUE:")
    print(test_dialogues[i][:300] + "...")
    print(f"\nGENERATED SUMMARY:")
    print(generated_summaries[i])
    print(f"\nREFERENCE SUMMARY:")
    print(test_references[i])
    print(f"{'='*80}")


[Step 9] GENERATING SUMMARIES

Generating summaries for test examples...


Generating summaries:   0%|          | 0/4 [00:00<?, ?it/s]


EXAMPLE 1

DIALOGUE:
Patient: If you are a doctor, please answer the medical questions based on the patient's description.

Doctor: Hello. It could be a blood collection due to minor injury or a vein rupture which is also common at this age. It is not an emergency, but you should apply compression bandage and warm compr...

GENERATED SUMMARY:
summarize: Patient: If you are a doctor, please answer the medical questions based on the patient's description. _______________________________________________________________________________Doctor: Hello. It could be a blood collection due to minor injury or a vein rupture which is also common at this age. It is not an emergency, but you should apply compression bandage and warm compresses if six hours have past. Furthermore, it should get relieved over the next few days but if it continues to increase or persist then you should see a Doctor who can examine the patient. Take care. _______________________________________________ ________________

In [23]:
print("\n" + "="*80)
print("[Step 10] EVALUATING MODEL PERFORMANCE")
print("="*80)

class Evaluator:
    """Evaluation metrics for summarization."""

    def __init__(self):
        self.rouge_scorer = rouge_scorer.RougeScorer(
            ['rouge1', 'rouge2', 'rougeL'],
            use_stemmer=True
        )

    def compute_rouge(self, predictions, references):
        """Compute ROUGE scores."""
        scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

        for pred, ref in zip(predictions, references):
            result = self.rouge_scorer.score(ref, pred)
            scores['rouge1'].append(result['rouge1'].fmeasure)
            scores['rouge2'].append(result['rouge2'].fmeasure)
            scores['rougeL'].append(result['rougeL'].fmeasure)

        return {k: np.mean(v) for k, v in scores.items()}

    def compute_bertscore(self, predictions, references):
        """Compute BERTScore."""
        P, R, F1 = bert_score(predictions, references, lang='en', verbose=False)

        return {
            'precision': P.mean().item(),
            'recall': R.mean().item(),
            'f1': F1.mean().item()
        }

    def evaluate(self, predictions, references):
        """Run all metrics."""
        print("Computing ROUGE scores...")
        rouge_scores = self.compute_rouge(predictions, references)

        print("Computing BERTScore...")
        bert_scores = self.compute_bertscore(predictions, references)

        # Print results
        print("\n" + "="*80)
        print("EVALUATION RESULTS")
        print("="*80)

        print("\nROUGE Scores:")
        for metric, score in rouge_scores.items():
            print(f"  {metric.upper()}: {score:.4f}")

        print("\nBERTScore:")
        for metric, score in bert_scores.items():
            print(f"  {metric.capitalize()}: {score:.4f}")

        print("="*80)

        return {'rouge': rouge_scores, 'bertscore': bert_scores}

# Evaluate the model
evaluator = Evaluator()

# Generate summaries for all test examples
print("\nGenerating summaries for all test examples...")
all_test_dialogues = test_df['dialogue'].tolist()
all_test_references = test_df['reference_summary'].tolist()
all_generated = model.batch_generate(all_test_dialogues)

# Evaluate
results = evaluator.evaluate(all_generated, all_test_references)



[Step 10] EVALUATING MODEL PERFORMANCE

Generating summaries for all test examples...


Generating summaries:   0%|          | 0/4 [00:00<?, ?it/s]

Computing ROUGE scores...
Computing BERTScore...


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



EVALUATION RESULTS

ROUGE Scores:
  ROUGE1: 0.0955
  ROUGE2: 0.0133
  ROUGEL: 0.0842

BERTScore:
  Precision: 0.7952
  Recall: 0.8431
  F1: 0.8184


In [24]:
print("\n" + "="*80)
print("[Step 11] SAVING RESULTS")
print("="*80)

# Create results dataframe
results_df = test_df.copy()
results_df['generated_summary'] = all_generated

# Save to CSV
results_df.to_csv('model_results.csv', index=False)
print("✓ Results saved to 'model_results.csv'")

# Save evaluation metrics
eval_df = pd.DataFrame({
    'Metric': list(results['rouge'].keys()) + ['BERT-' + k for k in results['bertscore'].keys()],
    'Score': list(results['rouge'].values()) + list(results['bertscore'].values())
})
eval_df.to_csv('evaluation_metrics.csv', index=False)
print("✓ Metrics saved to 'evaluation_metrics.csv'")



[Step 11] SAVING RESULTS
✓ Results saved to 'model_results.csv'
✓ Metrics saved to 'evaluation_metrics.csv'


In [26]:
from google.colab import files
files.download('model_results.csv')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [8]:
print("\n" + "="*80)
print("STEP 2: Data Preprocessing and Exploration")
print("="*80)

class DataPreprocessor:
    """
    Preprocess and standardize clinical dialogue datasets.
    """

    def __init__(self):
        pass

    def explore_dataset(self, dataset):
        """
        Explore and understand dataset structure.
        """
        print("\n--- Dataset Exploration ---")

        if isinstance(dataset, pd.DataFrame):
            print(f"Shape: {dataset.shape}")
            print(f"\nColumns: {list(dataset.columns)}")
            print(f"\nFirst few rows:")
            print(dataset.head())
            print(f"\nData types:")
            print(dataset.dtypes)
            print(f"\nMissing values:")
            print(dataset.isnull().sum())
        else:
            # Hugging Face dataset
            print(f"Dataset type: {type(dataset)}")
            if hasattr(dataset, 'keys'):
                print(f"Splits: {list(dataset.keys())}")
                for split in dataset.keys():
                    print(f"\n{split} split:")
                    print(f"  Size: {len(dataset[split])}")
                    print(f"  Features: {dataset[split].features}")
                    print(f"  Example: {dataset[split][0]}")

    def standardize_format(self, dataset, dialogue_col: str,
                          summary_col: str = None) -> pd.DataFrame:
        """
        Standardize dataset to required format:
        - dialogue: the conversation text
        - reference_summary: ground truth summary (if available)
        """
        print("\n--- Standardizing Dataset Format ---")

        if isinstance(dataset, pd.DataFrame):
            df = dataset.copy()
        else:
            # Convert Hugging Face dataset to DataFrame
            df = pd.DataFrame(dataset['train'])

        # Rename columns to standard format
        standardized = pd.DataFrame()
        standardized['dialogue'] = df[dialogue_col]

        if summary_col and summary_col in df.columns:
            standardized['reference_summary'] = df[summary_col]
        else:
            standardized['reference_summary'] = None  # Will need annotation

        print(f"✓ Standardized {len(standardized)} dialogues")
        return standardized

    def clean_text(self, text: str) -> str:
        """
        Clean and normalize text.
        """
        if pd.isna(text):
            return ""

        # Remove extra whitespace
        text = ' '.join(text.split())

        # Basic cleaning
        text = text.strip()

        return text

    def filter_by_length(self, df: pd.DataFrame,
                        min_words: int = 50,
                        max_words: int = 1000) -> pd.DataFrame:
        """
        Filter dialogues by word count.
        """
        print(f"\n--- Filtering by length ({min_words}-{max_words} words) ---")

        df['word_count'] = df['dialogue'].apply(lambda x: len(str(x).split()))

        filtered = df[
            (df['word_count'] >= min_words) &
            (df['word_count'] <= max_words)
        ].copy()

        print(f"Original: {len(df)} dialogues")
        print(f"Filtered: {len(filtered)} dialogues")
        print(f"Removed: {len(df) - len(filtered)} dialogues")

        return filtered

# Initialize preprocessor
preprocessor = DataPreprocessor()

# If we have loaded datasets, explore them
if datasets_dict:
    first_dataset_name = list(datasets_dict.keys())[0]
    first_dataset = datasets_dict[first_dataset_name]
    print(f"\nExploring {first_dataset_name} dataset:")
    preprocessor.explore_dataset(first_dataset)


STEP 2: Data Preprocessing and Exploration

Exploring healthcaremagic dataset:

--- Dataset Exploration ---
Dataset type: <class 'datasets.dataset_dict.DatasetDict'>
Splits: ['train']

train split:
  Size: 112165
  Features: {'instruction': Value('string'), 'input': Value('string'), 'output': Value('string')}
  Example: {'instruction': "If you are a doctor, please answer the medical questions based on the patient's description.", 'input': 'I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relieved myself, the spinning lessen so i am not sure whether its connected or c