In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from sklearn.metrics import mean_squared_error, r2_score

In [None]:


# Load tensors
X = np.load("X_tensor.npy")  # shape: (N, 3, E)
Y = np.load("Y_tensor.npy")  # shape: (N, 3)
subject_ids = np.load("subject_ids.npy")


X_input = X[:, :2, :]    # (N, 2, E)
Y_target = Y[:, 2]       # (N,)




Epoch 1/75 - Loss: 324.4204
Epoch 2/75 - Loss: 322.8572
Epoch 3/75 - Loss: 321.3033
Epoch 4/75 - Loss: 319.7749
Epoch 5/75 - Loss: 318.1318
Epoch 6/75 - Loss: 316.3667
Epoch 7/75 - Loss: 314.4029
Epoch 8/75 - Loss: 312.2045
Epoch 9/75 - Loss: 309.6205
Epoch 10/75 - Loss: 306.7607
Epoch 11/75 - Loss: 303.3496
Epoch 12/75 - Loss: 299.3836
Epoch 13/75 - Loss: 294.9559
Epoch 14/75 - Loss: 289.6105
Epoch 15/75 - Loss: 283.5974
Epoch 16/75 - Loss: 276.6620
Epoch 17/75 - Loss: 269.1977
Epoch 18/75 - Loss: 260.6860
Epoch 19/75 - Loss: 251.4237
Epoch 20/75 - Loss: 241.6590
Epoch 21/75 - Loss: 231.1267
Epoch 22/75 - Loss: 220.6202
Epoch 23/75 - Loss: 209.6255
Epoch 24/75 - Loss: 198.1974
Epoch 25/75 - Loss: 187.7386
Epoch 26/75 - Loss: 177.0319
Epoch 27/75 - Loss: 165.6129
Epoch 28/75 - Loss: 155.2483
Epoch 29/75 - Loss: 145.6849
Epoch 30/75 - Loss: 136.6215
Epoch 31/75 - Loss: 127.1458
Epoch 32/75 - Loss: 119.0014
Epoch 33/75 - Loss: 111.9290
Epoch 34/75 - Loss: 103.7335
Epoch 35/75 - Loss: 97.

In [None]:
class PatientTrajectoryDataset(Dataset):
    def __init__(self, X, Y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.Y = torch.tensor(Y, dtype=torch.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# Split
dataset = PatientTrajectoryDataset(X_input, Y_target)
train_size = int(0.8 * len(dataset))
test_size  = len(dataset) - train_size
train_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=4)

In [None]:

# Simple DAST-GCN-inspired model
class DASTGCNSimple(nn.Module):
    def __init__(self, num_features, hidden_dim=64):
        super(DASTGCNSimple, self).__init__()
        self.lstm = nn.LSTM(input_size=num_features, hidden_size=hidden_dim, batch_first=True)
        self.fc   = nn.Linear(hidden_dim, 1)

    def forward(self, x):
        # x: (batch, time, features)
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])  # last hidden state
        return out.squeeze()


model = DASTGCNSimple(num_features=X_input.shape[2])
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
for epoch in range(200):
    model.train()
    total_loss = 0
    for xb, yb in train_loader:
        optimizer.zero_grad()
        pred = model(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/75 - Loss: {total_loss/len(train_loader):.4f}")


model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        pred = model(xb)
        y_true.extend(yb.numpy())
        y_pred.extend(pred.numpy())

In [None]:
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
r2   = r2_score(y_true, y_pred)
print(f"\n✅ Final Results → RMSE: {rmse:.4f}, R²: {r2:.4f}")