In [None]:
!pip install --user stg 
# If you are running this notebook on Google Colab, please reset the current python environment via 'Runtime -> Restart runtime' after installation.

In [8]:
from stg import STG
import stg.utils as utils
import numpy as np
import torch
import time

## Prepare dataset

In [3]:
datasets = utils.load_cox_gaussian_data()

'datasets' should have a following structure: 

```{python}
    datasets = {'train': {'x': (n_train, d) observations (dtype = float32), 
                          't': (n_train) event times (dtype = float32),
                          'e': (n_train) event indicators (dtype = int32)},
                'test': {'x': (n_test, d) observations (dtype = float32), 
                          't': (n_test) event times (dtype = float32),
                          'e': (n_test) event indicators (dtype = int32)},
                'valid': {'x': (n_valid, d) observations (dtype = float32), 
                          't': (n_valid) event times (dtype = float32),
                          'e': (n_valid) event indicators (dtype = int32)}}
```

## Standardize dataset 

In [5]:
train_data = datasets['train']
norm_vals = {
        'mean' : datasets['train']['x'].mean(axis=0),
        'std'  : datasets['train']['x'].std(axis=0)
    }
test_data = datasets['test']

# standardize
train_data = utils.standardize_dataset(datasets['train'], norm_vals['mean'],                                           norm_vals['std'])
valid_data = utils.standardize_dataset(datasets['valid'], norm_vals['mean'],                                           norm_vals['std'])
test_data = utils.standardize_dataset(datasets['test'], norm_vals['mean'],                                            norm_vals['std'])

train_X = train_data['x']
train_y = {'e': train_data['e'], 't': train_data['t']}
valid_X = valid_data['x']
valid_y = {'e': valid_data['e'], 't': valid_data['t']}
test_X = test_data['x']
test_y = {'e': test_data['e'], 't': test_data['t']}

train_data={}
train_data['X'], train_data['E'], \
        train_data['T'] = utils.prepare_data(train_X, train_y)
train_data['ties'] = 'noties'

valid_data={}
valid_data['X'], valid_data['E'], \
        valid_data['T'] = utils.prepare_data(valid_X, valid_y)
valid_data['ties'] = 'noties'

test_data = {}
test_data['X'], test_data['E'], \
        test_data['T'] = utils.prepare_data(test_X, test_y)
test_data['ties'] = 'noties'

## Instantiate the STG trainer

In [12]:
device = "cpu" 
feature_selection = True 

model = STG(task_type='cox',input_dim=train_data['X'].shape[1], output_dim=1, hidden_dims=[60, 20, 3], activation='selu',
    optimizer='Adam', learning_rate=0.0005, batch_size=train_data['X'].shape[0], feature_selection=feature_selection, 
    sigma=0.5, lam=0.004, random_state=1, device=device)
#model.save_checkpoint(filename='tmp.pth')

## Training

In [None]:
now = time.time()
model.fit(train_data['X'], {'E': train_data['E'], 'T': train_data['T']}, nr_epochs=600, 
        valid_X=valid_data['X'], valid_y={'E': valid_data['E'], 'T': valid_data['T']}, print_interval=100)
print("Passed time: {}".format(time.time() - now))


## Evaluating the model

In [13]:
model.evaluate(test_data['X'], {'E': test_data['E'], 'T': test_data['T']})

test_CI=0.510212 test_loss=6.600657


In [11]:
model.get_gates(mode='prob')

array([0.17999333, 0.17464066, 0.18207976, 0.1616241 , 0.15998313,
       0.16988853, 0.18346652, 0.15973818, 0.17987254, 0.16981307],
      dtype=float32)

## Saving and loading

In [14]:
model.save_checkpoint('trained_model.pt')

In [15]:
model_tmp = STG(task_type='cox',input_dim=train_data['X'].shape[1], output_dim=1, hidden_dims=[60, 20, 3], activation='selu',
    optimizer='Adam', learning_rate=0.0005, batch_size=train_data['X'].shape[0], feature_selection=feature_selection, 
    sigma=0.5, lam=0.004, random_state=1, device=device)

In [16]:
model_tmp.load_checkpoint('trained_model.pt')

Checkpoint loaded: trained_model.pt.


## Checking the performance of the loaded model

In [17]:
model_tmp.evaluate(test_data['X'], {'E': test_data['E'], 'T': test_data['T']})

test_CI=0.510212 test_loss=6.600657
