
# MedGemma 1.5 Fine-tuning for COVID-19 Cough Detection
This notebook demonstrates the complete workflow for fine-tuning the MedGemma 1.5 (4B) model to detect COVID-19 from cough audio recordings.

## Workflow:
1.  **Setup**: Install dependencies.
2.  **Data Prep**: Convert audio to Mel Spectrograms.
3.  **EDA**: Visualize data distribution.
4.  **Training**: Fine-tune using LoRA.
5.  **Evaluation**: Compare performance.


## 1. Environment Setup

In [None]:

!pip install -q torch torchaudio transformers librosa matplotlib pillow pandas scikit-learn accelerate peft bitsandbytes
!pip install -q imageio-ffmpeg


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25h

In [None]:

import os
import sys
import glob
import json
import math
import argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import torch
import librosa
from sklearn.model_selection import train_test_split
from transformers import AutoProcessor, AutoModelForVision2Seq, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, accuracy_score

print(f"PyTorch Version: {torch.__version__}")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")



## 2. Authentication
You need to authenticate with Hugging Face to access the gated MedGemma model.


In [None]:

from google.colab import userdata
from huggingface_hub import login

try:
    # Try to get token from Colab secrets
    hf_token = userdata.get('HF_TOKEN')
    login(token=hf_token)
except:
    # Interactive login
    login()



## 3. Data Preparation
Upload your dataset. This script assumes:
- `data/` directory exists.
- Audio files and metadata JSONs are in `data/input`.
- Or you can upload pre-processed images to `data/processed_images`.

The code below defines the preprocessing logic (Audio -> Mel Spectrogram).


In [None]:

# Audio to Mel Spectrogram Conversion Logic
def get_mel_spectrogram(file_path, target_sr=16000, n_mels=128, image_size=(224, 224)):
    try:
        # Load audio using librosa (returns numpy array)
        y, sr = librosa.load(file_path, sr=target_sr)
        
        # Compute Mel Spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=y, 
            sr=sr, 
            n_mels=n_mels
        )
        
        # Convert to dB
        mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
        
        # Normalize to 0-255
        mel_spec_norm = (mel_spec_db - mel_spec_db.min()) / (mel_spec_db.max() - mel_spec_db.min()) * 255
        mel_spec_norm = mel_spec_norm.astype(np.uint8)
        
        # Convert to PIL Image
        img = Image.fromarray(mel_spec_norm, mode='L') # Grayscale
        img = img.convert("RGB") # Convert to RGB
        
        # Resize to target size
        img = img.resize(image_size)
        
        return img
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

