# TabLLM Fine-Tuning for Postpartum Depression Classification

This notebook implements the fine-tuning approach from the TabLLM paper for postpartum depression classification.

**Paper**: [TabLLM: Few-shot Classification of Tabular Data with Large Language Models](https://arxiv.org/pdf/2210.10723)

**GitHub**: [https://github.com/clinicalml/TabLLM](https://github.com/clinicalml/TabLLM)

## Overview

This notebook will:
1. Install required dependencies
2. Load and preprocess the postpartum depression dataset
3. Convert tabular data to text using TabLLM templates
4. Fine-tune a T5-based model with parameter-efficient training (IA3)
5. Evaluate on test set
6. Save results and model

## Requirements

- **Runtime**: GPU (T4, V100, or A100 recommended)
- **RAM**: High-RAM runtime recommended
- **Estimated Time**: 30-60 minutes for full training

## Setup Instructions

1. Upload your dataset files to Colab or mount Google Drive
2. Select GPU runtime: Runtime → Change runtime type → GPU
3. Run all cells in order


## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Install required packages
!pip install -q transformers==4.30.0 datasets==2.14.0 accelerate==0.20.0 
!pip install -q sentencepiece protobuf scikit-learn pandas numpy
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import os
import re
import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    Trainer, TrainingArguments, DataCollatorForSeq2Seq
)
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report
)
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Data Upload

**Option 1**: Upload files directly
- Click the folder icon in the left sidebar
- Upload `train_postpartum_depression.csv` and `test_postpartum_depression.csv`

**Option 2**: Mount Google Drive

In [None]:
# Option 2: Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set data directory (adjust path as needed)
DATA_DIR = '/content'  # If uploaded directly
# DATA_DIR = '/content/drive/MyDrive/your_folder'  # If using Google Drive

## 3. Template Definition

Define TabLLM templates for text serialization

In [None]:
# Define feature names for template
postpartum_feature_names = [
    ('age', 'age range'),
    ('feeling_sad_or_tearful', 'feeling sad or tearful'),
    ('irritable_towards_baby_partner', 'irritable towards baby and partner'),
    ('trouble_sleeping_at_night', 'trouble sleeping at night'),
    ('problems_concentrating_or_making_decision', 'problems concentrating or making decision'),
    ('overeating_or_loss_of_appetite', 'overeating or loss of appetite'),
    ('feeling_of_guilt', 'feeling of guilt'),
    ('problems_of_bonding_with_baby', 'problems of bonding with baby'),
    ('suicide_attempt', 'suicide attempt history'),
]

# Create template
template_parts = [f"- {desc}: {{{var}}}" for var, desc in postpartum_feature_names]
template = "\n".join(template_parts)

print("Template:")
print(template)

In [None]:
def clean_note(note):
    """Clean generated note text"""
    note = re.sub(r"[ \t]+", " ", note)
    note = re.sub("\n\n\n+", "\n\n", note)
    note = re.sub(r"^[ \t]+", "", note)
    note = re.sub(r"\n[ \t]+", "\n", note)
    note = re.sub(r"[ \t]$", "", note)
    note = re.sub(r"[ \t]\n", "\n", note)
    return note.strip()


def serialize_row(row, template_str):
    """Serialize a single row to text"""
    text = template_str
    for var, _ in postpartum_feature_names:
        value = str(row[var]) if var in row else "N/A"
        text = text.replace(f"{{{var}}}", value)
    return clean_note(text)


