# GRU-T (time-aware GRU)

The original GRU model presents two major limitations when applied to our OSIC FVC time-series data:
1. It ignores unequal time intervals between measurements.
- A standard GRU processes sequences assuming each step occurs at a uniform time interval.
- Even if we include Weeks as a feature, the GRU’s recurrence is fundamentally unaware that:
    - the gap from week 1 → 3 is 2 weeks,
    - but the gap from week 3 → 7 is 4 weeks.
- Since the GRU update always proceeds in discrete, equally spaced steps, it fails to model the true temporal dynamics of lung function decline when time gaps vary across patients.
2. It cannot naturally predict FVC at missing time steps.
- To batch sequences of different lengths, we typically apply sequence padding, which inserts dummy time steps, and a mask to ensure padded positions are ignored during loss computation.
- As a result:
    - The GRU only learns from observed weeks (e.g., 1, 3, 7).
    - It does not learn or output predictions for weeks that never appear as inputs (e.g., 2, 4, 5, 6).
- Thus, the model neither learns the evolution of FVC at unobserved weeks nor provides meaningful predictions at those intermediate time points.

To address these limitations, we adopt GRU-T, a time-aware GRU variant designed for irregularly sampled clinical data.

GRU-T explicitly incorporates:
- `dt`: the time difference between the current and previous measurements, which allows the model to scale or decay the hidden state based on how much real time has passed

This makes the recurrence continuous-time–like, enabling the model to distinguish between short and long gaps in the patient’s timeline.

By constructing a complete weekly time grid for each patient (e.g., weeks 1 through 70), and letting GRU-T process sequences that include both observed and unobserved weeks:
- GRU-T produces FVC predictions for every week, including those missing from the raw dataset.
- Loss is computed only on observed weeks (via masking), but predictions are generated across the entire timeline.
This enables the model to interpolate disease trajectory smoothly and consistently across irregular measurement gaps.

## Import libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd 
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import sys
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import r2_score
from scipy import stats

## Preprocess data

In [6]:
# Load the training data
df = pd.read_csv("../data/train.csv")
print(df.head())

print("\nNumber of data points: \n", df.groupby('Patient').size())

print("\nMissing values:\n", df.isna().sum())  # check for missing FVC values -> 0 missing values

                     Patient  Weeks   FVC    Percent  Age   Sex SmokingStatus
0  ID00007637202177411956430     -4  2315  58.253649   79  Male     Ex-smoker
1  ID00007637202177411956430      5  2214  55.712129   79  Male     Ex-smoker
2  ID00007637202177411956430      7  2061  51.862104   79  Male     Ex-smoker
3  ID00007637202177411956430      9  2144  53.950679   79  Male     Ex-smoker
4  ID00007637202177411956430     11  2069  52.063412   79  Male     Ex-smoker

Number of data points: 
 Patient
ID00007637202177411956430     9
ID00009637202177434476278     9
ID00010637202177584971671     9
ID00011637202177653955184     9
ID00012637202177665765362     9
                             ..
ID00419637202311204720264     9
ID00421637202311550012437    10
ID00422637202311677017371     8
ID00423637202312137826377     9
ID00426637202313170790466     9
Length: 176, dtype: int64

Missing values:
 Patient          0
Weeks            0
FVC              0
Percent          0
Age              0
Sex    

In [15]:
# Same preprocessing as GP
df['Sex_id'] = df['Sex'].map({'Male': 0, 'Female': 1})
df['Smk_id'] = df['SmokingStatus'].map({'Never smoked': 0, 'Ex-smoker': 1, 'Currently smokes': 2})

ids = df['Patient'].unique()
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=3244)

train_df = df[df['Patient'].isin(train_ids)].reset_index(drop=True)
val_df   = df[df['Patient'].isin(val_ids)].reset_index(drop=True)

# # Include missing weeks between existing time intervals
# def include_missing_weeks(df):
#     all_data = []
#     for patient_id, group in df.groupby('Patient'):
#         weeks = group['Weeks'].values
#         fvc_values = group['FVC'].values    
#         full_weeks = np.arange(weeks.min(), weeks.max() + 1)
#         full_fvc = np.interp(full_weeks, weeks, fvc_values)
#         patient_df = pd.DataFrame({
#             'Patient': patient_id,
#             'Weeks': full_weeks,
#             'FVC': full_fvc
#         })
#         for col in group.columns:
#             if col not in ['Patient', 'Weeks', 'FVC']:
#                 patient_df[col] = group[col].iloc[0]
#         all_data.append(patient_df)
#     return pd.concat(all_data).reset_index(drop=True)   

# train_df = include_missing_weeks(train_df)
# val_df = include_missing_weeks(val_df)
# print(f"After including missing weeks, train size: {len(train_df)}, val size: {len(val_df)}")
# print(train_df.head())

