In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MMoE_FFN(nn.Module):
    """
    Mixture-of-Experts Feed-Forward Network (MMoE_FFN).
    Contains:
      - Router (gating) linear layer to produce softmax weights for N experts.
      - N expert feed-forward networks (each an MLP).
      - Two task-specific linear heads (tower networks) for predictions.
    """
    def __init__(self, d_model: int, hidden_dim: int, num_experts: int, dropout: float = 0.1):
        """
        Args:
            d_model: Dimension of input features (and of expert outputs).
            hidden_dim: Hidden layer dimension for each expert MLP.
            num_experts: Number of expert networks.
            dropout: Dropout probability for expert hidden layers.
        """
        super(MMoE_FFN, self).__init__()
        # Gating network: linear layer that outputs N logits per token (one per expert)
        self.router = nn.Linear(d_model, num_experts)
        # Expert networks: each is a two-layer feed-forward (d_model -> hidden_dim -> d_model)
        # We use ReLU activation and dropout on the hidden layer for each expert.
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, d_model)
            )
            for _ in range(num_experts)
        ])
        # Task-specific tower heads: linear layers for each task
        self.longitudinal_head = nn.Linear(d_model, 3)  # 3 outputs for longitudinal task
        self.survival_head    = nn.Linear(d_model, 1)  # 1 output for survival task
    
    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: Tensor [B, S, d_model] – input features (e.g., from attention output).
        Returns:
            combined_out: Tensor [B, S, d_model] – combined expert output for each token.
            long_out: Tensor [B, S, 3] – longitudinal task predictions for each token.
            surv_out: Tensor [B, S, 1] – survival task predictions for each token.
        """
        B, S, _ = x.shape
        # Gating: compute softmax weights for each expert per token
        # router(x) -> [B, S, N], then softmax along the N dimension to get probabilities
        gating_logits = self.router(x)                                # [B, S, N]
        gating_weights = F.softmax(gating_logits, dim=-1)             # [B, S, N] (sum of N dim = 1 per token)&#8203;:contentReference[oaicite:4]{index=4}
        
        # Compute each expert's output on x
        # (Each expert is applied position-wise on the sequence)
        expert_outputs = [expert(x) for expert in self.experts]       # list of [B, S, d_model]
        expert_outputs = torch.stack(expert_outputs, dim=2)           # [B, S, N, d_model]
        
        # Mixture: weight and sum expert outputs using gating weights
        # combined_out[b,s,:] = Σ_{i=0}^{N-1} gating_weights[b,s,i] * expert_outputs[b,s,i,:]&#8203;:contentReference[oaicite:5]{index=5}
        combined_out = (gating_weights.unsqueeze(-1) * expert_outputs).sum(dim=2)  # [B, S, d_model]
        
        # Task-specific outputs from the combined representation (tower networks)
        long_out = self.longitudinal_head(combined_out)  # [B, S, 3]  (logits or scores for longitudinal task)&#8203;:contentReference[oaicite:6]{index=6}
        surv_out = self.survival_head(combined_out)      # [B, S, 1]  (logit or score for survival task)&#8203;:contentReference[oaicite:7]{index=7}
        return long_out, surv_out

class Decoder_Layer2(nn.Module):
    """
    Transformer Decoder Layer with MMoE.
    - Self-attention sub-layer (with residual connection + norm)
    - MMoE feed-forward sub-layer (with residual connection + norm)
    """
    def __init__(self, d_model: int, nhead: int, hidden_dim: int, num_experts: int, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimension (embedding size for inputs).
            nhead: Number of heads for multi-head self-attention.
            hidden_dim: Hidden dimension for each expert FFN.
            num_experts: Number of expert networks in the MMoE FFN.
            dropout: Dropout probability for attention and FFN.
        """
        super(Decoder_Layer2, self).__init__()
        # Multi-head self-attention (batch_first=True for [B, S, d_model] inputs/outputs)
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        # Mixture-of-Experts feed-forward network
        self.mmoe_ffn = MMoE_FFN(d_model, hidden_dim, num_experts, dropout=dropout)
        # LayerNorm for post-attention and post-FFN
        self.norm1 = nn.LayerNorm(d_model)
        # Dropout layers for residual connections
        self.dropout_attn = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor = None) -> tuple:
        """
        Args:
            x: Tensor [B, S, d_model] – input sequence to decode (encoder output).
            attn_mask: Optional mask for attention. 
                       - If shape [B, S], it is used as a key padding mask (True for positions to mask out).
                       - If shape [S, S] (or [B, S, S]), it is used as an attention mask (e.g., causal mask).
        Returns:
            out: Tensor [B, S, d_model] – output representation after this layer.
            long_out: Tensor [B, S, 3] – longitudinal task output for each token.
            surv_out: Tensor [B, S, 1] – survival task output for each token.
        """
        # 1. Self-Attention sub-layer
        if attn_mask is not None:
            if attn_mask.dim() == 2 and attn_mask.shape == (x.size(0), x.size(1)):
                # Treat 2D mask of shape [B, S] as key padding mask
                attn_out, _ = self.self_attn(x, x, x, key_padding_mask=attn_mask, need_weights=False)
            else:
                # Treat as an attn_mask (shape [S, S] or [B, S, S])
                attn_out, _ = self.self_attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        else:
            attn_out, _ = self.self_attn(x, x, x, need_weights=False)
        # Add & Norm: residual connection and LayerNorm for attention output&#8203;:contentReference[oaicite:8]{index=8}
        x = x + self.dropout_attn(attn_out)
        x = self.norm1(x)
        
        # 2. Mixture-of-Experts FFN sub-layer
        long_out, surv_out = self.mmoe_ffn(x)
        # Add & Norm: residual connection and LayerNorm for FFN output&#8203;:contentReference[oaicite:9]{index=9}
        
        # Return final representation and task outputs
        return long_out, surv_out

