# Testing SWoTTeD on synthetic data

In [None]:
from gen_data import gen_synthetic_data
from swotted.temporalPhenotyping import TemporalPhenotyping
from swotted.loss_metrics import *
from swotted.slidingWindow_model import SlidingWindow
from swotted.utils import success_rate
import matplotlib.pyplot as plt
import numpy as np
import torch
import pickle
import seaborn as sns

#### Params

In [None]:
K = 100         #: number of patients 
N = 10          #: number of medical events 
T = 6           #: length of time's stay 
R = 4           #: number of phenotypes
Tw = 3          #: length of time's window

### Synthetic data

In [None]:
# Generating synthetic data
W_, Ph_, X, params =  gen_synthetic_data(K, N, T, R, Tw, sliding_window=True, noise=0.0, truncate=True)

### Temporal phenotyping

In [None]:
# launch the model learning
loss = eval("Bernoulli")
model = eval("SlidingWindow")
tempPheno = TemporalPhenotyping(metric=loss(), model=model())
tempPheno.fit(X, 
		rank = R, 
		temp_window_length = Tw, 
		batch_size = 50,
		n_epochs = 1000, 
		sparsity= True, 
		non_negativity = True, 
		normalization = True, 
		pheno_succession = True, 
		temp_reg = False,
		trace=True)

### Results

- Phenotypes

In [None]:
tempPheno.reorderPhenotypes(Ph_, tw=Tw)
if Tw == 1:
    Ph = torch.squeeze(tempPheno.Ph, dim=2)
else:
    Ph = tempPheno.Ph

In [None]:
if Tw == 1:
    plt.subplot(221)
    sns.heatmap(Ph_, vmin=0, vmax=1, cmap="binary")
    plt.ylabel("Drugs")
    plt.xlabel("time")
    plt.title("phenotype")
    plt.subplot(222)
    plt.imshow(Ph.detach().numpy(),cmap = "gray", aspect='auto',interpolation = 'none')
    sns.heatmap(Ph.detach().numpy(), vmin=0, vmax=1, cmap="binary")
    plt.ylabel("Drugs")
    plt.xlabel("time")
    plt.title("result")
    plt.show()
else :
    for i in range(R):
        plt.subplot(221)
        sns.heatmap(Ph_[i], vmin=0, vmax=1, cmap="binary")
        plt.ylabel("Drugs")
        plt.xlabel("time")
        plt.title("phenotype")
        plt.subplot(222)
        sns.heatmap(Ph[i].detach().numpy(), vmin=0, vmax=1, cmap="binary")
        plt.ylabel("Drugs")
        plt.xlabel("time")
        plt.title("result")
        plt.show()

- Reconctruction

In [None]:
tempPheno.reconstruction()
Y = tempPheno.recons

In [None]:
for i in range(10):
        plt.subplot(221)
        sns.heatmap(X[i], vmin=0, vmax=1, cmap="binary")
        plt.ylabel("Drugs")
        plt.xlabel("time")
        plt.title("phenotype")
        plt.subplot(222)
        sns.heatmap(Y[i].detach().numpy(), vmin=0, vmax=1, cmap="binary")
        plt.ylabel("Drugs")
        plt.xlabel("time")
        plt.title("result")
        plt.show()

### FIT Scores

- FIT X:

In [None]:
error_X= success_rate(X, Y)
print(error_X)

- FIT P:

In [None]:
tempPheno.reorderPhenotypes(Ph_, tw=Tw)
error_Ph= success_rate(Ph_, Ph)
print(error_Ph)

- FIT W:

In [None]:
error_W= success_rate(W_, tempPheno.Wk)
print(error_W)