In [None]:
import numpy as np
import pandas as pd
import torch
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
from datasets import Dataset
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support
from tqdm.auto import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoModel,
    AutoTokenizer,
    TrainingArguments,
    DataCollatorWithPadding,
    Trainer,
)

In [None]:
df = pd.read_pickle('my_project_data.pkl')
print("DataFrame loaded from file")

In [None]:
# 1. Clean the data: Drop rows where the target flag is missing (empty/NaN)
df_clean = df.dropna(subset=['respiratory_depression']).copy()

print(f"Original DataFrame size: {len(df)}")
print(f"Cleaned DataFrame size (rows with a known flag): {len(df_clean)}")
print("\nTarget Class Distribution:")
print(df_clean['respiratory_depression'].value_counts())

# 2. Define X and y
# X will be the list of clinical narratives
X = df_clean['Clinical_Narrative'].tolist()
# y will be the list of target flags (converted to integers)
y = df_clean['respiratory_depression'].astype(int).tolist()

In [None]:
# First split: Separate out the Test Set (20%) and keep the rest for Training/Validation
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y,
    test_size=0.20,  # 20% for testing
    random_state=42, # for reproducibility
    stratify=y       # crucial for imbalanced classification
)

# Second split: Separate the Training Set (80% of remaining) from the Validation Set (20% of remaining)
# Since X_test is 20% of total, we need to split X_train_val (which is 80% of total)
# 0.20 / 0.80 = 0.25, so we split the remaining 80% into 75% train and 25% val to get 60/20 total.
X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val,
    test_size=0.25,  # 25% of the train_val set = 20% of the total dataset
    random_state=42,
    stratify=y_train_val
)

print("-" * 50)
print("Dataset Splits Summary:")
print(f"Train set size: {len(X_train)} (approx 60% of total)")
print(f"Validation set size: {len(X_val)} (approx 20% of total)")
print(f"Test set size: {len(X_test)} (20% of total)")
print("-" * 50)

## Tokenization and model fine tuning

In [None]:
# --- GLOBAL HYPERPARAMETERS ---
NUM_LABELS = 2
LEARNING_RATE = 2e-5
BATCH_SIZE = 8
NUM_EPOCHS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# List of models to run the experiment on
MODEL_NAMES = [
    "abhinand/MedEmbed-large-v0.1",
    "emilyalsentzer/Bio_ClinicalBERT",
    "medicalai/ClinicalBERT"
]

# --- 1. HELPER FUNCTION: METRICS ---
def compute_metrics(p):
    """
    Computes AUROC, F1-Score, and other standard metrics for binary classification.
    """
    preds_logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    labels = p.label_ids
    
    # Get predicted class (0 or 1)
    preds = np.argmax(preds_logits, axis=1)
    
    # Get probabilities for the positive class (required for AUROC)
    probas = torch.nn.functional.softmax(torch.tensor(preds_logits), dim=-1).numpy()[:, 1]
    
    # Calculate key metrics
    try:
        # AUROC (Area Under the Receiver Operating Characteristic Curve)
        roc_auc = roc_auc_score(labels, probas)
    except ValueError:
        # Handle case where only one class is present
        roc_auc = 0.0
        
    # Precision, Recall, F1 for the positive class (1)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='binary', pos_label=1, zero_division=0 # Set zero_division=0 for safe calculations
    )
    
    return {
        'roc_auc': roc_auc,
        'f1_score': f1,
        'precision': precision,
        'recall': recall,
        'accuracy': (preds == labels).mean(),
    }

