In [1]:
from dotenv import load_dotenv

import os, sys

sys.path.append("..")

import pandas as pd

load_dotenv()

pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.expand_frame_repr', False)

In [2]:
DATA_PATH = os.getenv("DATA_PATH", "data")

MODEL_PATH = os.getenv("MODEL_PATH", "models")

LOGS_PATH = os.getenv("LOGS_PATH", "logs")

RESULTS_PATH = os.getenv("RESULTS_PATH", "results")

In [3]:
df = pd.read_csv(f"data.csv")

In [4]:
df.drop(columns=['id', 'session_id', 'user_id', 'crisis_intro', 'crisis_option_a', 'crisis_option_b'], inplace=True)

In [5]:
# Patient Profile (Xp)
xp_features = [
    'age',
    'location', 'eol_preference',
    'family_preference', 'biological_sex', 'gender_identity', 'political_leaning',
    'marital_status', 'religion', 'religious_importance', 'annual_income',
    'education', 'family_history_dementia', 'personal_history_dementia',
    'dementia_worry'
]

# Xt Modular Components
xt_parts = {
    'medical': ['crisis_type'],
    'patient_condition': ['crisis_chance','emotional_state', 'agitation_frequency', 'agitation_severity',
    'family_visit_frequency', 'family_inconvenience', 'interaction_ability',
    'functional_ability', 'behavior', 'affordability'],
    'treatment': ['crisis_wean', 'crisis_tube'],
    'prognosis': ['crisis_comfort', 'resuscitation_chance', 'leave_hospital', 'internal_damage', 'future_arrest']
}


columns_to_drop = list(set(df.columns) - set(xp_features) - set(sum(xt_parts.values(), [])) - {"choice"})

df = df.drop(columns=columns_to_drop)

In [6]:
from sklearn.preprocessing import LabelEncoder

# Make a copy of the original dataframe
df_encoded = df.copy()

# Build list of categorical features to encode
categorical_features = [col for col in xp_features if col not in ['religious_importance', 'dementia_worry']]

# Add Xt features to encoding list
for part in xt_parts.values():
    for feature in part:
        categorical_features.append(feature)

# Drop duplicates just in case
categorical_features = list(set(categorical_features))

# Initialize encoder dictionary
encoders = {}

# Apply Label Encoding
for col in categorical_features:
    le = LabelEncoder()
    df_encoded[col] = le.fit_transform(df_encoded[col])
    encoders[col] = le  # Save encoder for later use

target_encoder = LabelEncoder()
df_encoded['choice'] = target_encoder.fit_transform(df_encoded['choice'])

# Map for later decoding
print("✅ Encoding completed. Stored encoders for all categorical features.")


✅ Encoding completed. Stored encoders for all categorical features.


In [7]:
# # Fix CPR variables to specific options across all rows
# df_encoded['resuscitation_chance'] = encoders['resuscitation_chance'].transform(['low (around 9-in-100)'])[0]
# df_encoded['leave_hospital'] = encoders['leave_hospital'].transform(['low (around 9-in-100)'])[0]
# df_encoded['internal_damage'] = encoders['internal_damage'].transform(['strong chance (around 80-in-100)'])[0]
# df_encoded['future_arrest'] = encoders['future_arrest'].transform(['moderate to high chance (around 70-in-100)'])[0]


# from sklearn.preprocessing import LabelEncoder

# for col in ['resuscitation_chance', 'leave_hospital', 'internal_damage', 'future_arrest']:
#     le = LabelEncoder()
#     df_encoded[col] = le.fit_transform(df_encoded[col])
#     encoders[col] = le


In [8]:
from sklearn.model_selection import train_test_split

# Separate features and target
X = df_encoded.drop(columns=['choice'])
y = df_encoded['choice']

# First split: Train + Temp (Val + Test)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)

# Second split: Validation + Test
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

print(f"Train size: {len(X_train)} | Val size: {len(X_val)} | Test size: {len(X_test)}")


Train size: 2331 | Val size: 499 | Test size: 500


In [9]:
import torch
from torch.utils.data import Dataset

class MultiClassTreatmentDataset(Dataset):
    def __init__(self, X, y, xp_features, xt_parts):
        self.X = X.reset_index(drop=True)
        self.y = y.reset_index(drop=True)
        self.xp_features = xp_features
        self.xt_parts = xt_parts

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        row = self.X.iloc[idx]

        # Extract Xp
        Xp = {feature: torch.tensor(row[feature], dtype=torch.float32).unsqueeze(0) for feature in self.xp_features}

        # Extract Xt parts
        Xt_parts = {}
        for part_name, features in self.xt_parts.items():
            Xt_parts[part_name] = {feature: torch.tensor(row[feature], dtype=torch.float32).unsqueeze(0) for feature in features}

        # Target
        y_target = torch.tensor(self.y.iloc[idx], dtype=torch.long)

        return Xp, Xt_parts, y_target