class TransformerDecoderMMoE(nn.Module):
    """
    Transformer Decoder with MMoE (single-layer decoder).
    Wraps Decoder_Layer to produce multi-task outputs.
    """
    def __init__(self, d_model: int, nhead: int, hidden_dim: int, num_experts: int, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimensionality (must match encoder output).
            nhead: Number of attention heads.
            hidden_dim: Hidden size for expert MLPs.
            num_experts: Number of experts in the mixture.
            dropout: Dropout probability.
        """
        super(TransformerDecoderMMoE, self).__init__()
        self.layer = Decoder_Layer2(d_model, nhead, hidden_dim, num_experts, dropout)
    
    def forward(self, encoder_out: torch.Tensor, attn_mask: torch.Tensor = None) -> tuple:
        """
        Args:
            encoder_out: Tensor [B, S, d_model] – output from the encoder (input to decode).
            attn_mask: Optional mask for attention (same format as in Decoder_Layer).
        Returns:
            out: Tensor [B, S, d_model] – decoder output representation.
            long_out: Tensor [B, S, 3] – longitudinal task predictions per token.
            surv_out: Tensor [B, S, 1] – survival task predictions per token.
        """
        long_out, surv_out = self.layer(encoder_out, attn_mask)
        return long_out, surv_out



In [None]:
Models

from Simulation.data_simulation_base import simulate_JM_base
n_sim = 1
I = 1000
obstime = [0,1,2,3,4,5,6,7,8,9,10]
landmark_times = [1,2,3,4,5]
pred_windows = [1,2,3]
scenario = "none" # ["none", "interaction", "nonph"]
from sklearn.preprocessing import MinMaxScaler
from data_simulation_base import simulate_JM_base
data_all = simulate_JM_base(I=I, obstime=obstime, opt=scenario, seed=n_sim)
data = data_all[data_all.obstime <= data_all.time]

## split train/test
random_id = range(I) #np.random.permutation(range(I))
train_id = random_id[0:int(0.7*I)]
test_id = random_id[int(0.7*I):I]

train_data = data[data["id"].isin(train_id)]
test_data = data[data["id"].isin(test_id)]
x1= train_data[['X1','X2']]
y = train_data[['Y1','Y2','Y3']]

ModuleNotFoundError: No module named 'Models'

In [4]:
import torch
import torch.nn as nn

import torch.nn.functional as F



class Decoder(nn.Module):
    """
    Decoder Block
    
    Parameters
    ----------
    d_long:
        Number of longitudinal outcomes
    d_base:
        Number of baseline / time-independent covariates
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_long,
                 d_base,
                 d_model,
                 nhead,
                 num_decoder_layers,
                 dropout):
        super().__init__()
        
        self.embedding = nn.Sequential(
            nn.Linear(d_long + d_base, d_model),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model)
            )
        
        self.decoder_layers = nn.ModuleList([Decoder_Layer(d_model,nhead,dropout)
                                             for _ in range(num_decoder_layers)])
        
    def forward(self, long, base, mask, obs_time):
        # Concatenate longitudinal and baseline data
        x = torch.cat((long, base), dim=2)
        
        # Linear Embedding
        x = self.embedding(x)
        
        # Positional EmbeddingTransformerDecoderMMoE

        x = x + positional_encoding(
            x.shape[0], x.shape[1], x.shape[2], obs_time)
        
        # Decoder Layers
        for layer in self.decoder_layers:
            decoding = layer(x, x, mask)

        return decoding