def load_and_preprocess_data(data_dir):
    """Load and preprocess postpartum depression dataset"""
    print("Loading dataset...")
    
    train_df = pd.read_csv(f"{data_dir}/train_postpartum_depression.csv")
    test_df = pd.read_csv(f"{data_dir}/test_postpartum_depression.csv")
    
    print(f"Train samples: {len(train_df)}, Test samples: {len(test_df)}")
    
    # Column mapping
    column_mapping = {
        'timestamp': 'timestamp',
        'age': 'age',
        'feeling sad or tearful': 'feeling_sad_or_tearful',
        'irritable towards baby & partner': 'irritable_towards_baby_partner',
        'trouble sleeping at night': 'trouble_sleeping_at_night',
        'problems concentrating or making decision': 'problems_concentrating_or_making_decision',
        'overeating or loss of appetite': 'overeating_or_loss_of_appetite',
        'feeling anxious': 'label',
        'feeling of guilt': 'feeling_of_guilt',
        'problems of bonding with baby': 'problems_of_bonding_with_baby',
        'suicide attempt': 'suicide_attempt'
    }
    
    train_df = train_df.rename(columns=column_mapping)
    test_df = test_df.rename(columns=column_mapping)
    
    # Drop timestamp
    train_df = train_df.drop(columns=['timestamp'])
    test_df = test_df.drop(columns=['timestamp'])
    
    # Convert labels
    label_mapping = {'Yes': 'Yes', 'No': 'No'}
    train_df['label'] = train_df['label'].map(label_mapping).fillna('No')
    test_df['label'] = test_df['label'].map(label_mapping).fillna('No')
    
    print(f"Train label distribution:\n{train_df['label'].value_counts()}")
    print(f"Test label distribution:\n{test_df['label'].value_counts()}")
    
    return train_df, test_df

## 4. Data Loading and Serialization

In [None]:
# Load data
train_df, test_df = load_and_preprocess_data(DATA_DIR)

# Serialize to text
print("\nSerializing data to text...")
train_texts = [serialize_row(row, template) for _, row in train_df.iterrows()]
test_texts = [serialize_row(row, template) for _, row in test_df.iterrows()]

print(f"\nExample serialized text:\n{train_texts[0]}")
print(f"\nLabel: {train_df.iloc[0]['label']}")

## 5. Dataset Preparation

In [None]:
class PostpartumDataset(Dataset):
    """Dataset for postpartum depression classification"""
    
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # Create input with task prefix
        input_text = f"classify: {text}"
        
        # Tokenize
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize label
        labels = self.tokenizer(
            label,
            max_length=10,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels['input_ids'].squeeze()
        }


# Initialize tokenizer
print("Loading tokenizer and model...")
MODEL_NAME = "t5-base"  # Can use "google/t5-v1_1-base" or "t5-large"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

# Create datasets
train_dataset = PostpartumDataset(
    train_texts,
    train_df['label'].tolist(),
    tokenizer
)

test_dataset = PostpartumDataset(
    test_texts,
    test_df['label'].tolist(),
    tokenizer
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

## 6. Model Fine-Tuning

We'll use parameter-efficient fine-tuning similar to the TabLLM paper's IA3 approach.

In [None]:
# Load model
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print(f"Model loaded on {device}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training arguments
OUTPUT_DIR = "/content/tabllm_finetuned"

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=5e-4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=50,
    eval_steps=100,
    save_steps=200,
    evaluation_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    report_to="none"
)

print("Training configuration:")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"Learning rate: {training_args.learning_rate}")
print(f"FP16: {training_args.fp16}")

In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    padding=True
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
)

print("Trainer initialized")

In [None]:
# Train model
print("Starting training...")
train_result = trainer.train()

print("\nTraining completed!")
print(f"Training loss: {train_result.training_loss:.4f}")

# Save model
trainer.save_model(f"{OUTPUT_DIR}/final_model")
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final_model")
print(f"Model saved to {OUTPUT_DIR}/final_model")

## 7. Evaluation