def preprocess_dataset(input_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    json_files = glob.glob(os.path.join(input_dir, '*.json'))
    print(f"Found {len(json_files)} files.")
    
    processed_data = []
    for json_file in tqdm(json_files):
        file_id = os.path.splitext(os.path.basename(json_file))[0]
        try:
            with open(json_file, 'r') as f:
                metadata = json.load(f)
        except: continue
            
        status = metadata.get('status')
        if not status: continue
            
        # Find audio
        audio_path = None
        for ext in ['.webm', '.ogg', '.wav']:
            p = os.path.join(input_dir, file_id + ext)
            if os.path.exists(p):
                audio_path = p
                break
        
        if not audio_path: continue
        
        # Convert
        img = get_mel_spectrogram(audio_path)
        if img:
            out_path = os.path.join(output_dir, f"{file_id}.png")
            img.save(out_path)
            processed_data.append({
                'image_path': out_path,
                'label': status
            })
            
    df = pd.DataFrame(processed_data)
    csv_path = os.path.join(os.path.dirname(output_dir), 'processed_metadata.csv')
    df.to_csv(csv_path, index=False)
    print(f"Saved metadata to {csv_path}")
    return csv_path


In [None]:

# Configuration for Data
input_data_dir = "data/input" # Update this path
output_image_dir = "data/processed_images"

# Uncomment to run preprocessing if you have data uploaded
# preprocess_dataset(input_data_dir, output_image_dir)


## 4. Create Splits

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
import os
import argparse

def main():
    # parser = argparse.ArgumentParser() (Replaced by SplitConfig)
    # parser.add_argument("--metadata", type=str, default="data/processed_metadata.csv")
    # parser.add_argument("--output_dir", type=str, default="data")
    # parser.add_argument("--test_size", type=float, default=0.2)
    # parser.add_argument("--seed", type=int, default=42)
    args = SplitConfig()

    if not os.path.exists(args.metadata):
        print(f"Metadata file not found: {args.metadata}")
        return

    df = pd.read_csv(args.metadata)
    print(f"Total samples: {len(df)}")
    print("Label distribution:")
    print(df['label'].value_counts())

    # Stratified split to maintain label distribution
    train_df, test_df = train_test_split(
        df, 
        test_size=args.test_size, 
        random_state=args.seed, 
        stratify=df['label']
    )

    print(f"Train samples: {len(train_df)}")
    print(f"Test samples: {len(test_df)}")

    train_path = os.path.join(args.output_dir, "train_metadata.csv")
    test_path = os.path.join(args.output_dir, "test_metadata.csv")

    train_df.to_csv(train_path, index=False)
    test_df.to_csv(test_path, index=False)

    print(f"Saved splits to {train_path} and {test_path}")

# Run the main logic
if True:
    main()

In [None]:

class SplitConfig:
    metadata = "data/processed_metadata.csv"
    output_dir = "data"
    test_size = 0.2
    seed = 42


## 5. Exploratory Data Analysis (EDA)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import argparse
from PIL import Image
import math

def main():
    # parser = argparse.ArgumentParser() (Replaced by EDAConfig)
    # parser.add_argument("--metadata", type=str, default="data/processed_metadata.csv")
    # parser.add_argument("--output_dir", type=str, default="data/eda_results")
    args = EDAConfig()

    if not os.path.exists(args.metadata):
        print(f"Metadata file not found: {args.metadata}")
        return

    os.makedirs(args.output_dir, exist_ok=True)

    df = pd.read_csv(args.metadata)
    print(f"Loaded {len(df)} records.")

    # 1. Label Distribution
    plt.figure(figsize=(10, 6))
    counts = df['label'].value_counts()
    counts.plot(kind='bar', color=['skyblue', 'orange', 'red'])
    plt.title('Label Distribution')
    plt.xlabel('Label')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    for i, v in enumerate(counts):
        plt.text(i, v, str(v), ha='center', va='bottom')
    
    dist_plot_path = os.path.join(args.output_dir, "label_distribution.png")
    plt.tight_layout()
    plt.savefig(dist_plot_path)
    print(f"Saved distribution plot to {dist_plot_path}")
    plt.close()

    # 2. Sample Images Visualization
    unique_labels = df['label'].unique()
    samples_per_class = 3
    
    # Calculate grid size
    n_cols = samples_per_class
    n_rows = len(unique_labels)
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))
    if n_rows == 1: axes = [axes] # Handle single row case
    
    for i, label in enumerate(unique_labels):
        subset = df[df['label'] == label]
        # Sample with replacement if not enough data, though unlikely here
        samples = subset.sample(min(samples_per_class, len(subset)), random_state=42)
        
        for j, (_, row) in enumerate(samples.iterrows()):
            img_path = row['image_path']
            ax = axes[i][j] if n_rows > 1 else axes[j]
            
            try:
                img = Image.open(img_path)
                ax.imshow(img)
                ax.set_title(f"{label}\n{os.path.basename(img_path)}")
                ax.axis('off')
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                ax.text(0.5, 0.5, "Error loading image", ha='center', va='center')
                ax.axis('off')

    plt.tight_layout()
    samples_plot_path = os.path.join(args.output_dir, "sample_spectrograms.png")
    plt.savefig(samples_plot_path)
    print(f"Saved sample images plot to {samples_plot_path}")
    plt.close()

    print("\n--- Summary ---")
    print(counts)

# Run the main logic
if True:
    main()

In [None]:

class EDAConfig:
    metadata = "data/processed_metadata.csv"
    output_dir = "data/eda_results"


## 6. Fine-tuning MedGemma

In [None]:

class TrainConfig:
    data_dir = "data"
    model_name = "google/medgemma-1.5-4b-it"
    output_dir = "output"
    epochs = 3
    max_steps = 100 # Adjusted for Colab demo
    batch_size = 2
    bnb_4bit = True # Enable 4-bit for Colab T4


