In [3]:
import pandas as pd

# Load the pickle file back into a DataFrame
data = pd.read_pickle('data/simulated_data.pkl')



# Check the loaded data
print(data.head())

unique_id_count = data['id'].nunique()
print(f"Number of unique IDs: {unique_id_count}")

print(data.shape)

   id  visit  obstime  predtime  time  event         Y1        Y2        Y3  \
0   0      0        0         0     9   True  11.943728 -3.032593  2.760192   
1   0      1        1         1     9   True  12.255357 -5.431790  4.225383   
2   0      2        2         2     9   True  12.491947 -6.953460  2.854653   
3   0      3        3         3     9   True  16.406431 -8.508030  4.766191   
4   0      4        4         4     9   True  16.632347 -9.813989  5.816555   

    X1        X2    pred_Y1   pred_Y2   pred_Y3      true  
0  1.0  0.680195  11.943728 -3.032593  2.760192  1.000000  
1  1.0  0.680195  12.255357 -5.431790  4.225383  0.999397  
2  1.0  0.680195  12.491947 -6.953460  2.854653  0.998135  
3  1.0  0.680195  16.406431 -8.508030  4.766191  0.995494  
4  1.0  0.680195  16.632347 -9.813989  5.816555  0.989983  
Number of unique IDs: 1000


In [51]:
df = pd.DataFrame(obs_time.numpy())
df.to_csv("data/tensor_data_long.csv", index=False) 

In [78]:
from util import (get_tensors, get_mask, init_weights, get_std_opt)
I = data['id'].nunique()

        ## split train/test
random_id = range(I) #np.random.permutation(range(I))
train_id = random_id[0:int(0.7*I)]
test_id = random_id[int(0.7*I):I]

train_data = data[data["id"].isin(train_id)]
test_data = data[data["id"].isin(test_id)]

print(train_data.shape)

batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(train_data)

(5329, 15)


In [109]:
def positional_encoding(batch_size, length, d_model, obs_time):
    """
    Positional Encoding for each visit
    
    Parameters
    ----------
    batch_size:
        Number of subjects in batch
    length:
        Number of visits
    d_model:
        Dimension of the model vector
    obs_time:
        Observed/recorded time of each visit
    """
    PE = torch.zeros((batch_size, length, d_model)).to('cuda')
    if obs_time.ndim == 0:
        obs_time = obs_time.repeat(batch_size).unsqueeze(1)
    elif obs_time.ndim == 1:
        obs_time = obs_time.repeat(batch_size,1)
    obs_time = obs_time.to('cuda')
    pow0 = torch.pow(10000, torch.arange(0, d_model, 2, dtype=torch.float32)/d_model).to('cuda')

    PE[:, :, 0::2] = torch.sin(torch.einsum('ij,k->ijk', obs_time, pow0))
    pow1 = torch.pow(10000, torch.arange(1, d_model, 2, dtype=torch.float32)/d_model).to('cuda')
    PE[:, :, 1::2] = torch.cos(torch.einsum('ij,k->ijk', obs_time, pow1))

    return PE