def tokenize(model_name, train_X, train_y, val_X, val_y, test_X, test_y):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # B. Define Tokenization Parameters
    MAX_LENGTH = 512 # Max length for standard BERT, consider increasing this if notes are very long
    TRUNCATION = True
    PADDING = 'max_length'
    
    # C. Function to Encode a Split
    def encode_data(texts, labels):
        """Encodes text data and converts it into a PyTorch TensorDataset."""
        # Tokenize the texts
        encodings = tokenizer(
            texts,
            max_length=MAX_LENGTH,
            truncation=TRUNCATION,
            padding=PADDING,
            return_tensors='pt' # Return PyTorch tensors
        )
    
        # Convert labels to PyTorch tensor
        labels_tensor = torch.tensor(labels)
    
        # Create the TensorDataset
        # The keys 'input_ids', 'attention_mask', and 'token_type_ids' are standard for BERT
        dataset = TensorDataset(
            encodings['input_ids'],
            encodings['attention_mask'],
            labels_tensor
        )
        return dataset
    
    # D. Encode All Splits
    train_dataset = encode_data(train_X, train_y)
    val_dataset = encode_data(val_X, val_y)
    test_dataset = encode_data(test_X, test_y)
    
    print("-" * 50)
    print(f"Train Dataset (input_ids, attention_mask, labels): {train_dataset.tensors[0].shape}, {train_dataset.tensors[1].shape}, {train_dataset.tensors[2].shape}")
    print(f"Validation Dataset: {val_dataset.tensors[0].shape}, {val_dataset.tensors[1].shape}, {val_dataset.tensors[2].shape}")
    print(f"Test Dataset: {test_dataset.tensors[0].shape}, {test_dataset.tensors[1].shape}, {test_dataset.tensors[2].shape}")
    print("-" * 50)

    train_dataset_hf = convert_tensor_dataset_to_hf_dataset(train_dataset)
    val_dataset_hf = convert_tensor_dataset_to_hf_dataset(val_dataset)
    test_dataset_hf = convert_tensor_dataset_to_hf_dataset(test_dataset)

    return train_dataset_hf, val_dataset_hf, test_dataset_hf

# --- 2. HELPER FUNCTION: DATA CONVERSION ---
# NOTE: This assumes 'train_dataset', 'val_dataset', and 'test_dataset' 
# are available as PyTorch TensorDataset objects in the scope.
def convert_tensor_dataset_to_hf_dataset(tensor_dataset):
    """Converts a PyTorch TensorDataset to a Hugging Face Dataset."""
    
    # Extract the tensors: (input_ids, attention_mask, labels)
    input_ids = tensor_dataset.tensors[0].numpy()
    attention_mask = tensor_dataset.tensors[1].numpy()
    labels = tensor_dataset.tensors[2].numpy()

    data_dict = {
        'input_ids': input_ids.tolist(),
        'attention_mask': attention_mask.tolist(),
        'labels': labels.tolist()
    }
    
    hf_dataset = Dataset.from_dict(data_dict)
    
    # Ensure correct torch tensor types for the model input
    def format_tensors(example):
        example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
        example['attention_mask'] = torch.tensor(example['attention_mask'], dtype=torch.long)
        example['labels'] = torch.tensor(example['labels'], dtype=torch.long)
        return example
        
    hf_dataset = hf_dataset.map(format_tensors, batched=True)
    
    return hf_dataset


# --- 3. MAIN EXPERIMENT FUNCTION ---
def run_model_experiment(model_name, train_data, val_data, test_data, device):
    """
    Initializes, fine-tunes, and evaluates a specific Hugging Face model.
    """
    print("\n" + "="*80)
    print(f"STARTING EXPERIMENT FOR MODEL: {model_name}")
    print("="*80)

    # A. Model Initialization
    # We use AutoModelForSequenceClassification to load BERT with a classification head
    try:
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=NUM_LABELS,
            ignore_mismatched_sizes=True, # Ignore mismatched sizes if classification head changes
            trust_remote_code=True
        )
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return {"Model": model_name, "Error": "Failed to load model"}

    model.to(device)
    
    # B. Define Training Arguments (Dynamic Output Directory)
    # Create a clean folder name for results
    model_safe_name = model_name.replace("/", "__").replace("-", "_")
    output_dir = f'./results_{model_safe_name}'
    os.makedirs(output_dir, exist_ok=True)
    
    training_args = TrainingArguments(
        output_dir=output_dir,                           # Dynamic output directory
        num_train_epochs=NUM_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE * 2,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir=f'./logs_{model_safe_name}',         # Dynamic logging directory
        logging_steps=500,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="roc_auc",
        learning_rate=LEARNING_RATE,
        save_total_limit=1,                              # Limit saves to only the best model
        report_to="none",                                 # Prevents need for external logging tools like WandB
        gradient_accumulation_steps=2
    )

    # C. Setup and Run Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=val_data,
        compute_metrics=compute_metrics,
    )

    print("Starting Fine-Tuning...")
    trainer.train()
    print("Fine-tuning complete. Best model loaded.")

    # D. Final Evaluation on Test Set
    print("\n" + "="*50)
    print("Final Evaluation on the Test Set")
    print("="*50)
    test_results = trainer.evaluate(test_data)
    
    # Prepare results for comparison
    final_results = {
        'Model': model_name,
        'Test_Accuracy': test_results.get('eval_accuracy'),
        'Test_F1': test_results.get('eval_f1_score'),
        'Test_AUROC': test_results.get('eval_roc_auc'),
    }

    print("Test Results:")
    for key, value in final_results.items():
        if isinstance(value, float):
            print(f"{key}: {value:.4f}")
        else:
            print(f"{key}: {value}")
            
    return final_results