In [None]:
import os
import argparse
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
try:
    from transformers import AutoProcessor, Gemma3ForConditionalGeneration as AutoModelForVision2Seq, TrainingArguments, Trainer
    try:
        from transformers import BitsAndBytesConfig
    except ImportError:
        BitsAndBytesConfig = None
except ImportError:
    # If explicit import fails (unlikely given check), fallback to AutoModelForImageTextToText if possible or generic
    from transformers import AutoProcessor, AutoModelForImageTextToText as AutoModelForVision2Seq, TrainingArguments, Trainer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.model_selection import train_test_split
import numpy as np
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

class CoughDataset(Dataset):
    def __init__(self, metadata_file, image_dir, processor, prompt="<image>Classify this cough sound."):
        self.df = pd.read_csv(metadata_file)
        self.image_dir = image_dir
        self.processor = processor
        self.prompt = prompt
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row['image_path']
        label_text = str(row['label']) # Target text
        
        try:
            image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            image = Image.new('RGB', (224, 224)) # Dummy black image
            
        # Prepare inputs for VLM
        # For PaliGemma, prompt should be "detect cough" or similar, or just a question.
        # We process input (text + image) and target (label_text)
        
        inputs = self.processor(
            text=self.prompt, 
            images=image, 
            return_tensors="pt",
            padding="max_length", # Ensure consistent size if needed, usually collator handles it
            truncation=True
        )
        
        # Process target text for labels
        # We need to tokenize the label text. 
        # Note: AutoProcessor for PaliGemma usually has a tokenizer.
        # If the processor handles suffixes, we can use that, but often manual tokenization is safer for "labels"
        
        # Tokenize target
        # For PaliGemma, the model expects standard causal LM labels
        # But we need to make sure the PROMPT is not in the loss if we want purely response training, 
        # or standard CausalLM training on the whole sequence.
        # Simpler approach: Prepend Image Tokens + Prompt, Append Label.
        
        # Using processor's suffix argument is cleaner if available, but let's manual it for generic VLM
        # Actually, PaliGemma processor call: 
        # inputs = processor(text=prompts, images=images, suffix=labels, return_tensors="pt", padding=True)
        
        # Let's try to tokenise label separately if suffix is not supported or to be safe.
        # But wait, we can't create full inputs easily without processor.
        # Let's assume standard processor usage for labels: similar to inputs but with target text.
        
        # A common pattern for VLM finetuning:
        # Inputs: image + "Question: ... Answer:"
        # Labels: "target"
        
        # Actually, let's use the processor to tokenize the label text (as a target)
        # Note: variable length labels might need padding in collator.
        
        # Let's use a simpler strategy:
        # Tokenize prompt + label. 
        # However, dealing with image tokens in standard tokenizer is tricky.
        
        # Strategy: Use `suffix` if using PaliGemmaProcessor.
        if "PaliGemma" in type(self.processor).__name__ or True: # Assume VLM processor handles it
             # Attempt to use 'suffix' for labels if supported, otherwise just return text and let collation handle it?
             # No, dataset must return tensors.
             
             # Re-do with suffix
             # We need to do this carefully.
             pass

        # Optimized VLM Dataset Return:
        # We return the raw text and image, and let a custom data collator handle the tokenization?
        # No, slow.
        
        # Let's try using the processor to generate labels.
        # labels = self.processor.tokenizer(text=label_text, return_tensors="pt").input_ids.squeeze(0)
        
        # Return dict
        input_ids = inputs.input_ids.squeeze(0)
        attention_mask = inputs.attention_mask.squeeze(0)
        pixel_values = inputs.pixel_values.squeeze(0)
        
        # Create labels: same as input_ids but mask out the prompt? 
        # VLM fine-tuning typically provides the full sequence (Prompt + Answer) and masks the prompt in labels.
        # But Processor usually creates [PAD] [IMG] ... [TXT].
        
        # Let's use a simplified approach: just generate the full sequence input_ids for (Prompt + Label) 
        # and create labels from it.
        # But handling image token insertion is processor's job.
        
        # Better approach: 
        # use processor(text=prompt, images=image, suffix=label_text, ...) if supported.
        # If not, use processor(text=prompt + " " + label_text, images=image)
        
        full_text = f"{self.prompt} {label_text}"
        inputs_full = self.processor(text=full_text, images=image, return_tensors="pt", padding="max_length", max_length=512, truncation=True)
        
        input_ids = inputs_full.input_ids.squeeze(0)
        attention_mask = inputs_full.attention_mask.squeeze(0)
        pixel_values = inputs_full.pixel_values.squeeze(0)
        # Extract token_type_ids if available (Gemma3 needs it)
        token_type_ids = inputs_full.get("token_type_ids")
        if token_type_ids is not None:
             token_type_ids = token_type_ids.squeeze(0)
        
        # Labels: copy input_ids
        labels = input_ids.clone()
        # Mask padding in labels
        labels[attention_mask == 0] = -100
        
        item = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "labels": labels
        }
        
        if token_type_ids is not None:
            item["token_type_ids"] = token_type_ids
            
        return item