class Decoder_p(nn.Module):
    """
    Decoder Block
    
    Parameters_
    ----------
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_model,
                 nhead,
                 num_decoder_layers,
                 dropout):
        super().__init__()

        self.decoder_layers = nn.ModuleList([Decoder_Layer(d_model,nhead,dropout)
                                             for _ in range(num_decoder_layers)])
        
    def forward(self, q, kv, mask, pred_time):
        # Positional Embedding
        
        q = q + positional_encoding(
            q.shape[0], q.shape[1], q.shape[2], pred_time)
        
        # Decoder Layers
        for layer in self.decoder_layers:
            x = layer(q, kv,mask)

        return x

In [96]:
import torch.nn.functional as F
class Decoder_Layer(nn.Module):
    """
    Decoder Block
    
    Parameters
    ----------
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    dropout:
        The dropout value
    """
    
    def __init__(self,
                 d_model,
                 nhead,
                 dropout = 0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.Attention = MultiHeadAttention(d_model, nhead)
                
        self.feedForward = nn.Sequential(
            nn.Linear(d_model,64),
            nn.ReLU(),
            nn.Linear(64,d_model),
            nn.Dropout(dropout)
            )
        
        self.layerNorm1 = nn.LayerNorm(d_model)
        self.layerNorm2 = nn.LayerNorm(d_model)
        
    def forward(self, q, kv, mask):
        
        # Attention
        residual = q
        x = self.Attention(query=q, key=kv, value=kv, mask = mask)
        x = self.dropout(x)
        x = self.layerNorm1(x + residual)
        
        # Feed Forward
        residual = x
        x = self.feedForward(x)
        x = self.layerNorm2(x + residual)
        
        return x
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // nhead
        self.nhead = nhead
        
        assert (
            d_model % nhead == 0
        ), "Embedding size (d_model) needs to be divisible by number of heads"
        
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
        
    def attention(self, query, key, value, d_k, mask = None, dropout=None):
    
        scores = torch.matmul(query, key.transpose(-2, -1)) /  np.sqrt(d_k)
        if mask is not None:
            mask = mask.unsqueeze(1).to('cuda')
            scores = scores.masked_fill(mask == 0, -float('inf'))
        scores = F.softmax(scores, dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
            
        output = torch.matmul(scores, value)
        return output

    def forward(self, query, key, value, mask = None):
        I = query.shape[0]
        
        # perform linear operation and split into N heads
        query = self.q_linear(query).view(I, -1, self.nhead, self.d_k)
        key = self.k_linear(key).view(I, -1, self.nhead, self.d_k)
        value = self.v_linear(value).view(I, -1, self.nhead, self.d_k)
        
        # transpose to get dimensions I * nhead * J * d_k
        query = query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)

        # calculate attention
        scores = self.attention(query, key, value, self.d_k, mask, self.dropout)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(I, -1, self.d_model)
        output = self.out(concat)
    
        return output


In [139]:
class Decoder_MMOE_Layer(nn.Module):
    """Transformer Decoder block with Mixture-of-Experts (MMoE) feedforward.
    
    This block applies multi-head attention (with residual connection and normalization),
    then passes the result through an MMoE module with multiple experts and two task-specific gating networks (for longitudinal and survival tasks).
    The experts' outputs are combined using the gating weights for each task, and finally each task has its own linear head to produce predictions.
    
    Returns:
        Tuple[Tensor, Tensor]: (longitudinal_output of shape [B, T, d_long], survival_output of shape [B, T, 1])
    """
    def __init__(self, d_model, nhead, num_experts, d_ff_expert, d_long):
        super().__init__()
        self.dropout_attn = nn.Dropout(0.1)
        self.Attention = MultiHeadAttention(d_model, nhead)
        self.norm1 = nn.LayerNorm(d_model)
        # MMoE experts: each expert is a small feedforward network (d_model -> d_model)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff_expert),
                nn.ReLU(),
                nn.Linear(d_ff_expert, d_model)
            ) for _ in range(num_experts)
        ])
        # Task-specific gating networks (each outputs a distribution over experts)
        self.gate_long = nn.Linear(d_model, num_experts)
        self.gate_surv = nn.Linear(d_model, num_experts)
        # Output heads for each task
        self.longitudinal_head = nn.Linear(d_model, d_long)
        self.survival_head = nn.Linear(d_model, 1)
    
    def forward(self, q, kv, mask=None):
        
        # 1. Multi-head attention (with optional masking)
        attn_output = self.Attention(query=q, key=kv, value=kv, mask = mask)
        # Add & Norm: residual connection and layer normalization
        x = self.norm1(q + self.dropout_attn(attn_output))
        # 2. Mixture-of-Experts (MMoE) feedforward sublayer
        # Compute outputs of all experts (shared expert pool)
        expert_outputs = [expert(x) for expert in self.experts]    # list of [B, T, d_model] for each expert
        expert_outputs = torch.stack(expert_outputs, dim=2)        # shape [B, T, num_experts, d_model]
        # Compute gating weights for each task (softmax over experts)
        gate_long = F.softmax(self.gate_long(x), dim=-1)           # [B, T, num_experts]
        gate_surv = F.softmax(self.gate_surv(x), dim=-1)           # [B, T, num_experts]
        # Combine expert outputs using task-specific gating weights
        gate_long = gate_long.unsqueeze(-1)                        # [B, T, num_experts, 1]
        gate_surv = gate_surv.unsqueeze(-1)                        # [B, T, num_experts, 1]
        combined_long = torch.sum(expert_outputs * gate_long, dim=2)  # [B, T, d_model]
        combined_surv = torch.sum(expert_outputs * gate_surv, dim=2)  # [B, T, d_model]
        # 3. Task-specific output projections (heads)
        long_out = self.longitudinal_head(combined_long)           # [B, T, d_long]
        surv_logit = self.survival_head(combined_surv)
        surv_out = torch.sigmoid(surv_logit)                # [B, T, 1]
        return long_out, surv_out

In [140]:
import torch.nn as nn
import torch
class Decoder(nn.Module):
    """
    Decoder Block
    
    Parameters
    ----------
    d_long:
        Number of longitudinal outcomes
    d_base:
        Number of baseline / time-independent covariates
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_long,
                 d_base,
                 d_model,
                 nhead,
                 num_decoder_layers = 1,
                 dropout=0.1):
        super().__init__()
        
        self.embedding = nn.Sequential(
            nn.Linear(d_long + d_base, d_model),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model)
            )
        
        self.decoder_layers = nn.ModuleList([Decoder_Layer(d_model,nhead,dropout)
                                             for _ in range(num_decoder_layers)])
        
    def forward(self, long, base, mask, obs_time):
        # ConcatenateMMoEHeaddding
        x = torch.cat((long, base), dim=2)
        x = self.embedding(x)
        
        # Positional Embedding

        x = x + positional_encoding(
            x.shape[0], x.shape[1], x.shape[2], obs_time)

        # Decoder Layers
        for layer in self.decoder_layers:
            decoding = layer(x, x, mask)

        return decoding

