In [1]:
import os
import time
import pickle
import random
import pandas as pd
import numpy as np
import torch
from types import SimpleNamespace
from sklearn.preprocessing import MinMaxScaler

# ==============================
# Set args manually (simulate argparse)
# ==============================
args = SimpleNamespace(
    seed=1,
    epoch=20,
    batch_size=16,
    num_enc_layer=4,
    num_dec_layer=4,
    d_long=3,
    num_head=4,
    model_size=16,
    suffix='train',
    model='LSR',
    data='simulated_data',
    local=True,
    Y1_missing=0.0,
    Y2_missing=0.0,
    Y3_missing=0.0,
    inten_weight=0.01,
    surv_weight=0.1,
    lr=0.0003
)

# ==============================
# Set device and seed
# ==============================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# ==============================
# Load dataset
# ==============================
dataset_path = f'data/{args.data}.pkl'
data_all = pd.read_pickle(dataset_path)
I = data_all["id"].values[-1] + 1
dag_info_path = f'data/_info.pkl' 
with open(dag_info_path, 'rb') as f:
    dag_info = pickle.load(f)
    
print("=" * 50)
print(f"Starting training for dataset: {args.data}")
print(f"{args.num_head} heads, {args.num_enc_layer} enc layers, {args.num_dec_layer} dec layers, {args.model_size} model dimension")
print(f"Data contains {I} unique trajectories")
print("=" * 50)

# ==============================
# Filter: only rows where obstime <= time
# ==============================
data = data_all[data_all.obstime <= data_all.time]

# Split IDs
random_id = list(range(I))
train_id = random_id[0:int(0.6 * I)]
vali_id = random_id[int(0.6 * I):int(0.8 * I)]
test_id = random_id[int(0.8 * I):I]

# Select rows by id
train_data = data[data["id"].isin(train_id)]
vali_data = data[data["id"].isin(vali_id)]
test_data = data[data["id"].isin(test_id)]

# ==============================
# Scale Y columns
# ==============================
Y_str_list = [f"Y{i+1}" for i in range(args.d_long)]
scaler = MinMaxScaler(feature_range=(-1, 1))

train_data.loc[:, Y_str_list] = scaler.fit_transform(train_data.loc[:, Y_str_list])
vali_data.loc[:, Y_str_list] = scaler.transform(vali_data.loc[:, Y_str_list])
test_data.loc[:, Y_str_list] = scaler.transform(test_data.loc[:, Y_str_list])


Using device: cuda
Starting training for dataset: simulated_data
4 heads, 4 enc layers, 4 dec layers, 16 model dimension
Data contains 1000 unique trajectories


In [2]:
from transformerlsr import TransformerLSR  
from LSRfunctions import long_loss_LSR  # replace with your actual loss
from LSRfunctions import get_tensors
from util import surv_loss_lsr
import warnings 
warnings.filterwarnings("ignore")
model = TransformerLSR(
    d_long=args.d_long,
    d_base=2,
    dag_info=dag_info,
    d_model=args.model_size,
    nhead=args.num_head,
    num_encoder_layers=args.num_enc_layer,
    num_decoder_layers=args.num_dec_layer,
    device=device
)

long_loss = long_loss_LSR
model.to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
batch_size = args.batch_size
curr_best = 100000000


n_epoch = 1
for epoch in range(n_epoch):
    print("=" * 50)
    print(f"Epoch {epoch + 1}/{n_epoch}")
    print("=" * 50)

    model = model.train()
    running_loss = 0
    tokens = 0

    # Shuffle IDs
    train_id = np.random.permutation(train_id)
    vali_id = np.random.permutation(vali_id)

    for batch in range(0, len(train_id), args.batch_size):
        optimizer.zero_grad()

        # Get batch data
        indices = train_id[batch:batch + args.batch_size]
        batch_data = train_data[train_data["id"].isin(indices)].copy()

        # Prepare tensors
        batch = get_tensors(batch_data, long=Y_str_list, device=device)

        # Forward pass
        long_preds, visit_inten, surv_inten = model(batch)

        # Loss computation
        loss1, full_loss1, num_tokens = long_loss(long_preds, batch)
        loss2, full_loss2 = surv_loss_lsr(surv_inten, batch)

        # Combined loss
        loss = loss1 + args.surv_weight * loss2 
        loss.backward()
        optimizer.step()

        running_loss += (full_loss1 + args.surv_weight * full_loss2).item()
        tokens += num_tokens

    # Epoch stats
    train_show = running_loss / tokens
    print(f"Avg training loss per token: {train_show:.4f}")
    print(f"  long loss  : {loss1.item():.4f}")
    print(f"  surv loss  : {loss2.item():.4f}")