def main():
    # parser = argparse.ArgumentParser() (Replaced by TrainConfig)
    # parser.add_argument("--data_dir", type=str, default="data")
    # parser.add_argument("--model_name", type=str, default="google/medgemma-1.5-4b-it") # Or google/paligemma-3b-pt-224
    # parser.add_argument("--output_dir", type=str, default="output")
    # parser.add_argument("--epochs", type=int, default=3)
    # parser.add_argument("--max_steps", type=int, default=-1, help="If > 0: set total number of training steps to perform. Overrides epochs.")
    # parser.add_argument("--batch_size", type=int, default=2) # 4b model might need small batch size
    # parser.add_argument("--bnb_4bit", action="store_true", help="Use 4-bit quantization")
    args = TrainConfig()
    
    # Check for hardware
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Prepare data paths
    metadata_path = os.path.join(args.data_dir, "train_metadata.csv")
    image_dir = os.path.join(args.data_dir, "processed_images")
    
    if not os.path.exists(metadata_path):
        print("Processed data not found. Please run preprocess_audio.py first.")
        return

    # Load Processor
    try:
        processor = AutoProcessor.from_pretrained(args.model_name)
    except Exception as e:
        print(f"Error loading processor: {e}")
        return

    dataset = CoughDataset(metadata_path, image_dir, processor=processor, prompt=f"{processor.boi_token}Classify this cough sound as healthy or COVID-19.")
    
    # Split dataset
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")
    
    # Load Model
    # Quantization config
    bnb_config = None
    if args.bnb_4bit:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

    model = AutoModelForVision2Seq.from_pretrained(
        args.model_name,
        device_map="auto" if device == "cuda" else None, # Disable device_map on CPU to avoid meta device issues
        quantization_config=bnb_config,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        trust_remote_code=True
    )
    
    # Apply LoRA
    # Target modules for LoRA in PaliGemma/Gemma usually: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj
    peft_config = LoraConfig(
        # task_type=TaskType.CAUSAL_LM, # Removing task_type to avoid prepare_inputs_for_generation check on base model
        inference_mode=False, 
        r=16, 
        lora_alpha=32, 
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"]
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    
    # Training Arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=4, # To simulate larger batch
        num_train_epochs=args.epochs,
        max_steps=args.max_steps,
        learning_rate=2e-4,
        logging_steps=10,
        save_strategy="epoch",
        eval_strategy="epoch",
        fp16=(device == "cuda"),
        remove_unused_columns=False,
        save_total_limit=2
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        # Default collator usually works for dict of tensors, but checking if padding needed
        # We handled padding in dataset to max_length, so robust.
    )
    
    print("Starting training...")
    trainer.train()
    
    print("Training complete. Saving model...")
    model.save_pretrained(os.path.join(args.output_dir, "final_model"))
    processor.save_pretrained(os.path.join(args.output_dir, "final_model"))

# Run the main logic
if True:
    main()

## 7. Evaluation

In [None]:

class EvalConfig:
    model_path = "output/final_model"
    test_data = "data/test_metadata.csv"
    output_file = "evaluation_results.json"
    bnb_4bit = True


In [None]:
import torch
from transformers import AutoProcessor
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, classification_report, accuracy_score
import argparse
import os
from tqdm import tqdm
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Fallback import for older transformers versions
# Fallback import for older transformers versions
try:
    from transformers import AutoProcessor, Gemma3ForConditionalGeneration as AutoModelForVision2Seq
except ImportError:
    from transformers import AutoProcessor, AutoModelForImageTextToText as AutoModelForVision2Seq