'''
class Decoder_p(nn.Module):
    """
    Decoder Block
    
    Parameters_
    ----------
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_model,
                 nhead,
                 num_decoder_layers,
                 dropout):
        super().__init__()

        self.decoder_layers = nn.ModuleList([Decoder_Layer(d_model,nhead,dropout)
                                             for _ in range(num_decoder_layers)])
        
    def forward(self, q, kv, mask, pred_time):
        # Positional Embedding
        
        q = q + positional_encoding(
            q.shape[0], q.shape[1], q.shape[2], pred_time)
        
        # Decoder Layers
        for layer in self.decoder_layers:
            x = layer(q, kv,mask)

        return x
'''
class TransformerDecoderMMoE(nn.Module):
    """
    Transformer Decoder with MMoE (single-layer decoder).
    Wraps Decoder_Layer to produce multi-task outputs.
    """
    def __init__(self, d_model: int, nhead: int, hidden_dim: int, num_experts: int, dropout: float = 0.1):
        """
        Args:
            d_model: Model dimensionality (must match encoder output).
            nhead: Number of attention heads.
            hidden_dim: Hidden size for expert MLPs.
            num_experts: Number of experts in the mixture.
            dropout: Dropout probability.
        """
        super().__init__()
        self.layer = Decoder_Layer2(d_model, nhead, hidden_dim, num_experts, dropout)
    
    def forward(self, encoder_out: torch.Tensor, attn_mask: torch.Tensor = None) -> tuple:
        """
        Args:
            encoder_out: Tensor [B, S, d_model] – output from the encoder (input to decode).
            attn_mask: Optional mask for attention (same format as in Decoder_Layer).
        Returns:
            out: Tensor [B, S, d_model] – decoder output representation.
            long_out: Tensor [B, S, 3] – longitudinal task predictions per token.
            surv_out: Tensor [B, S, 1] – survival task predictions per token.
        """
        long_out, surv_out = self.layer(encoder_out, attn_mask)
        return long_out, surv_out



