In [1]:
import importlib
import frengression
# importlib.reload(frengression)
from data_causl.utils import *
from data_causl.data import *
from frengression import *

device = torch.device('cpu')
from CausalEGM import *

import numpy as np
import jax.numpy as jnp
import pickle
import os
from tqdm import tqdm
import src.exp_utils as exp_utils

from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import copy
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import log_loss
from scipy.spatial.distance import pdist, squareform
from scipy.stats import norm, gaussian_kde
import warnings

warnings.filterwarnings("ignore")


In [4]:
s,x,z,y = generate_data_survivl(n=5000,T=3, random_seed=42, C_coeff=0)
s_tr = torch.tensor(s, dtype=torch.float32)
x_tr = torch.tensor(x, dtype=torch.int32)
y_tr = torch.tensor(y, dtype=torch.int32)
z_tr = torch.nan_to_num(torch.tensor(z, dtype=torch.float32))

model = FrengressionSurv(x_dim=1, y_dim=1, z_dim=1, T=3, s_dim = 1, noise_dim=1, num_layer=3, hidden_dim=100, 
                        device=device, x_binary = True, s_in_predict=True, y_binary=True)
model.train_y(s=s_tr,x=x_tr,z=z_tr,y=y_tr,num_iters=6000, lr=1e-4, print_every_iter=1000)

Epoch 1: loss 2.4062,	loss_y 0.8635, 0.8715, 0.0161,	loss_eta 1.5427, 1.5957, 0.1059
Epoch 1000: loss 1.4550,	loss_y 0.3534, 0.6544, 0.6020,	loss_eta 1.1016, 2.2370, 2.2708
Epoch 2000: loss 1.4532,	loss_y 0.3362, 0.6511, 0.6299,	loss_eta 1.1170, 2.2459, 2.2578
Epoch 3000: loss 1.4744,	loss_y 0.3314, 0.6485, 0.6342,	loss_eta 1.1430, 2.2710, 2.2561
Epoch 4000: loss 1.4700,	loss_y 0.3255, 0.6425, 0.6339,	loss_eta 1.1445, 2.2659, 2.2427
Epoch 5000: loss 1.4409,	loss_y 0.3181, 0.6297, 0.6233,	loss_eta 1.1228, 2.2517, 2.2577
Epoch 6000: loss 1.4465,	loss_y 0.3160, 0.6270, 0.6221,	loss_eta 1.1306, 2.2429, 2.2246


In [4]:
s,x,z,y = generate_data_survivl(n=10,T=5, random_seed=42, C_coeff=0)
s = torch.tensor(s, dtype=torch.float32)
x = torch.tensor(x, dtype=torch.int32)
y = torch.tensor(y, dtype=torch.int32)
z = torch.nan_to_num(torch.tensor(z, dtype=torch.float32))

n = x.shape[0]
event_indicator = (y>0).float()
c = torch.cumsum(event_indicator, dim=1)
c_shifted = torch.zeros_like(c)
c_shifted[:, 1:] = c[:, :-1]
mask = (c_shifted > 0)
y_masked = copy.deepcopy(y)
y_masked[mask] = -1

y_list = [y[:,:1]]
x_list = [x[:,:1]]
s_list = [s[:,:1]]
z_list = [z[:,:1]]

# resample from data
for t in range(1, 5):
    valid_idx = (y_masked[:,t] >=0).nonzero(as_tuple=True)[0]
    sample_idx = valid_idx[torch.randint(0, len(valid_idx), (n,))]
    y_list.append(y[sample_idx, t:((t+1))])
    x_list.append(x[sample_idx, :((t+1))])
    z_list.append(z[sample_idx, :((t+1))])
    s_list.append(s[sample_idx, :1])
y_sample = torch.cat(y_list, dim=1)

In [18]:
y_sample

tensor([[1, 1, 0, 0, 1],
        [0, 1, 1, 1, 0],
        [1, 0, 0, 1, 0],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 1],
        [0, 0, 1, 0, 0],
        [0, 1, 1, 0, 1]], dtype=torch.int32)

In [33]:
y_sample1 = []
y_sample2 = []
for t in range(5):
    sxz_p = torch.cat([s_list[t], x_list[t], z_list[t]], dim=1)
    etat1 = model.model_eta[t](sxz_p)
    etat2 = model.model_eta[t](sxz_p)

    sxeta_p1 = torch.cat([s_list[t], x_list[t], etat1], dim=1)
    yt1 = model.model_y[t](sxeta_p1)
    sxeta_p2 = torch.cat([s_list[t], x_list[t], etat2], dim=1)
    yt2 = model.model_y[t](sxeta_p2)

    y_sample1.append(yt1)
    y_sample2.append(yt2)