class Transformer1(nn.Module):
    """
    An adaptation of the transformer model (Attention is All you Need)
    fofrom util import (get_tensors, get_mask, init_weights, get_std_opt)r survival analysis.
    
    Parameters
    ----------
    d_long:
        Number of longitudinal outcomes
    d_base:
        Number of baseline / time-independent covariates
    d_model:
        Dimension of the input vector (post embedding)
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_long,
                 d_base,
                 d_model = 32,
                 nhead = 4,
                 n_expert = 4,
                 d_ff = 64,  
                 num_decoder_layers = 3,
                 dropout = 0.2):
        super().__init__()
        self.decoder = Decoder(d_long, d_base, d_model, nhead, num_decoder_layers, dropout)

        self.mmoe_layer = Decoder_MMOE_Layer(d_model, nhead, n_expert, d_ff, d_long)

    def forward(self, long, base, mask, obs_time, pred_time):        
        # Decoder Layers
        x = self.decoder(long, base, mask, obs_time)
        
        # Decoder Layer with prediction time embedding
        x = x+positional_encoding(
            x.shape[0], x.shape[1], x.shape[2], pred_time)
        long,surv = self.mmoe_layer(x,x, mask)

        return long, surv


In [161]:
from Gate import Transformer2
from util import (get_tensors, get_mask, init_weights, get_std_opt)
model = Transformer2(d_long=3, d_base=2, d_model=32, nhead=4,
                    num_decoder_layers=4)
model.to('cuda')
model.apply(init_weights)
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = get_std_opt(optimizer, d_model=32, warmup_steps=200, factor=0.2)

n_epoch = 25
batch_size = 32
    
    
loss_values = []

for epoch in range(n_epoch):
    running_loss = 0
    train_id = np.random.permutation(train_id)
    for batch in range(0, len(train_id), batch_size):
        optimizer.zero_grad()
            
        indices = train_id[batch:batch+batch_size]
        batch_data = train_data[train_data["id"].isin(indices)]
            
        batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(batch_data.copy())
        batch_long_inp = batch_long[:,:-1,:].to('cuda');batch_long_out = batch_long[:,1:,:].to('cuda')  #time 1-11 as train and 12 as validation 
        batch_base = batch_base[:,:-1,:].to('cuda')
        batch_mask_inp = get_mask(batch_mask[:,:-1]).to('cuda')
        batch_mask_out = batch_mask[:,1:].unsqueeze(2).to('cuda') 
        obs_time = obs_time.to('cuda')
        yhat_long, yhat_surv = model(batch_long_inp, batch_base, batch_mask_inp,
                        obs_time[:,:-1].to('cuda'), obs_time[:,1:].to('cuda'))
        
        loss1 = long_loss(yhat_long, batch_long_out, batch_mask_out)
        loss2 = surv_loss(yhat_surv, batch_mask, batch_e)
        
        loss = loss1 + loss2
        
        loss.backward()
        scheduler.step()
        running_loss += loss
    loss_values.append(running_loss.tolist())
plt.plot((loss_values-np.min(loss_values))/(np.max(loss_values)-np.min(loss_values)), 'b-')

ImportError: cannot import name 'Transformer2' from 'Gate' (/home/shijimao/Proj1/Gate.py)

In [163]:
from Gate import Transformer2

ImportError: cannot import name 'Transformer2' from 'Gate' (/home/shijimao/Proj1/Gate.py)

In [146]:
loss_values

[1048.9539129910388,
 972.8937334973884,
 826.1892618594885,
 600.8639152200541,
 318.9957113194537,
 136.00979473102427,
 110.45224968798891,
 81.93828470549927,
 62.32205963914217,
 53.03192154609149,
 48.09715204398629,
 44.65603587824041,
 44.038893849822664,
 41.51567918020396,
 39.011522313665154,
 38.4125590917343,
 38.7850665928162,
 37.57539453937534,
 35.35756088153758,
 35.70045022864838,
 34.40393437763319,
 34.142595533611285,
 32.17954106241806,
 32.228888145215834,
 33.26472138134074,
 32.044389393484835,
 30.692764782688688,
 30.578030498379334,
 30.091239701020438,
 29.441394896199725,
 29.33703544112531,
 27.78134679460031,
 28.255942967177763,
 27.688377349241684,
 26.943911282729633,
 27.02925634293412,
 26.003318863572954,
 25.80465057676105,
 25.859010715610815,
 25.87035290708557,
 25.36024976835031,
 24.593654530160563,
 24.69587609528654,
 24.01936429926206,
 24.49325731623882,
 23.643042593517627,
 23.293659409257195,
 23.123610526304635,
 23.635403512930996,


In [152]:
from Gate import Transformer1

In [155]:
from util import (get_tensors, get_mask, init_weights, get_std_opt)
model = Transformer1(d_long=3, d_base=2, d_model=32, nhead=4,
                    num_decoder_layers=4)
model.to('cuda')
model.apply(init_weights)
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = get_std_opt(optimizer, d_model=32, warmup_steps=200, factor=0.2)

n_epoch = 50
batch_size = 64
    
    
loss_values = []

for epoch in range(n_epoch):
    running_loss = 0
    train_id = np.random.permutation(train_id)
    for batch in range(0, len(train_id), batch_size):
        optimizer.zero_grad()
            
        indices = train_id[batch:batch+batch_size]
        batch_data = train_data[train_data["id"].isin(indices)]
            
        batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(batch_data.copy())
        batch_long_inp = batch_long[:,:-1,:].to('cuda');batch_long_out = batch_long[:,1:,:].to('cuda')  #time 1-11 as train and 12 as validation 
        batch_base = batch_base[:,:-1,:].to('cuda')
        batch_mask_inp = get_mask(batch_mask[:,:-1]).to('cuda')
        batch_mask_out = batch_mask[:,1:].unsqueeze(2).to('cuda') 
        obs_time = obs_time.to('cuda')
        yhat_long, yhat_surv = model(batch_long_inp, batch_base, batch_mask_inp,
                        obs_time[:,:-1].to('cuda'), obs_time[:,1:].to('cuda'),use_moe = False)
        
        loss1 = long_loss(yhat_long, batch_long_out, batch_mask_out)
        loss2 = surv_loss(yhat_surv, batch_mask, batch_e)
        
        loss = loss1 + loss2
        
        loss.backward()
        scheduler.step()
        running_loss += loss
    loss_values.append(running_loss.tolist())
plt.plot((loss_values-np.min(loss_values))/(np.max(loss_values)-np.min(loss_values)), 'b-')

NameError: name 'plt' is not defined

In [156]:
loss_values

[1227.9036712943239,
 1081.036820630734,
 849.4745719671103,
 651.3093409186251,
 529.6113087922205,
 449.4477444265342,
 394.9660752161263,
 350.17166014160335,
 310.9212975219318,
 274.2383190993341,
 239.04214953494852,
 205.51665451285834,
 175.92091648960937,
 149.98713764483554,
 125.2336170595159,
 103.06979246765974,
 86.2684417305632,
 72.10232630056818,
 61.44794099922382,
 53.65549494145909,
 47.78500461797283,
 44.10368175728349,
 41.38883288751451,
 38.35520578298846,
 36.63679381962372,
 34.62101195149265,
 34.130881570306556,
 33.57303732543674,
 31.900528736398293,
 30.459250033361318,
 29.974676880897746,
 30.270780113805866,
 30.815262652287675,
 29.4741770416686,
 28.766773650532777,
 28.02349843183503,
 28.211932149161804,
 26.89598918339025,
 26.52281245689925,
 26.104369498561756,
 26.375766386967275,
 26.69391133519107,
 25.810975540350576,
 25.81529852507116,
 25.166144496803163,
 24.97105022525799,
 24.900731504550322,
 25.10450388237312,
 24.371948463765587,
 

In [124]:
batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(train_data.copy())
batch_long_inp = batch_long[:,:-1,:].to('cuda');batch_long_out = batch_long[:,1:,:].to('cuda')  #time 1-11 as train and 12 as validation 
batch_base = batch_base[:,:-1,:].to('cuda')
batch_mask_inp = get_mask(batch_mask[:,:-1]).to('cuda')
batch_mask_out = batch_mask[:,1:].unsqueeze(2).to('cuda') 
obs_time = obs_time.to('cuda')
yhat_long, yhat_surv = model(batch_long_inp, batch_base, batch_mask_inp,
                obs_time[:,:-1].to('cuda'), obs_time[:,1:].to('cuda'))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7000x10 and 32x32)

In [116]:
test3 = Decoder_p(d_model = 32, nhead = 4,
                 num_decoder_layers = 1,dropout= 0.1)
test3.train()
test3.to('cuda')
d = test3(a,a,get_mask(batch_mask[:,:-1].to('cuda')),obs_time[:,1:].to('cuda'))

In [128]:
b,c = test2(a,get_mask(batch_mask[:,:-1].to('cuda')),obs_time[:,1:].to('cuda'))

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

In [56]:
from Gate import Transformer1

In [41]:
sum(batch_long2 == batch_long)
sum(batch_base2 == batch_base)


tensor([[700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700],
        [700, 700]])

In [46]:
batch_size = 50
for batch in range(0, len(train_id), batch_size):
    indices = train_id[batch:batch+batch_size]
    batch_data = train_data[train_data["id"].isin(indices)]
    batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(batch_data)
    print(batch_long.shape)


  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

torch.Size([50, 11, 3])


  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

torch.Size([50, 11, 3])
torch.Size([50, 11, 3])


  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.te

torch.Size([50, 11, 3])


  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.tensor(row.loc[list(long)])
  x_base[ii,jj,:] = torch.tensor(row.loc[list(base)])
  x_long[ii,jj,:] = torch.te

In [47]:
obs_time.shape

torch.Size([50, 11])

In [22]:
import numpy as np
np.max(train_data.loc[:,"visit"]) + 1


np.int64(11)

In [None]:
import warnings 
warnings.filterwarnings("ignore")
n_epoch = 50
batch_size = 32
    
    
loss_values = []
loss1_list = []
loss2_list = []
for epoch in range(n_epoch):
    running_loss = 0
    train_id = np.random.permutation(train_id)
    for batch in range(0, len(train_id), batch_size):
        optimizer.zero_grad()
            
        indices = train_id[batch:batch+batch_size]
        batch_data = train_data[train_data["id"].isin(indices)]
            
        batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(batch_data.copy())
        batch_long_inp = batch_long[:,:-1,:].to('cuda')
        batch_long_out = batch_long[:,1:,:].to('cuda')
        batch_base = batch_base[:,:-1,:].to('cuda')
        batch_mask_inp = get_mask(batch_mask[:,:-1]).to('cuda')
        batch_mask_out = batch_mask[:,1:].unsqueeze(2).to('cuda') 
        obs_time = obs_time.to('cuda')
        yhat_long, yhat_surv = model(batch_long_inp, batch_base, batch_mask_inp,
                        obs_time[:,:-1].to('cuda'), obs_time[:,1:].to('cuda'))
        
        loss1 = long_loss(yhat_long, batch_long_out, batch_mask_out)
        loss2 = surv_loss(yhat_surv, batch_mask, batch_e)
        
        #loss = loss1 + loss2
        loss = multi_task_loss(loss1, loss2)
        
        loss.backward()
        scheduler.step()
        running_loss += loss
        loss1_list.append(loss1.tolist())
        loss2_list.append(loss2.tolist())
    loss_values.append(running_loss.tolist())
plt.plot((loss_values-np.min(loss_values))/(np.max(loss_values)-np.min(loss_values)), 'b-')


In [None]:
def data_preprocessing(source):
    if source == 'JM':
        ##Load Data
        data = pd.read_pickle('data/simulated_data.pkl')
        I = data['id'].nunique()

        ## split train/test
        random_id = range(I) #np.random.permutation(range(I))
        train_id = random_id[0:int(0.7*I)]
        test_id = random_id[int(0.7*I):I]

        train_data = data[data["id"].isin(train_id)]
        test_data = data[data["id"].isin(test_id)]
    return train_id, train_data, test_id, test_data


def main(d_long = 3, d_base = 2, d_model = 32, nhead = 4, num_decoder_layers = 7,n_epoch = 50, batch_size = 32):
    train_id, train_data, test_id, test_data = data_preprocessing('JM')

    model = Transformer1(d_long=d_long, d_base=d_base, d_model=d_model, nhead=nhead,
                    num_decoder_layers=num_decoder_layers)
    model.to('cuda')
    model.apply(init_weights)
    model = model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
    scheduler = get_std_opt(optimizer, d_model=d_model, warmup_steps=200, factor=0.2)

    n_epoch = n_epoch 
    batch_size = batch_size 