In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, classification_report
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# Import the FTTransformer model from rtdl_revisiting_models.
from rtdl_revisiting_models import FTTransformer

##########################################
# 1. Load and Preprocess the Data
##########################################
# Load CSV files for train, validation, and test splits.
train_data = pd.read_csv("train_data.csv")
val_data   = pd.read_csv("val_data.csv")
test_data  = pd.read_csv("test_data.csv")

In [2]:
# Define feature columns and target label.
numerical_features = ["Age", "CDGLOBAL", "CDRSB", "MMSCORE", "HMSCORE", "NPISCORE", "GDTOTAL"]
categorical_features = ["GENOTYPE"]
label = "Group"

In [3]:
# Subset dataframes to desired columns.
cols = numerical_features + categorical_features + [label]
train_data = train_data[cols]
val_data   = val_data[cols]
test_data  = test_data[cols]

In [4]:
# Handle missingness for numerical features.
cols_with_missing = ["CDRSB", "MMSCORE", "HMSCORE", "NPISCORE", "GDTOTAL"]
for col in cols_with_missing:
    for df in [train_data, val_data, test_data]:
        df[col + "_is_missing"] = df[col].isnull().astype(int)
        df[col] = df[col].fillna(-999)

In [5]:
# Extend continuous features to include missing indicators.
numerical_features_extended = numerical_features + [col + "_is_missing" for col in cols_with_missing]

In [6]:
# Encode categorical features using LabelEncoder.
cat_encoders = {}
for col in categorical_features:
    le = LabelEncoder()
    train_data[col] = le.fit_transform(train_data[col].astype(str))
    val_data[col]   = le.transform(val_data[col].astype(str))
    test_data[col]  = le.transform(test_data[col].astype(str))
    cat_encoders[col] = le

In [7]:
# Encode the target.
label_encoder = LabelEncoder()
train_data[label] = label_encoder.fit_transform(train_data[label])
val_data[label]   = label_encoder.transform(val_data[label])
test_data[label]  = label_encoder.transform(test_data[label])
num_classes = len(label_encoder.classes_)  # e.g., 3 for classification

In [8]:
##########################################
# 2. Prepare NumPy Arrays and Create Dataset
##########################################
# Continuous features (including missing indicators).
X_train_cont = train_data[numerical_features_extended].values.astype(np.float32)
X_val_cont   = val_data[numerical_features_extended].values.astype(np.float32)
X_test_cont  = test_data[numerical_features_extended].values.astype(np.float32)

# Categorical features.
X_train_cat = train_data[categorical_features].values.astype(np.int64)
X_val_cat   = val_data[categorical_features].values.astype(np.int64)
X_test_cat  = test_data[categorical_features].values.astype(np.int64)

# Labels.
y_train = train_data[label].values.astype(np.int64)
y_val   = val_data[label].values.astype(np.int64)
y_test  = test_data[label].values.astype(np.int64)