# Scale columns "Weeks" and "FVC"
time_scaler = StandardScaler()
fvc_scaler = StandardScaler()
train_df["Weeks_scaled"] = time_scaler.fit_transform(train_df[["Weeks"]])
train_df["FVC_scaled"] = fvc_scaler.fit_transform(train_df[["FVC"]])
val_df["Weeks_scaled"] = time_scaler.transform(val_df[["Weeks"]])
val_df["FVC_scaled"] = fvc_scaler.transform(val_df[["FVC"]])

# Calculate baseline FVC (scaled)
baseline_fvc_tr = train_df.groupby('Patient')['FVC_scaled'].first().to_dict()
baseline_fvc_val = val_df.groupby('Patient')['FVC_scaled'].first().to_dict()
train_df['Baseline_FVC'] = train_df['Patient'].map(baseline_fvc_tr)
val_df['Baseline_FVC'] = val_df['Patient'].map(baseline_fvc_val)

# Sort by patient and weeks
train_df = train_df.sort_values(['Patient', 'Weeks']).reset_index(drop=True)
val_df = val_df.sort_values(['Patient', 'Weeks']).reset_index(drop=True)

# dt in raw weeks
train_df["dt"] = train_df.groupby("Patient")["Weeks"].diff().fillna(0.0)
val_df["dt"]   = val_df.groupby("Patient")["Weeks"].diff().fillna(0.0)

# dt in scaled weeks
train_df["dt_scaled"] = train_df.groupby("Patient")["Weeks_scaled"].diff().fillna(0.0)
val_df["dt_scaled"]   = val_df.groupby("Patient")["Weeks_scaled"].diff().fillna(0.0)


# Check the processed data
# print(val_df.head())
print(train_df.head())

                     Patient  Weeks   FVC    Percent  Age   Sex SmokingStatus  \
0  ID00007637202177411956430     -4  2315  58.253649   79  Male     Ex-smoker   
1  ID00007637202177411956430      5  2214  55.712129   79  Male     Ex-smoker   
2  ID00007637202177411956430      7  2061  51.862104   79  Male     Ex-smoker   
3  ID00007637202177411956430      9  2144  53.950679   79  Male     Ex-smoker   
4  ID00007637202177411956430     11  2069  52.063412   79  Male     Ex-smoker   

   Sex_id  Smk_id  Weeks_scaled  FVC_scaled  Baseline_FVC   dt  dt_scaled  
0       0       1     -1.510341   -0.448119     -0.448119  0.0   0.000000  
1       0       1     -1.122587   -0.566423     -0.448119  9.0   0.387754  
2       0       1     -1.036419   -0.745635     -0.448119  2.0   0.086168  
3       0       1     -0.950252   -0.648415     -0.448119  2.0   0.086168  
4       0       1     -0.864084   -0.736264     -0.448119  2.0   0.086168  


## Input Transformation for GRU-T

Input, Output Size

`X`:    (batch, T, F)

`dt`:   (batch, T)

------------->

`pred`: (batch, T, 1)     # FVC per time step


In [16]:
feature_cols = [
    "Weeks_scaled",
    "Age",
    "Sex_id",
    "Smk_id",
    "Baseline_FVC",
    # you can add more engineered features here
]

target_col = "FVC_scaled"

def build_sequences(df, feature_cols, target_col):
    X_seqs, dt_seqs, y_seqs = [], [], []
    patients = []

    for pid, g in df.groupby("Patient"):
        g = g.sort_values("Weeks")  # safety

        X = g[feature_cols].values.astype(np.float32)      # (T, F)
        dt = g["dt"].values.astype(np.float32)            # (T,)
        y = g[target_col].values.astype(np.float32)       # (T,)

        X_seqs.append(X)
        dt_seqs.append(dt)
        y_seqs.append(y)
        patients.append(pid)

    return X_seqs, dt_seqs, y_seqs, patients

X_tr_seqs, dt_tr_seqs, y_tr_seqs, train_patients = build_sequences(train_df, feature_cols, target_col)
X_val_seqs, dt_val_seqs, y_val_seqs, val_patients = build_sequences(val_df, feature_cols, target_col)

print(len(X_tr_seqs), "train patients")
print("Example sequence shapes:", X_tr_seqs[0].shape, dt_tr_seqs[0].shape, y_tr_seqs[0].shape)

print(X_tr_seqs)


140 train patients
Example sequence shapes: (9, 5) (9,) (9,)
[array([[-1.5103409 , 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-1.122587  , 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-1.0364194 , 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-0.95025194, 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-0.86408436, 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-0.60558176, 79.        ,  0.        ,  1.        , -0.4481193 ],
       [-0.08857659, 79.        ,  0.        ,  1.        , -0.4481193 ],
       [ 0.4284286 , 79.        ,  0.        ,  1.        , -0.4481193 ],
       [ 1.1177689 , 79.        ,  0.        ,  1.        , -0.4481193 ]],
      dtype=float32), array([[-0.99333566, 69.        ,  0.        ,  1.        ,  1.1273067 ],
       [-0.95025194, 69.        ,  0.        ,  1.        ,  1.1273067 ],
       [-0.86408436, 69.        ,  0.        ,  1.        ,  1.1273067 ],
       [-0.77791685, 69.   