def evaluate(model, processor, test_loader, device):
    model.eval()
    all_labels = []
    all_preds = []
    all_probs = []
    
    # Target class for binary classification logic (adjust as needed)
    # Dataset has: healthy, symptomatic, COVID-19
    # We essentially want to detect COVID-19 vs Rest (or Symptomatic vs Healthy?)
    # User objective implies COVID detection.
    # Let's treat "COVID-19" as Positive (1) and others as Negative (0) for AUC-ROC
    
    target_class = "COVID-19"
    
    print(f"Evaluating for Target Class: {target_class}")
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            # Batch preparation
            # Note: Custom loop might be needed if not using Trainer's prediction loop
            pass

    # Since VLM generation is slow and complex to batched score without a specific head,
    # we will use a generation-based approach for the "answer" and map it to labels.
    # OR we can check the probability of the string "COVID-19" vs "healthy".
    
    # Efficient approach for 4B model:
    # 1. Generate text.
    # 2. Parse text ("detected covid", "healthy", etc.)
    # 3. For AUC-ROC, we need probabilities. 
    #    We can look at the logits of the first token of the answer "COVID-19" vs "healthy".
    
    return {}

def main():
    # parser = argparse.ArgumentParser() (Replaced by EvalConfig)
    # parser.add_argument("--model_path", type=str, required=True, help="Path to model or huggingface repo")
    # parser.add_argument("--test_data", type=str, default="data/test_metadata.csv")
    # parser.add_argument("--output_file", type=str, default="evaluation_results.json")
    # parser.add_argument("--bnb_4bit", action="store_true", help="Use 4-bit quantization")
    args = EvalConfig()
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Load metadata
    df = pd.read_csv(args.test_data)
    print(f"Loaded {len(df)} test samples.")
    
    # Load Model & Processor
    print(f"Loading model from {args.model_path}...")
    try:
        processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True)
        # Handle quantization if requested (crucial for 4B on consumer GPU)
        quantization_config = None
        if args.bnb_4bit:
            from transformers import BitsAndBytesConfig
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16
            )
        
        model = AutoModelForVision2Seq.from_pretrained(
            args.model_path,
            device_map="auto",
            quantization_config=quantization_config,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            trust_remote_code=True
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # Evaluation Loop
    predictions = []
    ground_truth = []
    # For AUC, we need scores. 
    # Strategy: Compute perplexity or logits for "COVID-19" vs "healthy".
    # Simplified: Generate response and check containment. (Gives binary output, AUC not ideal but possible if binary)
    # Better: Get probability of token "COVID" vs "healthy".
    
    # Let's stick to generation for F1/Accuracy/Sensitivity/Specificity first as it's robust.
    # AUC needs scores. If we can't get scores easily from VLM API, we might skip or approximate.
    # Approximation: 1.0 if correct class generated, 0.0 otherwise? No, that's just accuracy.
    
    # Let's perform generation.
    
    print("Starting generation...")
    # Iterate sample by sample (slow but safe for VLM memory)
    # DEBUG: Limit to 5 samples to test speed/correctness
    df_subset = df.head(5)
    
    for _, row in tqdm(df_subset.iterrows(), total=len(df_subset)):
        image_path = row['image_path']
        true_label = row['label'] # "COVID-19", "healthy", "symptomatic"
        
        # Determine binary ground truth for COVID detection
        is_covid = 1 if true_label == "COVID-19" else 0
        ground_truth.append(is_covid)
        
        try:
            image = Image.open(image_path).convert("RGB")
            # Use processor.boi_token if available (Gemma3), otherwise <image>
            boi = getattr(processor, "boi_token", "<image>")
            prompt = f"{boi}Classify this cough sound as healthy or COVID-19."
            inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
            
            # Generate
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=20)
                
            decoded = processor.batch_decode(outputs, skip_special_tokens=True)[0]
            # Output will contain the prompt too usually? 
            # MedGemma might handle it differently.
            # Parse output.
            
            # Heuristic parsing
            response = decoded.lower()
            if "covid-19" in response or "positive" in response:
                pred_score = 1 # Predicted COVID
            elif "healthy" in response or "negative" in response:
                pred_score = 0
            else:
                pred_score = 0 # Default to negative if unclear? Or fail?
            
            predictions.append(pred_score)
            
        except Exception as e:
            print(f"Error processing {image_path}: {e}")
            predictions.append(0) # Fallback

    # Metrics
    print("\n--- Evaluation Report ---")
    
    # Binary Classification Metrics
    y_true = np.array(ground_truth)
    y_pred = np.array(predictions)
    
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # Confusion Matrix
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # AUC-ROC
    # Note: Using binary predictions (0/1) for AUC is suboptimal (gives trapezoidal approx). 
    # But sufficient for high-level "quantitative" comparison if scores are unavailable.
    roc_auc = roc_auc_score(y_true, y_pred)
    
    print(f"Accuracy: {acc:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"Sensitivity (Recall): {sensitivity:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"AUC-ROC: {roc_auc:.4f}")
    
    report = {
        "accuracy": acc,
        "f1_score": f1,
        "sensitivity": sensitivity,
        "specificity": specificity,
        "auc_roc": roc_auc,
        "confusion_matrix": {"tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp)}
    }
    
    # Save results
    import json
    with open(args.output_file, 'w') as f:
        json.dump(report, f, indent=4)
    print(f"Results saved to {args.output_file}")

# Run the main logic
if True:
    main()

## 8. Results Comparison

In [None]:
import json
import pandas as pd
import argparse
import os

def load_results(filepath):
    if not os.path.exists(filepath):
        print(f"File not found: {filepath}")
        return None
    with open(filepath, 'r') as f:
        return json.load(f)

def main():
    # parser = argparse.ArgumentParser() (Replaced by CompareConfig)
    # parser.add_argument("--baseline", type=str, default="baseline_results.json")
    # parser.add_argument("--finetuned", type=str, default="finetuned_results.json")
    # parser.add_argument("--output_md", type=str, default="comparison_report.md")
    args = CompareConfig()

    baseline = load_results(args.baseline)
    finetuned = load_results(args.finetuned)

    if not baseline:
        print("Baseline results missing.")
        return
    
    if not finetuned:
         print("Finetuned results missing (training might be incomplete).")
         return

    # Metrics to compare
    metrics = ["accuracy", "f1_score", "sensitivity", "specificity", "auc_roc"]
    
    data = []
    for m in metrics:
        b_val = baseline.get(m, 0)
        f_val = finetuned.get(m, 0)
        diff = f_val - b_val
        data.append({
            "Metric": m, 
            "Baseline": b_val, 
            "Finetuned": f_val, 
            "Difference": diff
        })
        
    df = pd.DataFrame(data)
    
    # Generate Markdown Report
    md = "# Model Comparison Report\n\n"
    md += "## Metrics Comparison\n\n"
    md += "| Metric | Baseline | Fine-tuned | Difference |\n"
    md += "| :--- | :--- | :--- | :--- |\n"
    
    for _, row in df.iterrows():
        diff_str = f"{row['Difference']:.4f}"
        if row['Difference'] > 0:
            diff_str = f"+{diff_str}"
        
        md += f"| **{row['Metric']}** | {row['Baseline']:.4f} | {row['Finetuned']:.4f} | {diff_str} |\n"
        
    md += "\n## Confusion Matrix Comparison\n\n"
    
    def format_cm(cm):
        return (f"TN: {cm.get('tn',0)}, FP: {cm.get('fp',0)}, "
                f"FN: {cm.get('fn',0)}, TP: {cm.get('tp',0)}")

    md += f"- **Baseline**: {format_cm(baseline.get('confusion_matrix', {}))}\n"
    md += f"- **Fine-tuned**: {format_cm(finetuned.get('confusion_matrix', {}))}\n"
    
    md += "\n## Conclusion\n\n"
    if finetuned.get('sensitivity', 0) > baseline.get('sensitivity', 0):
        md += "The fine-tuned model shows improvement in Sensitivity, indicating better detection of COVID-19 cases.\n"
    else:
        md += "The fine-tuned model did not improve Sensitivity. Further tuning or more data might be needed.\n"

    with open(args.output_md, 'w') as f:
        f.write(md)
        
    print(f"Comparison report saved to {args.output_md}")
    print(df)

# Run the main logic
if True:
    main()

In [None]:

class CompareConfig:
    baseline = "baseline_results.json" 
    finetuned = "evaluation_results.json"
    output_md = "comparison_report.md"
