In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler, LabelEncoder
import torch 
import torch.nn as nn

import sys 
sys.path.append("../models")

device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu')

In [12]:
def _get_pair_weights(pair_kernel_size, pair_kernel_weight=1):
    
    # Add neighboring indices to pairs in new dimension
    # where neighbors come from square around the pair
    # pair kernel size comes from relationship (2x+1)^2 = 2k + 1
    # pair_kernel_size = (np.sqrt(2 * pair_kernel_size + 1) - 1) // 2
    
    # Create coordinate grids
    x = torch.arange(2*pair_kernel_size + 1)
    y = torch.arange(2*pair_kernel_size + 1)
    
    # Create meshgrid
    X, Y = torch.meshgrid(x, y, indexing="ij")
    
    # Stack X and Y to form the coordinate tensor
    coordinate_tensor = torch.stack((X, Y), dim=2)
    
    # Get the kernel offset
    kernel_offset = coordinate_tensor - pair_kernel_size
    
    weights = torch.exp(
        -torch.square(kernel_offset).sum(dim=2) / (2 * pair_kernel_weight)
    )
    
    return weights

In [15]:
_get_pair_weights(3).shape

x shape tensor([0, 1, 2, 3, 4, 5, 6])
y shape torch.Size([7])


torch.Size([7, 7])

### First on Binary Classification

In [13]:
data = fetch_openml(data_id=41145) # philippine dataset
X, y = data['data'].copy(deep=False), data['target'].copy(deep=False)

y = LabelEncoder().fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

  warn(


In [14]:
X_train.shape

(4665, 308)

In [15]:


from dnamite import DNAMiteBinaryClassifier

model = DNAMiteBinaryClassifier(
    n_features=X_train.shape[1],
    n_embed=8,
    n_hidden=32,
    device=device,
    learning_rate=5e-4,
    kernel_size=5,
    kernel_weight=1,
    pair_kernel_size=100,
    pair_kernel_weight=10,
    entropy_param=1e-3,
    gamma=0.5,
    reg_param=0.1,
    pair_reg_param=0.1,
).to(device)

# First try to select features
model.select_features(X_train, y_train)

                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.
Number of main features selected:  12


                                                

Number of interaction features selected:  66


In [16]:
# now train the model
model.fit(X_train, y_train)

SPlIT 0
Found selected features. Using only those features.


                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                                

SPlIT 1
Found selected features. Using only those features.


                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                                

SPlIT 2
Found selected features. Using only those features.


                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                                

SPlIT 3
Found selected features. Using only those features.


                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                                

SPlIT 4
Found selected features. Using only those features.


                                                

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                                

In [19]:
preds = model.predict(X_test)

pred_probs = 1 / (1 + np.exp(-preds))

from sklearn.metrics import roc_auc_score
roc_auc_score(y_test, pred_probs)

                                      

0.8385630593309477

### Now test for survival analysis

In [20]:
X = pd.read_csv(
    "https://raw.githubusercontent.com/chl8856/DeepHit/master/sample%20data/METABRIC/cleaned_features_final.csv"
)

y = pd.read_csv(
    "https://raw.githubusercontent.com/chl8856/DeepHit/master/sample%20data/METABRIC/label.csv"
)

data = pd.concat([X, y], axis=1)
X = data.drop(["event_time", "label"], axis=1)
y = np.array(list(zip(data["label"], data["event_time"])), dtype=[('event', 'bool'), ('time', 'float32')])

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

In [39]:
from dnamite import DNAMiteSurvival

model = DNAMiteSurvival(
    n_features=X_train.shape[1],
    n_embed=8,
    n_hidden=32,
    n_output=100,
    device=device,
    learning_rate=5e-4,
    kernel_size=5,
    kernel_weight=1,
    pair_kernel_size=10,
    pair_kernel_weight=1,
    entropy_param=1e-3,
    gamma=0.05,
    reg_param=0.05,
    pair_reg_param=0.1,
).to(device)

# First try to select features
model.select_features(X_train, y_train)

                                      

Early stopping: Test loss has not improved for 5 consecutive epochs.
Number of main features selected:  10


                                      

Early stopping: Test loss has not improved for 5 consecutive epochs.
Number of interaction features selected:  45




In [40]:
# now train the model
model.fit(X_train, y_train)

SPlIT 0
Found selected features. Using only those features.


                                              

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                      

SPlIT 1
Found selected features. Using only those features.


                                             

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                               

SPlIT 2
Found selected features. Using only those features.


                                              

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                              

SPlIT 3
Found selected features. Using only those features.


                                      

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                              

SPlIT 4
Found selected features. Using only those features.


                                      

Early stopping: Test loss has not improved for 5 consecutive epochs.


                                               

In [41]:
preds = model.predict(X_test)
pred_probs = 1 / (1 + np.exp(-preds))
surv_preds = 1 - pred_probs

test_times = np.linspace(
    max(y_train["time"].min(), y_test[y_test["event"] > 0]["time"].min()) + 1e-4,
    min(y_train["time"].max(), y_test[y_test["event"] > 0]["time"].max()) - 1e-4,
    100
)

surv_preds = surv_preds[
    :, 
    np.clip(
        np.searchsorted(model.eval_times.cpu().numpy(), test_times),
        0, surv_preds.shape[1]-1
    )
]
risk_preds = -1 * np.log(np.clip(surv_preds, 1e-5, 1 - 1e-5))

# Get time-dependent AUC
from sksurv.metrics import cumulative_dynamic_auc
aucs, mean_auc = cumulative_dynamic_auc(y_train, y_test, risk_preds, test_times)
mean_auc

                                     

0.5321649959705601

In [37]:
aucs

array([0.58778626, 0.58778626, 0.58778626, 0.58778626, 0.55866455,
       0.55866455, 0.55866455, 0.5700975 , 0.5700975 , 0.5700975 ,
       0.50623689, 0.50668132, 0.50668132, 0.5045809 , 0.5045809 ,
       0.5058881 , 0.5058881 , 0.5058881 , 0.5058881 , 0.495625  ,
       0.44025178, 0.44025178, 0.43888861, 0.43888861, 0.43888861,
       0.47208   , 0.47208   , 0.4382915 , 0.4382915 , 0.4376895 ,
       0.4376895 , 0.4376895 , 0.50455508, 0.54557696, 0.54557696,
       0.49206895, 0.49206895, 0.49206895, 0.48950214, 0.48950214,
       0.53516702, 0.53516702, 0.56890388, 0.51693509, 0.51693509,
       0.51693509, 0.51693509, 0.4967256 , 0.51236201, 0.51236201,
       0.51236201, 0.51047887, 0.51047887, 0.51047887, 0.47800139,
       0.47800139, 0.47800139, 0.47800139, 0.47800139, 0.47800139,
       0.47800139, 0.47979531, 0.45348853, 0.45348853, 0.45348853,
       0.45348853, 0.45348853, 0.45348853, 0.47243549, 0.47243549,
       0.49051677, 0.49051677, 0.49051677, 0.49051677, 0.49051