In [None]:
COMPARISON_RESULTS = []

for model_name in MODEL_NAMES:
    # Pass the converted Hugging Face datasets
    train_dataset_hf, val_dataset_hf, test_dataset_hf = tokenize(model_name, X_train, y_train, X_val, y_val, X_test, y_test)
    results = run_model_experiment(model_name, 
                                   train_dataset_hf, 
                                   val_dataset_hf, 
                                   test_dataset_hf, 
                                   DEVICE)
    COMPARISON_RESULTS.append(results)

# Print Summary Table (Placeholder structure for final summary)
print("\n" + "#"*80)
print("FINAL MODEL COMPARISON SUMMARY")
print("#"*80)
for res in COMPARISON_RESULTS:
    print(f"Model: {res.get('Model'):<40} | Accuracy: {res.get('Test_Accuracy', 'N/A'):.4f} | F1: {res.get('Test_F1', 'N/A'):.4f}")

## preserving embeddings for future interpretability studies (ongoing)

In [None]:
def extract_and_save_embeddings(model_name, dataset_hf, device, batch_size, output_file_prefix):
    """
    Loads the base encoder model, extracts the [CLS] token embeddings for a dataset,
    and saves the embeddings and labels to a NumPy .npy file.
    
    Returns: The path to the saved file.
    """
    print("\n" + "="*80)
    print(f"STARTING EMBEDDING EXTRACTION FOR MODEL: {model_name}")
    print("="*80)

    # A. Load the base model and tokenizer
    model = AutoModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name) # Load the tokenizer 
    
    model.to(device)
    model.eval() 

    # B. Setup Data Collator and DataLoader
    # The collator groups samples into batches and converts them to tensors
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
    
    # We use a standard DataLoader, passing the collator to handle the batching
    data_loader = DataLoader(
        dataset_hf, 
        batch_size=batch_size, 
        shuffle=False, 
        collate_fn=data_collator # CRITICAL CHANGE
    )

    all_embeddings = []
    all_labels = []

    # C. Inference Loop
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Extracting Embeddings"):
            
            # Inputs to the model (input_ids and attention_mask are now tensors)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            # Labels are stored separately
            labels = batch['labels'].to(device)
            # --------------------------------------------------------------------------

            # Pass through the model
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            
            # Extract the [CLS] token embedding (first token, index 0)
            # Shape (batch_size, hidden_size)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            
            all_embeddings.append(cls_embeddings)
            all_labels.append(labels.cpu().numpy())

    # D. Aggregate and Save
    final_embeddings = np.concatenate(all_embeddings, axis=0)
    final_labels = np.concatenate(all_labels, axis=0)

    # Save to file
    output_path = f"{output_file_prefix}_embeddings_features.npy"
    # Ensure all required numpy imports are present in the script
    np.save(output_path, {'X': final_embeddings, 'y': final_labels}) 
    
    print(f"Extraction complete. Saved data to: {output_path}")
    print(f"Features shape (Embeddings): {final_embeddings.shape}")
    print(f"Target shape (Labels): {final_labels.shape}")
    
    return output_path

In [None]:
COMPARISON_RESULTS = []
EMBEDDING_FILE_PATHS = [] # Store the paths to the saved embedding files

for model_name in MODEL_NAMES:
    
    # 2. Extract embeddings from the TEST set using the base model
    model_safe_name = model_name.split('/')[-1]
    
    embedding_file = extract_and_save_embeddings(
        model_name=model_name,
        dataset_hf=test_dataset_hf,
        device=DEVICE,
        batch_size=BATCH_SIZE * 2, # Can use a larger batch for inference
        output_file_prefix=model_safe_name
    )
    EMBEDDING_FILE_PATHS.append(embedding_file)

print("\nEmbedding extraction finished. Proceed to train separate XGBoost models.")
print(f"Saved Files: {EMBEDDING_FILE_PATHS}")