In [10]:
from torch.utils.data import DataLoader

# Create datasets
train_dataset = MultiClassTreatmentDataset(X_train, y_train, xp_features, xt_parts)
val_dataset = MultiClassTreatmentDataset(X_val, y_val, xp_features, xt_parts)
test_dataset = MultiClassTreatmentDataset(X_test, y_test, xp_features, xt_parts)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=True)


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ModularMultiClassTreatmentModel(nn.Module):
    def __init__(self, xp_features, xt_parts, embedding_sizes, num_classes):
        super(ModularMultiClassTreatmentModel, self).__init__()

        self.xp_features = xp_features
        self.xt_parts = xt_parts
        self.num_classes = num_classes

        # ===== Embeddings =====
        self.embeddings = nn.ModuleDict()
        for feature, num_categories in embedding_sizes.items():
            self.embeddings[feature] = nn.Embedding(num_embeddings=num_categories, embedding_dim=4)

        # ===== Patient Encoder =====
        patient_input_dim = self.calculate_total_embedding_dim(xp_features)
        self.patient_encoder = nn.Sequential(
            nn.Linear(patient_input_dim, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU()
        )

        # ===== Xt Part Encoders =====
        self.part_encoders = nn.ModuleDict()
        for part_name, features in xt_parts.items():
            part_input_dim = self.calculate_total_embedding_dim(features)
            self.part_encoders[part_name] = nn.Sequential(
                nn.Linear(part_input_dim, 64),
                nn.BatchNorm1d(64),
                nn.ReLU()
            )

        # ===== Attention Mechanism =====
        self.attention = nn.Sequential(
            nn.Linear(64, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

        # ===== Final Decision Head =====
        self.decision_head = nn.Sequential(
            nn.Linear(64 + 64, 128),  # patient vector + weighted Xt vector
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Linear(64, num_classes)  # Multi-class output
        )

    def calculate_total_embedding_dim(self, features):
        dim = 0
        for feature in features:
            dim += 4 if feature in self.embeddings else 1  # Categorical: embedding | Continuous: 1-dim
        return dim

    def forward(self, Xp, Xt_parts):
        # ===== Process Patient Features =====
        patient_embeds = []
        for feature in self.xp_features:
            if feature in self.embeddings:
                patient_embeds.append(self.embeddings[feature](Xp[feature].long()).squeeze(1))
            else:
                patient_embeds.append(Xp[feature].unsqueeze(1))  # 🔥 FIXED HERE

        patient_embeds = torch.cat(patient_embeds, dim=1)
        patient_vector = self.patient_encoder(patient_embeds)

        # ===== Process Xt Parts =====
        part_vectors = []
        attention_scores = []

        for part_name, features in self.xt_parts.items():
            part_embeds = []
            for feature in features:
                if feature in self.embeddings:
                    part_embeds.append(self.embeddings[feature](Xt_parts[part_name][feature].long()).squeeze(1))
                else:
                    part_embeds.append(Xt_parts[part_name][feature].unsqueeze(1))  # 🔥 FIXED HERE

            part_embeds = torch.cat(part_embeds, dim=1)
            part_vector = self.part_encoders[part_name](part_embeds)
            part_vectors.append(part_vector)

            attn_score = self.attention(part_vector)
            attention_scores.append(attn_score)

        # ===== Attention Weighted Sum =====
        attention_scores = torch.cat(attention_scores, dim=1)
        attn_weights = F.softmax(attention_scores, dim=1)

        weighted_Xt = torch.stack(part_vectors, dim=1)  # Shape: [batch_size, num_parts, 64]
        weighted_Xt = (attn_weights.unsqueeze(2) * weighted_Xt).sum(dim=1)

        # ===== Final Decision =====
        combined = torch.cat([patient_vector, weighted_Xt], dim=1)
        logits = self.decision_head(combined)

        return logits  # Raw logits for CrossEntropyLoss


In [12]:
# Define embedding sizes for all categorical features
embedding_sizes = {}

# Go through each categorical feature we encoded earlier
for col in categorical_features:
    num_categories = df_encoded[col].nunique()
    embedding_sizes[col] = num_categories


In [13]:
import torch.optim as optim

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ModularMultiClassTreatmentModel(xp_features, xt_parts, embedding_sizes, num_classes=len(df_encoded['choice'].unique()))
model = model.to(device)

# Loss function for multi-class
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)


In [14]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=30, device='cpu'):
    best_val_loss = float('inf')
    patience = 5
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for Xp_batch, Xt_parts_batch, y_batch in train_loader:
            Xp_batch = {k: v.to(device).squeeze(1) for k, v in Xp_batch.items()}
            Xt_parts_batch = {part: {f: v.to(device).squeeze(1) for f, v in features.items()} for part, features in Xt_parts_batch.items()}
            y_batch = y_batch.to(device)

            optimizer.zero_grad()
            logits = model(Xp_batch, Xt_parts_batch)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        scheduler.step()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for Xp_batch, Xt_parts_batch, y_batch in val_loader:
                Xp_batch = {k: v.to(device).squeeze(1) for k, v in Xp_batch.items()}
                Xt_parts_batch = {part: {f: v.to(device).squeeze(1) for f, v in features.items()} for part, features in Xt_parts_batch.items()}
                y_batch = y_batch.to(device)

                logits = model(Xp_batch, Xt_parts_batch)
                loss = criterion(logits, y_batch)
                val_loss += loss.item()

                probs = torch.softmax(logits, dim=1)
                preds = torch.argmax(probs, dim=1)

                correct += (preds == y_batch).sum().item()
                total += y_batch.size(0)

        avg_train_loss = running_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct / total

        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")

        # Early Stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f'models/best_base_model.pth')
            print("✅ Best model saved!")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping triggered!")
                break

    print("Training complete!")


In [15]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix

def test_model(model, test_loader, criterion, encoder, type = "Test", device='cpu'):
    model.eval()
    all_labels = []
    all_preds = []
    total_loss = 0.0

    with torch.no_grad():
        for Xp_batch, Xt_parts_batch, y_batch in test_loader:
            Xp_batch = {k: v.to(device).squeeze(1) for k, v in Xp_batch.items()}
            Xt_parts_batch = {part: {f: v.to(device).squeeze(1) for f, v in features.items()} for part, features in Xt_parts_batch.items()}
            y_batch = y_batch.to(device)

            logits = model(Xp_batch, Xt_parts_batch)
            loss = criterion(logits, y_batch)
            total_loss += loss.item()

            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            all_labels.extend(y_batch.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    print(f"{type} Loss: {avg_loss:.4f}")
    print(f"{type} Accuracy: {accuracy:.4f}")
    print(f"{type} Precision: {precision:.4f}")
    print(f"{type} Recall: {recall:.4f}")
    print(f"{type} F1 Score: {f1:.4f}\n")

    print("===== Classification Report =====\n")
    print(classification_report(all_labels, all_preds, zero_division=0, target_names=encoder.classes_))

    print("===== Confusion Matrix =====\n")
    print(confusion_matrix(all_labels, all_preds))

    return avg_loss, accuracy, precision, recall, f1


In [16]:
test_model(model, test_loader, criterion, target_encoder, type = "Test", device=device)

Test Loss: 0.7628
Test Accuracy: 0.2455
Test Precision: 0.0603
Test Recall: 0.2455
Test F1 Score: 0.0968

===== Classification Report =====

              precision    recall  f1-score   support

no treatment       0.00      0.00      0.00       338
   treatment       0.25      1.00      0.39       110

    accuracy                           0.25       448
   macro avg       0.12      0.50      0.20       448
weighted avg       0.06      0.25      0.10       448

===== Confusion Matrix =====

[[  0 338]
 [  0 110]]


(0.7628026945250375,
 0.24553571428571427,
 0.060287786989795915,
 0.24553571428571427,
 0.09680619559651818)

In [17]:
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=50, device=device)

Epoch [1/50] Train Loss: 0.4404 | Val Loss: 0.3088 | Val Acc: 0.8549
✅ Best model saved!
Epoch [2/50] Train Loss: 0.2751 | Val Loss: 0.2384 | Val Acc: 0.8996
✅ Best model saved!
Epoch [3/50] Train Loss: 0.2139 | Val Loss: 0.1838 | Val Acc: 0.9375
✅ Best model saved!
Epoch [4/50] Train Loss: 0.1575 | Val Loss: 0.1390 | Val Acc: 0.9509
✅ Best model saved!
Epoch [5/50] Train Loss: 0.1210 | Val Loss: 0.1174 | Val Acc: 0.9665
✅ Best model saved!
Epoch [6/50] Train Loss: 0.0935 | Val Loss: 0.0917 | Val Acc: 0.9777
✅ Best model saved!
Epoch [7/50] Train Loss: 0.0760 | Val Loss: 0.0774 | Val Acc: 0.9754
✅ Best model saved!
Epoch [8/50] Train Loss: 0.0713 | Val Loss: 0.0736 | Val Acc: 0.9799
✅ Best model saved!
Epoch [9/50] Train Loss: 0.0505 | Val Loss: 0.0670 | Val Acc: 0.9799
✅ Best model saved!
Epoch [10/50] Train Loss: 0.0443 | Val Loss: 0.0692 | Val Acc: 0.9799
Epoch [11/50] Train Loss: 0.0331 | Val Loss: 0.0541 | Val Acc: 0.9844
✅ Best model saved!
Epoch [12/50] Train Loss: 0.0282 | Val 

In [18]:
baseline_loss, baseline_accuracy, baseline_precision, baseline_recall, baseline_f1 = test_model(model, test_loader, criterion, target_encoder, type = "Test", device=device)

Test Loss: 0.0258
Test Accuracy: 0.9911
Test Precision: 0.9911
Test Recall: 0.9911
Test F1 Score: 0.9911

===== Classification Report =====

              precision    recall  f1-score   support

no treatment       0.99      0.99      0.99       338
   treatment       0.98      0.98      0.98       110

    accuracy                           0.99       448
   macro avg       0.99      0.99      0.99       448
weighted avg       0.99      0.99      0.99       448

===== Confusion Matrix =====

[[336   2]
 [  2 108]]


In [19]:
df_encoded["choice"].value_counts()

choice
0    2552
1     778
Name: count, dtype: int64

In [20]:
baseline_loss, baseline_accuracy, baseline_precision, baseline_recall, baseline_f1 = test_model(model, train_loader, criterion, target_encoder, type = "Train", device=device)

Train Loss: 0.0067
Train Accuracy: 0.9987
Train Precision: 0.9987
Train Recall: 0.9987
Train F1 Score: 0.9987

===== Classification Report =====

              precision    recall  f1-score   support

no treatment       1.00      1.00      1.00      1766
   treatment       0.99      1.00      1.00       538

    accuracy                           1.00      2304
   macro avg       1.00      1.00      1.00      2304
weighted avg       1.00      1.00      1.00      2304

===== Confusion Matrix =====

[[1763    3]
 [   0  538]]


In [21]:
import numpy as np
from collections import defaultdict

def loss_based_classification_report(model, data_loader, criterion, encoder, device='cpu', type='Test'):
    model.eval()
    all_labels = []
    all_preds = []
    all_losses = []
    all_probs = []

    with torch.no_grad():
        for Xp_batch, Xt_parts_batch, y_batch in data_loader:
            Xp_batch = {k: v.to(device).squeeze(1) for k, v in Xp_batch.items()}
            Xt_parts_batch = {part: {f: v.to(device).squeeze(1) for f, v in features.items()} for part, features in Xt_parts_batch.items()}
            y_batch = y_batch.to(device)

            logits = model(Xp_batch, Xt_parts_batch)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)

            loss = criterion(logits, y_batch)

            all_labels.extend(y_batch.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_losses.extend([loss.item()] * y_batch.size(0))
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_losses = np.array(all_losses)
    all_probs = np.array(all_probs)

    # Per-sample loss using negative log likelihood for true class
    per_sample_losses = -np.log(np.choose(all_labels, all_probs.T) + 1e-12)

    # Group losses by class
    loss_by_class = defaultdict(list)
    for label, sample_loss in zip(all_labels, per_sample_losses):
        loss_by_class[label].append(sample_loss)

    # Build Loss Classification Report
    print(f"\n===== {type} Loss-Based Classification Report =====\n")
    print(f"{'Class':<25}{'Average Loss':<15}{'Support':<10}")
    print("-" * 50)

    for class_idx, losses in loss_by_class.items():
        avg_loss = np.mean(losses)
        support = len(losses)
        class_name = encoder.inverse_transform([class_idx])[0]
        print(f"{class_name:<25}{avg_loss:<15.4f}{support:<10}")


In [22]:
loss_based_classification_report(model, test_loader, criterion, target_encoder, device=device, type="Test")


===== Test Loss-Based Classification Report =====

Class                    Average Loss   Support   
--------------------------------------------------
treatment                0.0432         110       
no treatment             0.0201         338       