y_sample1_cat = torch.cat(y_sample1,dim=1)
y_sample2_cat = torch.cat(y_sample2,dim=1)
loss_y, loss1_y, loss2_y = energy_loss_two_sample(y_sample, y_sample1_cat, y_sample2_cat)
            
eta_true = torch.randn(y.size(), device=device)

eta1 = []
eta2 = []
perm = torch.randperm(n)
for t in range(5):
    sxz_p1 = torch.cat([s_list[t], x_list[t], z_list[t]], dim=1)
    sxz_p2 = torch.cat([s_list[t][perm], x_list[t][perm], z_list[t][perm]], dim=1)
    etat1 = model.model_eta[t](sxz_p1)
    etat2 = model.model_eta[t](sxz_p2)

    eta1.append(etat1)
    eta2.append(etat2)
eta1_cat = torch.cat(eta1,dim=1)
eta2_cat = torch.cat(eta2,dim=1)

loss_eta, loss1_eta, loss2_eta = energy_loss_two_sample(eta_true, eta1_cat, eta2_cat)
loss = loss_y + loss_eta

In [35]:
loss

tensor(2.2194, grad_fn=<AddBackward0>)

In [12]:
d

tensor([1, 0, 0, 0, 1], dtype=torch.int32)

In [13]:
x[3,:]

tensor([1, 1, 1, 1, 1], dtype=torch.int32)

In [16]:
z_list

[tensor([[-0.3250],
         [ 0.8188],
         [-0.5669],
         [ 1.1772],
         [-0.2943],
         [ 0.6727],
         [ 1.3669],
         [-1.4821],
         [-0.4471],
         [-0.3443]]),
 tensor([[-1.4821, -0.1779],
         [-0.3443,  0.9819],
         [ 1.1772,  0.0681],
         [ 0.8188, -0.0544],
         [ 0.8188, -0.0544],
         [ 1.3669, -0.7025],
         [-0.4471,  0.7484],
         [ 1.1772,  0.0681],
         [ 1.3669, -0.7025],
         [-0.2943, -1.4967]]),
 tensor([[-0.4471,  0.7484,  0.0560],
         [ 1.3669, -0.7025, -1.2175],
         [ 0.6727,  0.5753, -0.7639],
         [ 0.6727,  0.5753, -0.7639],
         [ 0.6727,  0.5753, -0.7639],
         [-0.4471,  0.7484,  0.0560],
         [ 0.8188, -0.0544, -0.5053],
         [ 0.8188, -0.0544, -0.5053],
         [ 1.3669, -0.7025, -1.2175],
         [ 1.3669, -0.7025, -1.2175]]),
 tensor([[-0.4471,  0.7484,  0.0560, -0.1865],
         [ 0.8188, -0.0544, -0.5053, -0.8043],
         [ 0.6727,  0.5753, -0

In [17]:
x_list

[tensor([[1],
         [1],
         [0],
         [1],
         [0],
         [1],
         [0],
         [0],
         [1],
         [1]], dtype=torch.int32),
 tensor([[0, 0],
         [1, 1],
         [1, 1],
         [1, 0],
         [1, 0],
         [0, 0],
         [1, 0],
         [1, 1],
         [0, 0],
         [0, 1]], dtype=torch.int32),
 tensor([[1, 0, 0],
         [0, 0, 1],
         [1, 0, 1],
         [1, 0, 1],
         [1, 0, 1],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [0, 0, 1],
         [0, 0, 1]], dtype=torch.int32),
 tensor([[1, 0, 0, 0],
         [1, 0, 0, 1],
         [1, 0, 1, 1],
         [1, 1, 1, 1],
         [1, 0, 0, 0],
         [1, 1, 1, 1],
         [1, 0, 1, 1],
         [1, 0, 0, 1],
         [1, 1, 1, 1],
         [1, 0, 0, 0]], dtype=torch.int32),
 tensor([[1, 0, 0, 0, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 0, 0, 0, 1],
         [1, 0, 0, 0, 1],
         [1, 0, 0, 0, 1],
         [1, 1, 1, 1