In [9]:
# Create a simple PyTorch Dataset.
class TabularDataset(Dataset):
    def __init__(self, cont, cat, labels):
        self.cont = cont
        self.cat = cat
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return {
            "cont": torch.tensor(self.cont[idx], dtype=torch.float32),
            "cat": torch.tensor(self.cat[idx], dtype=torch.long),
            "target": torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [10]:
train_dataset = TabularDataset(X_train_cont, X_train_cat, y_train)
val_dataset   = TabularDataset(X_val_cont, X_val_cat, y_val)
test_dataset  = TabularDataset(X_test_cont, X_test_cat, y_test)

# Create DataLoaders.
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [11]:
##########################################
# 3. Initialize and Train the FTTransformer Classifier
##########################################
# Get the number of continuous features.
n_cont_features = X_train_cont.shape[1]
# Determine the cardinalities for each categorical feature.
cat_cardinalities = [int(train_data[col].nunique()) for col in categorical_features]

# For classification, set d_out = number of classes.
d_out = num_classes

# Instantiate the FTTransformer.
model = FTTransformer(
    n_cont_features=n_cont_features,
    cat_cardinalities=cat_cardinalities,
    d_out=d_out,
    n_blocks=3,
    d_block=192,                # Backbone (hidden) dimension
    attention_n_heads=8,
    attention_dropout=0.2,
    ffn_d_hidden=None,          # Defaults internally if None.
    ffn_d_hidden_multiplier=4/3,
    ffn_dropout=0.1,
    residual_dropout=0.0
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

FTTransformer(
  (cls_embedding): _CLSEmbedding()
  (cont_embeddings): LinearEmbeddings()
  (cat_embeddings): CategoricalEmbeddings(
    (embeddings): ModuleList(
      (0): Embedding(6, 192)
    )
  )
  (backbone): FTTransformerBackbone(
    (blocks): ModuleList(
      (0): ModuleDict(
        (attention): MultiheadAttention(
          (W_q): Linear(in_features=192, out_features=192, bias=True)
          (W_k): Linear(in_features=192, out_features=192, bias=True)
          (W_v): Linear(in_features=192, out_features=192, bias=True)
          (W_out): Linear(in_features=192, out_features=192, bias=True)
          (dropout): Dropout(p=0.2, inplace=False)
        )
        (attention_residual_dropout): Dropout(p=0.0, inplace=False)
        (ffn_normalization): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
        (ffn): Sequential(
          (linear1): Linear(in_features=192, out_features=512, bias=True)
          (activation): _ReGLU()
          (dropout): Dropout(p=0.1, inplace

In [12]:
# Set up optimizer, loss function, and scheduler.
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# Training loop with early stopping.
max_epochs = 100
patience = 5
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(max_epochs):
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        cont = batch["cont"].to(device)
        cat = batch["cat"].to(device)
        targets = batch["target"].to(device)
        
        optimizer.zero_grad()
        logits = model(cont, cat)  # Forward pass returns logits (shape: [batch_size, d_out])
        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * cont.size(0)
    train_loss /= len(train_dataset)
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            cont = batch["cont"].to(device)
            cat = batch["cat"].to(device)
            targets = batch["target"].to(device)
            
            logits = model(cont, cat)
            loss = criterion(logits, targets)
            val_loss += loss.item() * cont.size(0)
    val_loss /= len(val_dataset)
    
    scheduler.step(val_loss)
    print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model_state = model.state_dict()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break



Epoch 1: Train Loss: 1.0578 | Val Loss: 1.0382
Epoch 2: Train Loss: 1.0128 | Val Loss: 0.8282
Epoch 3: Train Loss: 0.7741 | Val Loss: 0.5106
Epoch 4: Train Loss: 0.5536 | Val Loss: 0.4235
Epoch 5: Train Loss: 0.4884 | Val Loss: 0.3976
Epoch 6: Train Loss: 0.4308 | Val Loss: 0.3562
Epoch 7: Train Loss: 0.3916 | Val Loss: 0.4066
Epoch 8: Train Loss: 0.3812 | Val Loss: 0.3571
Epoch 9: Train Loss: 0.3757 | Val Loss: 0.3487
Epoch 10: Train Loss: 0.3501 | Val Loss: 0.3483
Epoch 11: Train Loss: 0.3555 | Val Loss: 0.3395
Epoch 12: Train Loss: 0.3441 | Val Loss: 0.3403
Epoch 13: Train Loss: 0.3334 | Val Loss: 0.3384
Epoch 14: Train Loss: 0.3370 | Val Loss: 0.3344
Epoch 15: Train Loss: 0.3442 | Val Loss: 0.3401
Epoch 16: Train Loss: 0.3461 | Val Loss: 0.3897
Epoch 17: Train Loss: 0.3406 | Val Loss: 0.3461
Epoch 18: Train Loss: 0.3236 | Val Loss: 0.3335
Epoch 19: Train Loss: 0.3139 | Val Loss: 0.3337
Epoch 20: Train Loss: 0.3174 | Val Loss: 0.3450
Epoch 21: Train Loss: 0.3264 | Val Loss: 0.3954
E

In [13]:
# Save the best model.
save_model_path = "best_ft_transformer_classification.pt"
torch.save(best_model_state, save_model_path)
print("Trained model saved to", save_model_path)

Trained model saved to best_ft_transformer_classification.pt


In [14]:
# Load the best model (optional).
model.load_state_dict(torch.load(save_model_path))

<All keys matched successfully>

In [15]:
# Evaluate classification performance on the test set.
model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for batch in test_loader:
        cont = batch["cont"].to(device)
        cat = batch["cat"].to(device)
        targets = batch["target"].to(device)
        
        logits = model(cont, cat)
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
test_acc = accuracy_score(all_targets, all_preds)
print("Test Accuracy:", test_acc)
print("Classification Report (Test):")
print(classification_report(all_targets, all_preds, target_names=[str(c) for c in label_encoder.classes_]))

Test Accuracy: 0.9043478260869565
Classification Report (Test):
              precision    recall  f1-score   support

          AD       0.83      0.81      0.82        72
          CN       0.98      0.95      0.97       106
         MCI       0.89      0.92      0.90       167

    accuracy                           0.90       345
   macro avg       0.90      0.89      0.90       345
weighted avg       0.90      0.90      0.90       345



In [16]:
##########################################
# 4. Extract 192-Dimensional Embeddings (Before the Final Classification)
##########################################
# The classifier was trained with d_out = num_classes.
# To obtain 192-dim embeddings (the backbone outputs), we replace the final linear layer with an identity.
# Here we assume the final projection is stored in the attribute 'fc'.
model.fc = nn.Identity()  # Now model(cont, cat) returns the backbone features of shape (batch_size, 192).

# Function to extract features from a DataLoader.
def extract_features(loader, model, device):
    model.eval()
    features_list = []
    with torch.no_grad():
        for batch in loader:
            cont = batch["cont"].to(device)
            cat = batch["cat"].to(device)
            feats = model(cont, cat)  # Should now return features of shape (batch_size, 192)
            features_list.append(feats.cpu().numpy())
    return np.concatenate(features_list)

# Extract features from each set.
features_train = extract_features(train_loader, model, device)
features_val   = extract_features(val_loader, model, device)
features_test  = extract_features(test_loader, model, device)

print("Extracted feature shapes:")
print("Train features:", features_train.shape)
print("Validation features:", features_val.shape)
print("Test features:", features_test.shape)

# Save the extracted features.
np.save("ft_train_features.npy", features_train)
np.save("ft_val_features.npy", features_val)
np.save("ft_test_features.npy", features_test)
print("Extracted features saved as 'ft_train_features.npy', 'ft_val_features.npy', and 'ft_test_features.npy'.")

Extracted feature shapes:
Train features: (1605, 3)
Validation features: (344, 3)
Test features: (345, 3)
Extracted features saved as 'ft_train_features.npy', 'ft_val_features.npy', and 'ft_test_features.npy'.