Epoch 1/1
Avg training loss per token: 0.0224
  long loss  : 0.1115
  surv loss  : -1.2095


In [4]:
from util import surv_loss_lsr
surv_loss_lsr(surv_inten, batch)


NameError: name 'a' is not defined

In [14]:
batch = get_tensors(train_data, long=Y_str_list, device=device)

In [33]:
a = batch["intenmask"][0,1:]
# Find the first index where death_mask is True
idx = torch.argmax(a.int()).item()  # returns 0 if all False
surv_mask = torch.zeros_like(a)
if a.any():
    surv_mask[:idx] = True
print(surv_mask)

tensor([ True,  True,  True,  True,  True,  True,  True,  True, False, False,
        False], device='cuda:0')


In [7]:
model = model.eval()
vali_loss = 0
tokens = 0

for batch in range(0, len(vali_id), args.batch_size):
    indices = vali_id[batch:batch + args.batch_size]
    batch_data = vali_data[vali_data["id"].isin(indices)]
    batch = get_tensors(batch_data.copy(), long=Y_str_list, device=device)

    with torch.no_grad():
        long_preds, visit_inten, surv_inten = model(batch)
        loss1, full_loss1, num_tokens = long_loss(long_preds, batch)
        # loss2, full_loss2 = surv_loss(surv_inten, Zeta, batch)

    vali_loss += full_loss1.item()
    tokens += num_tokens

# ---- Evaluation outputs ----
event_ll = (torch.log(visit_inten[0]) * batch["longmask"][0, 1:]).sum()
non_event_ll = batch["mask"][0].sum()
visit_ll = event_ll - non_event_ll



# Detailed visit comparison
first_traj_len = torch.sum(batch["mask"][0], dim=-1).cpu().numpy()
print(f"sample trajectory visit intensities: {np.log(visit_inten[0, :first_traj_len - 1].detach().cpu().numpy())}")

ground_intensities = batch_data["true_inten"].to_numpy()[:first_traj_len - 1]
print(f"ground truth visit event intensities: {np.log(ground_intensities)}")

total_time = batch["obstime"][0][1:first_traj_len]
print(f"times: {total_time}")

# Survival examination
surv_event_ll = (torch.log(surv_inten[0]) * batch["intenmask"][0, 1:]).sum(dim=-1)
non_surv_event_ll = batch["mask"][0].sum(dim=-1)
surv_ll = surv_event_ll - non_surv_event_ll

print(f"sample trajectory survival event intensity: {surv_event_ll.item():.2f}")
print(f"sample trajectory survival non-event intensity: {non_surv_event_ll.item():.2f}")

ground_truth_surv_ll = batch_data["surv_ll"].to_numpy()[0]
print(f"ground truth survival event intensity: {ground_truth_surv_ll:.2f}")

ground_truth_surv_non_ll = batch_data["surv_non_ll"].to_numpy()[0]
print(f"ground truth survival NON-event intensity: {ground_truth_surv_non_ll:.2f}")

# Detailed survival comparison
first_traj_len = torch.sum(batch["mask"][0], dim=-1).cpu().numpy()
print(f"sample trajectory surv intensities: {np.log(surv_inten[0, :first_traj_len].detach().cpu().numpy())}")

ground_surv_intensities = batch_data["true_surv"].to_numpy()[:first_traj_len]
print(f"ground truth surv intensities: {np.log(ground_surv_intensities)}")

total_time = batch["totaltime"][0][1:first_traj_len + 1]
print(f"times: {total_time}")

# Save best model
vali_show = vali_loss / tokens

if vali_show < curr_best:
    curr_best = vali_show
    print(f"updated at epoch: {epoch}")
    print(f"current best validation loss: {curr_best:.2f}")
    torch.save(model.state_dict(), model_save_path)


visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([16, 11, 48])
inten weight shape: torch.Size([1, 48])
visit_x shape: torch.Size([8, 11, 48])
i

KeyError: 'event_ll'