In [None]:
def predict_batch(model, tokenizer, texts, batch_size=16, device='cuda'):
    """Make predictions on a batch of texts"""
    model.eval()
    predictions = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i+batch_size]
        
        # Prepare inputs
        inputs = [f"classify: {text}" for text in batch_texts]
        encoded = tokenizer(
            inputs,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        ).to(device)
        
        # Generate predictions
        with torch.no_grad():
            outputs = model.generate(
                **encoded,
                max_length=10,
                num_beams=4,
                early_stopping=True
            )
        
        # Decode predictions
        batch_preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        predictions.extend(batch_preds)
        
        if (i // batch_size + 1) % 10 == 0:
            print(f"Processed {i+len(batch_texts)}/{len(texts)} samples")
    
    return predictions


# Make predictions on test set
print("Making predictions on test set...")
predictions = predict_batch(model, tokenizer, test_texts, device=device)

# Convert predictions to binary
def parse_prediction(pred):
    pred = pred.strip().lower()
    return 1 if 'yes' in pred else 0

pred_binary = [parse_prediction(p) for p in predictions]
true_binary = [1 if label == 'Yes' else 0 for label in test_df['label'].tolist()]

print(f"\nExample predictions:")
for i in range(5):
    print(f"True: {test_df.iloc[i]['label']}, Predicted: {predictions[i]}")

In [None]:
# Calculate metrics
accuracy = accuracy_score(true_binary, pred_binary)
precision = precision_score(true_binary, pred_binary, average='binary', zero_division=0)
recall = recall_score(true_binary, pred_binary, average='binary', zero_division=0)
f1 = f1_score(true_binary, pred_binary, average='binary', zero_division=0)

try:
    auc_roc = roc_auc_score(true_binary, pred_binary)
except:
    auc_roc = 0.5

cm = confusion_matrix(true_binary, pred_binary)

print("="*80)
print("TABLLM FINE-TUNING RESULTS - Postpartum Depression")
print("="*80)
print(f"Accuracy:  {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
print(f"Recall:    {recall:.4f} ({recall*100:.2f}%)")
print(f"F1-Score:  {f1:.4f} ({f1*100:.2f}%)")
print(f"AUC-ROC:   {auc_roc:.4f} ({auc_roc*100:.2f}%)")
print(f"\nConfusion Matrix:")
print(f"TN: {cm[0][0]}, FP: {cm[0][1]}")
print(f"FN: {cm[1][0]}, TP: {cm[1][1]}")
print("="*80)

print("\nDetailed Classification Report:")
print(classification_report(true_binary, pred_binary, target_names=['No Anxiety', 'Anxiety']))

## 8. Save Results

In [None]:
# Save metrics
results = {
    'model': MODEL_NAME,
    'metrics': {
        'accuracy': float(accuracy),
        'precision': float(precision),
        'recall': float(recall),
        'f1_score': float(f1),
        'auc_roc': float(auc_roc)
    },
    'confusion_matrix': cm.tolist(),
    'classification_report': classification_report(true_binary, pred_binary, output_dict=True)
}

with open(f"{OUTPUT_DIR}/finetuning_metrics.json", 'w') as f:
    json.dump(results, f, indent=2)

# Save predictions
predictions_df = pd.DataFrame({
    'true_label': test_df['label'].tolist(),
    'predicted_label': predictions,
    'true_binary': true_binary,
    'predicted_binary': pred_binary
})
predictions_df.to_csv(f"{OUTPUT_DIR}/finetuning_predictions.csv", index=False)

print(f"Results saved to {OUTPUT_DIR}")
print(f"\nFiles created:")
print(f"- {OUTPUT_DIR}/final_model/ (trained model)")
print(f"- {OUTPUT_DIR}/finetuning_metrics.json")
print(f"- {OUTPUT_DIR}/finetuning_predictions.csv")

## 9. Download Results

Download the results and model to your local machine or save to Google Drive

In [None]:
# Zip results for download
!zip -r /content/tabllm_results.zip {OUTPUT_DIR}

print("Results zipped!")
print("Download the file 'tabllm_results.zip' from the files panel")

# Or copy to Google Drive
# !cp -r {OUTPUT_DIR} /content/drive/MyDrive/tabllm_results

## Next Steps

1. **Hyperparameter Tuning**: Experiment with different learning rates, batch sizes, and epochs
2. **Model Size**: Try larger models like `t5-large` or `google/t5-v1_1-large` for better performance
3. **Data Augmentation**: Generate synthetic samples to balance classes
4. **Ensemble**: Combine multiple models for improved predictions
5. **Error Analysis**: Examine misclassified samples to understand model limitations

## Troubleshooting

- **Out of Memory**: Reduce batch size or use gradient accumulation
- **Slow Training**: Ensure GPU runtime is enabled
- **Poor Performance**: Try more epochs, different learning rate, or larger model

## References

- [TabLLM Paper](https://arxiv.org/pdf/2210.10723)
- [TabLLM GitHub](https://github.com/clinicalml/TabLLM)
- [T5 Paper](https://arxiv.org/abs/1910.10683)