class Transformer1(nn.Module):
    """
    An adaptation of the transformer model (Attention is All you Need)
    for survival analysis.
    
    Parameters
    ----------
    d_long:
        Number of longitudinal outcomes
    d_base:
        Number of baseline / time-independent covariates
    d_model:
        Dimension of the input vector (post embedding)
    nhead:
        Number of heads
    num_decoder_layers:
        Number of decoder layers to stack
    dropout:
        The dropout value
    """
    def __init__(self,
                 d_long,
                 d_base,
                 d_model = 32,
                 nhead = 4,
                 n_expert = 4,
                 d_ff = 64,  
                 num_decoder_layers = 3,
                 dropout = 0.2):
        super().__init__()
        self.decoder = Decoder(d_long, d_base, d_model, nhead, num_decoder_layers, dropout)

        #self.decoder_pred = Decoder_p(d_model, nhead, 1, dropout)
        
        self.long = nn.Sequential(
            nn.Linear(d_model, d_long)
        )
        
        self.surv = nn.Sequential(
            nn.Linear(d_model, 1)
        )
        self.decoder_mmoe = TransformerDecoderMMoE(d_model = d_model,nhead = nhead,hidden_dim = 64, num_experts = n_expert)

    def forward(self, long, base, mask, obs_time, pred_time):        
        # Decoder Layers
        x = self.decoder(long, base, mask, obs_time)
        
        # Decoder Layer with prediction time embedding
        
        long,surv = self.decoder_mmoe(x, x, mask, pred_time)

        return long, surv

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """
    Each expert is a small feed-forward subnetwork.
    Here, we map d_model -> d_ff -> d_model.
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        # x: [B, T, d_model]
        return self.net(x)


class MMoEHead(nn.Module):
    """
    Shared Expert Pool + 2 Gating Networks (one for each task).
    Then each task does a final linear layer to get the actual prediction.
    """
    def __init__(self, d_model, d_ff, n_expert, d_long):
        """
        d_model: dimension of transformer output
        d_ff: hidden dimension inside each expert
        n_expert: number of experts
        d_long: dimension of the longitudinal output
        """
        super().__init__()
        # Create shared experts
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(n_expert)
        ])
        # Each task has its own gating
        self.gate_long = nn.Linear(d_model, n_expert)
        self.gate_surv = nn.Linear(d_model, n_expert)

        # Final output layers for each task
        self.long_out = nn.Linear(d_model, d_long)  # e.g. predict d_long dims
        self.surv_out = nn.Linear(d_model, 1)       # e.g. 1-dim survival logit

    def forward(self, x):
        """
        x: [B, T, d_model] from the Transformer.
        Returns:
          long_pred: [B, T, d_long]
          surv_pred: [B, T, 1]
        """
        B, T, _ = x.shape

        # 1) Compute each expert's output
        #    We'll stack them: shape will be [B, T, n_expert, d_model]
        expert_outs = []
        for expert in self.experts:
            e_out = expert(x)             # [B, T, d_model]
            expert_outs.append(e_out)
        # Stack along new dim=2 => (n_expert)
        # result: [B, T, n_expert, d_model]
        expert_outs = torch.stack(expert_outs, dim=2)

        # 2) Gating for longitudinal
        gate_logits_long = self.gate_long(x)              # [B, T, n_expert]
        gate_weights_long = F.softmax(gate_logits_long, dim=-1)  # soft gating
        # expand so we can multiply
        gate_weights_long = gate_weights_long.unsqueeze(-1)       # [B, T, n_expert, 1]

        # Weighted sum over the expert dimension
        # shape => [B, T, d_model]
        long_combined = (expert_outs * gate_weights_long).sum(dim=2)

        # 3) Gating for survival
        gate_logits_surv = self.gate_surv(x)              # [B, T, n_expert]
        gate_weights_surv = F.softmax(gate_logits_surv, dim=-1)
        gate_weights_surv = gate_weights_surv.unsqueeze(-1)       # [B, T, n_expert, 1]

        surv_combined = (expert_outs * gate_weights_surv).sum(dim=2)

        # 4) Final linear heads for each task
        long_pred = self.long_out(long_combined)    # [B, T, d_long]
        surv_logit = self.surv_out(surv_combined)    # [B, T, 1]
        surv_pred = torch.sigmoid(surv_logit)        # or however you interpret survival

        return long_pred, surv_pred


In [6]:
class Decoder_Layer(nn.Module):
    """
    Decoder Block
    
    Parameters
    ----------
    d_model:
        Dimension of the input vector
    nhead:
        Number of heads
    dropout:
        The dropout value
    """
    
    def __init__(self,
                 d_model,
                 nhead,
                 dropout):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
        
        self.Attention = MultiHeadAttention(d_model, nhead)
                
        self.feedForward = nn.Sequential(
            nn.Linear(d_model,64),
            nn.ReLU(),
            nn.Linear(64,d_model),
            nn.Dropout(dropout)
            )
        
        self.layerNorm1 = nn.LayerNorm(d_model)
        self.layerNorm2 = nn.LayerNorm(d_model)
        
    def forward(self, q, kv, mask):
        
        # Attention
        residual = q
        x = self.Attention(query=q, key=kv, value=kv, mask = mask)
        x = self.dropout(x)
        x = self.layerNorm1(x + residual)
        
        # Feed Forward
        residual = x
        x = self.feedForward(x)
        x = self.layerNorm2(x + residual)
        
        return x

def positional_encoding(batch_size, length, d_model, obs_time):
    """
    Positional Encoding for each visit
    
    Parameters
    ----------
    batch_size:
        Number of subjects in batch
    length:
        Number of visits
    d_model:
        Dimension of the model vector
    obs_time:
        Observed/recorded time of each visit
    """
    PE = torch.zeros((batch_size, length, d_model)).to('cuda')
    if obs_time.ndim == 0:
        obs_time = obs_time.repeat(batch_size).unsqueeze(1)
    elif obs_time.ndim == 1:
        obs_time = obs_time.repeat(batch_size,1)
    obs_time = obs_time.to('cuda')
    pow0 = torch.pow(10000, torch.arange(0, d_model, 2, dtype=torch.float32)/d_model).to('cuda')

    PE[:, :, 0::2] = torch.sin(torch.einsum('ij,k->ijk', obs_time, pow0))
    pow1 = torch.pow(10000, torch.arange(1, d_model, 2, dtype=torch.float32)/d_model).to('cuda')
    PE[:, :, 1::2] = torch.cos(torch.einsum('ij,k->ijk', obs_time, pow1))

    return PE

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // nhead
        self.nhead = nhead
        
        assert (
            d_model % nhead == 0
        ), "Embedding size (d_model) needs to be divisible by number of heads"
        
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)
        
    def attention(self, query, key, value, d_k, mask = None, dropout=None):
    
        scores = torch.matmul(query, key.transpose(-2, -1)) /  np.sqrt(d_k)
        if mask is not None:
            mask = mask.unsqueeze(1).to('cuda')
            scores = scores.masked_fill(mask == 0, -float('inf'))
        scores = F.softmax(scores, dim=-1)
        
        if dropout is not None:
            scores = dropout(scores)
            
        output = torch.matmul(scores, value)
        return output

    def forward(self, query, key, value, mask = None):
        I = query.shape[0]
        
        # perform linear operation and split into N heads
        query = self.q_linear(query).view(I, -1, self.nhead, self.d_k)
        key = self.k_linear(key).view(I, -1, self.nhead, self.d_k)
        value = self.v_linear(value).view(I, -1, self.nhead, self.d_k)
        
        # transpose to get dimensions I * nheanum_experts

        # calculate attention
        scores = self.attention(query, key, value, self.d_k, mask, self.dropout)
        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous()\
        .view(I, -1, self.d_model)
        output = self.out(concat)
    
        return output

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MMoE_FFN(nn.Module):
    """
    Multi-gate Mixture-of-Experts Feed-Forward Network.
    - Contains a shared pool of expert FFNs (each maps d_model -> hidden_dim -> d_model).
    - Two task-specific gating networks (longitudinal and survival) produce softmax weights over experts.
    - Combines expert outputs per task and passes them through task-specific linear heads.
    """
    def __init__(self, d_model, hidden_dim, num_experts):
        super(MMoE_FFN, self).__init__()
        self.num_experts = num_experts
        # Shared pool of expert networks (each expert is an FFN: Linear(d_model->hidden_dim) -> ReLU -> Linear(hidden_dim->d_model))
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, d_model)
            )
            for _ in range(num_experts)
        ])
        # Task-specific gating layers (one for each task) that output a weight for each expert
        self.gate_long = nn.Linear(d_model, num_experts)  # Longitudinal task gate
        self.gate_surv = nn.Linear(d_model, num_experts)  # Survival task gate
        # Task-specific tower output layers
        self.longitudinal_head = nn.Linear(d_model, 3)  # Outputs 3 features for longitudinal task
        self.survival_head   = nn.Linear(d_model, 1)    # Outputs 1 feature for survival task

    def forward(self, x):
        """
        :param x: Input tensor of shape [B, S, d_model].
        :return: (long_out, surv_out, combined_output):
                 - long_out: Longitudinal task output of shape [B, S, 3]
                 - surv_out: Survival task output of shape [B, S, 1]
        """
        B, S, D = x.size()
        # Flatten batch and sequence dims for processing through experts and gates
        x_flat = x.view(B * S, D)  # shape [B*S, d_model]
        # Compute each expert's output for all positions (shared across tasks)
        expert_outputs = []  # will collect outputs of shape [B*S, d_model] from each expert
        for expert in self.experts:
            expert_out = expert(x_flat)               # [B*S, d_model] output from this expert
            expert_outputs.append(expert_out)
        # Stack expert outputs into a single tensor of shape [B*S, num_experts, d_model]
        expert_outputs = torch.stack(expert_outputs, dim=1)  # dim=1 indexes the expert
        # Compute gating logits for each task and apply softmax to get mixture weights (one weight per expert)
        gate_long_logits = self.gate_long(x_flat)  # [B*S, num_experts] (logits for longitudinal task experts)
        gate_surv_logits = self.gate_surv(x_flat)  # [B*S, num_experts] (logits for survival task experts)
        gate_long = F.softmax(gate_long_logits, dim=-1)  # [B*S, num_experts] weights for each expert (longitudinal)
        gate_surv = F.softmax(gate_surv_logits, dim=-1)  # [B*S, num_experts] weights for each expert (survival)
        # Reshape weights for broadcasting: [B*S, num_experts] -> [B*S, num_experts, 1]
        gate_long = gate_long.unsqueeze(-1)
        gate_surv = gate_surv.unsqueeze(-1)
        # Compute weighted sum of expert outputs for each task using the gate weights
        long_combined_flat = torch.sum(expert_outputs * gate_long, dim=1)  # [B*S, d_model] combined output for longitudinal task
        surv_combined_flat = torch.sum(expert_outputs * gate_surv, dim=1)  # [B*S, d_model] combined output for survival task
        # Reshape combined outputs back to [B, S, d_model]
        long_combined = long_combined_flat.view(B, S, D)
        surv_combined = surv_combined_flat.view(B, S, D)
        # Compute final task-specific outputs via the tower heads
        long_out = self.longitudinal_head(long_combined)  # [B, S, 3] longitudinal task output
        surv_out = self.survival_head(surv_combined)      # [B, S, 1] survival task output
    
        return long_out, surv_out


In [1]:
from Simulation.data_simulation_base import simulate_JM_base
n_sim = 1
I = 1000
obstime = [0,1,2,3,4,5,6,7,8,9,10]
landmark_times = [1,2,3,4,5]
pred_windows = [1,2,3]
scenario = "none" # ["none", "interaction", "nonph"]
from sklearn.preprocessing import MinMaxScaler
from data_simulation_base import simulate_JM_base
data_all = simulate_JM_base(I=I, obstime=obstime, opt=scenario, seed=n_sim)
data = data_all[data_all.obstime <= data_all.time]

## split train/test
random_id = range(I) #np.random.permutation(range(I))
train_id = random_id[0:int(0.7*I)]
test_id = random_id[int(0.7*I):I]

train_data = data[data["id"].isin(train_id)]
test_data = data[data["id"].isin(test_id)]
x1= train_data[['X1','X2']]
y = train_data[['Y1','Y2','Y3']]

import sys
import torch
sys.path.append("/home/shijimao/TransformerJM/Models")
sys.path.append("/home/shijimao/TransformerJM/Simulation")
from Models.Transformer.functions import (get_tensors, get_mask, init_weights, get_std_opt)
from Models.Transformer.loss import (long_loss, surv_loss)
from Models.metrics import (AUC, Brier, MSE)
import numpy as np
torch.manual_seed(0)


<torch._C.Generator at 0x7bd55910b8d0>

In [10]:
import matplotlib.pyplot as plt
model = Transformer1(d_long=3, d_base=2, d_model=32, nhead=4,
                    num_decoder_layers=7)
model.to('cuda')
model.apply(init_weights)
model = model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
scheduler = get_std_opt(optimizer, d_model=32, warmup_steps=200, factor=0.2)
n_epoch = 50
batch_size = 32
import warnings 
warnings.filterwarnings("ignore")    
    
loss_values = []
loss_test = []
loss1_list = []
loss2_list = []

for epoch in range(n_epoch):
    running_loss = 0
    train_id = np.random.permutation(train_id)
    for batch in range(0, len(train_id), batch_size):
        optimizer.zero_grad()
            
        indices = train_id[batch:batch+batch_size]
        batch_data = train_data[train_data["id"].isin(indices)]
            
        batch_long, batch_base, batch_mask, batch_e, batch_t, obs_time = get_tensors(batch_data.copy())
        batch_long_inp = batch_long[:,:-1,:].to('cuda')
        batch_long_out = batch_long[:,1:,:].to('cuda')
        batch_base = batch_base[:,:-1,:].to('cuda')
        batch_mask_inp = get_mask(batch_mask[:,:-1]).to('cuda')
        batch_mask_out = batch_mask[:,1:].unsqueeze(2).to('cuda') 
        obs_time = obs_time.to('cuda')
        yhat_long, yhat_surv = model(batch_long_inp, batch_base, batch_mask_inp,
                        obs_time = obs_time[:,:-1].to('cuda'), pred_time = obs_time[:,1:].to('cuda'))
        loss1 = long_loss(yhat_long, batch_long_out, batch_mask_out)
        loss2 = surv_loss(yhat_surv, batch_mask, batch_e)
        
        loss = loss1 + loss2
        loss.backward()
        scheduler.step()
        running_loss += loss
        loss1_list.append(loss1.tolist())
        loss2_list.append(loss2.tolist())
    loss_values.append(running_loss.tolist())
plt.plot((loss_values-np.min(loss_values))/(np.max(loss_values)-np.min(loss_values)), 'b-')

RuntimeError: The size of tensor a (10) must match the size of tensor b (4) at non-singleton dimension 3