### Mount Storage, Import Libraries

In [224]:
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from tqdm.auto import tqdm
from datetime import datetime
from dataclasses import dataclass
from collections import defaultdict
from matplotlib.backends.backend_pdf import PdfPages

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [225]:
target_var = 'onset'
n_epochs   = 200

In [226]:
########## For Colab ##########
!pip install ts2vec
from ts2vec import TS2Vec

########## Personal ##########
from google.colab import drive
drive.mount('/content/drive')
with open(f'/content/drive/MyDrive/datasets/dataset_{target_var}.json') as f:
    content = f.read()
    data = json.loads(content)

########## Enterprise ##########
# import gcsfs
# fs = gcsfs.GCSFileSystem()
# with fs.open('gs://modoo-eod/users/datasets/dataset_hist.json') as f:
#     content = f.read()
#     data = json.loads(content)

########## Local ##########
# with open("../../datasets/dataset_hist.json") as f:
#     content=f.read()
#     data=json.loads(content)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Data Cleaning

In [227]:
df = pd.DataFrame.from_records(data)

print(len(df), "Measurements")

age_mean = np.mean([i[0] for i in df['static'] if pd.notna(i[0])])
bmi_mean = np.mean([i[1] for i in df['static'] if pd.notna(i[1])])

cleaned_data = []
for _, m in enumerate(data):

    # Remove measurements with empty windows
    if len(m['uc_windows']) == 0 or len(m['fhr_windows']) == 0:
        continue

    # Handle NaN values
    static = m['static'].copy()
    if pd.isna(m['static'][0]):
        static[0] = age_mean
    if pd.isna(m['static'][1]):
        static[1] = bmi_mean

    copy = m.copy()
    copy['static'] = static
    cleaned_data.append(copy)

cleaned_df = pd.DataFrame(cleaned_data)
cleaned_df["gest_age_weeks"] = [(i[-1]//7)+1 for i in cleaned_df["static"]]

print(len(cleaned_df), "Cleaned Measurements")

3681 Measurements
3661 Cleaned Measurements


### Train-Test Split (Stratified)

In [228]:
df_train = cleaned_df.groupby(
    "gest_age_weeks",
    group_keys=False
).apply(lambda x: x.sample(frac=0.8), include_groups=True)

df_test = cleaned_df.drop(df_train.index)

  ).apply(lambda x: x.sample(frac=0.8), include_groups=True)


In [229]:
train = df_train.to_dict(orient='records')
test  = df_test.to_dict(orient='records')

### Pre-Compute TS2Vec Embeddings

In [230]:
# n_instances x n_timestamps x n_features
train_uc = np.expand_dims(np.array([i['uc_raw'] for i in train]), 2)
train_fhr = np.expand_dims(np.array([i['fhr_raw'] for i in train]), 2)
test_uc  = np.expand_dims(np.array([i['uc_raw'] for i in test]), 2)
test_fhr = np.expand_dims(np.array([i['fhr_raw'] for i in test]), 2)

print("Train has shape", train_fhr.shape)
print("Test has shape ", test_fhr.shape)

ts_model = TS2Vec(
    input_dims = 1,
    output_dims = 320,
    device = 0,
    batch_size = 32
)

train_uc_embed    = ts_model.encode(train_uc, encoding_window="full_series")
train_fhr_embed   = ts_model.encode(train_fhr, encoding_window="full_series")

test_uc_embed     = ts_model.encode(test_uc, encoding_window="full_series")
test_fhr_embed    = ts_model.encode(test_fhr, encoding_window="full_series")

for idx, e in enumerate(train_uc_embed):
    train[idx]['uc_raw'] = e

for idx, e in enumerate(train_fhr_embed):
    train[idx]['fhr_raw'] = e

for idx, e in enumerate(test_uc_embed):
    test[idx]['uc_raw'] = e

for idx, e in enumerate(test_fhr_embed):
    test[idx]['fhr_raw'] = e

Train has shape (2928, 2048, 1)
Test has shape  (733, 2048, 1)


### Aggregate Windows

In [231]:
for i in train:

    uc_w    = torch.tensor([[v for _, v in w.items()] for w in i['uc_windows']], dtype=torch.float32)
    fhr_w   = torch.tensor([[v for _, v in w.items()] for w in i['fhr_windows']],dtype=torch.float32)
    uc_mean = uc_w.mean(dim=0) ; fhr_mean = fhr_w.mean(dim=0)

    i['uc_windows'] = uc_mean ; i['fhr_windows'] = fhr_mean

for i in test:

    uc_w    = torch.tensor([[v for _, v in w.items()] for w in i['uc_windows']], dtype=torch.float32)
    fhr_w   = torch.tensor([[v for _, v in w.items()] for w in i['fhr_windows']],dtype=torch.float32)
    uc_mean = uc_w.mean(dim=0) ; fhr_mean = fhr_w.mean(dim=0)

    i['uc_windows'] = uc_mean ; i['fhr_windows'] = fhr_mean

### Dataset, Data Loader

In [232]:
class PatientDataset(Dataset):

    def __init__(self, measurements):

        self.measurements = measurements

    def __len__(self):

        return len(self.measurements)

    def __getitem__(self, idx):

        measurement = self.measurements[idx]

        return measurement

In [233]:
def patient_collate_fn(batch):

    uc_raw  = torch.stack([torch.tensor(m['uc_raw'], dtype=torch.float32) for m in batch])
    fhr_raw = torch.stack([torch.tensor(m['fhr_raw'], dtype=torch.float32) for m in batch])
    static  = torch.stack([torch.tensor(m['static'], dtype=torch.float32) for m in batch])
    target  = torch.stack([torch.tensor(m['target'], dtype=torch.float32) for m in batch])
    uc_win  = torch.stack([m['uc_windows'] for m in batch])
    fhr_win = torch.stack([m['fhr_windows'] for m in batch])

    return {
        'uc_raw'      : uc_raw,
        'fhr_raw'     : fhr_raw,
        'uc_windows'  : uc_win,
        'fhr_windows' : fhr_win,
        'static'      : static,
        'target'      : target
    }

### Model Cfg

In [234]:
@dataclass
class ModelCfg:

    # Raw UC/FHR
    ts2vec_in_dim   : int = 1
    ts2vec_out_dim  : int = 320

    # FHR Windows
    fhr_in_dim      : int = 24
    fhr_hidden_dim  : int = 64
    fhr_out_dim     : int = 32

    # UC Windows
    uc_in_dim       : int = 20
    uc_hidden_dim   : int = 64
    uc_out_dim      : int = 32

    # Static
    stat_in_dim     : int = 8
    stat_hidden_dim : int = 64
    stat_out_dim    : int = 32

    # Fused Regressor
    fuse_hidden_dim : int = 512

### Model Modules

In [235]:
class StaticEncoder(nn.Module):

    def __init__(self, in_dim, hidden_dim, out_dim):

        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
            nn.ReLU()
        )

    def forward(self, x):

        # (B, 32)
        return self.net(x)

In [236]:
class WindowsEncoder(nn.Module):

    def __init__(self, in_dim, hidden_dim, out_dim):

        super().__init__()

        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim),
            nn.ReLU()
        )

    def forward(self, window):

        # (B, 24|20) -> (B, 32)
        h = self.mlp(window)

        # (B, 32)
        return h

### Main Model Class

In [237]:
class FusedRegressor(nn.Module):

    def __init__(self, cfg: ModelCfg):

        super().__init__()

        self.cfg = cfg

        # UC Window Encoder: Output 32
        self.uc_win_encoder = WindowsEncoder(
            in_dim=self.cfg.uc_in_dim,
            hidden_dim=self.cfg.uc_hidden_dim,
            out_dim=self.cfg.uc_out_dim,
        )

        # FHR Window Encoder: Output 32
        self.fhr_win_encoder = WindowsEncoder(
            in_dim=self.cfg.fhr_in_dim,
            hidden_dim=self.cfg.fhr_hidden_dim,
            out_dim=self.cfg.fhr_out_dim,
        )

        # Static Encoder: Output 32
        self.static_encoder = StaticEncoder(
            in_dim=self.cfg.stat_in_dim,
            hidden_dim=self.cfg.stat_hidden_dim,
            out_dim=self.cfg.stat_out_dim,
        )

        fused_dim = (
            2 * self.cfg.ts2vec_out_dim # 640
            + self.cfg.uc_out_dim       # 32
            + self.cfg.fhr_out_dim      # 32
            + self.cfg.stat_out_dim     # 32
        )

        # fused_dim = (
        #     2 * self.cfg.ts2vec_out_dim
        #     + 20
        #     + 24
        #     + 8
        # )

        self.fusion = nn.Sequential(
            nn.Linear(fused_dim, self.cfg.fuse_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.cfg.fuse_hidden_dim, self.cfg.fuse_hidden_dim//2),
            nn.ReLU(),
            nn.Linear(self.cfg.fuse_hidden_dim//2, self.cfg.fuse_hidden_dim//4),
            nn.ReLU(),
            nn.Linear(self.cfg.fuse_hidden_dim//4, self.cfg.fuse_hidden_dim//8),
            nn.ReLU(),
            nn.Linear(self.cfg.fuse_hidden_dim//8, 1)
        )

    def forward(self, batch):

        uc_raw = batch['uc_raw'] ; uc_raw_tensor = uc_raw.to(device)

        uc_windows = batch['uc_windows'] ; uc_win_tensor = uc_windows.to(device)

        fhr_raw = batch['fhr_raw'] ; fhr_raw_tensor = fhr_raw.to(device)

        fhr_windows = batch['fhr_windows'] ; fhr_win_tensor = fhr_windows.to(device)

        static = batch['static'] ; static_tensor = static.to(device)

        # (B, 20) -> (B, 32)
        uc_win_emb = self.uc_win_encoder(uc_win_tensor)
        # print("UC Windows Shape:", uc_win_emb.shape)

        # (B, 24) -> (B, 32)
        fhr_win_emb = self.fhr_win_encoder(fhr_win_tensor)
        # print("FHR Windows Shape:", fhr_win_emb.shape)

        # (B, 8) -> (B, 32)
        static_emb = self.static_encoder(static_tensor)
        # print("Static Shape:", static_emb.shape)

        # (B, 736)
        fused = torch.cat(
            [uc_raw_tensor, fhr_raw_tensor, uc_win_emb, fhr_win_emb, static_emb],
            dim=-1,
        ).to(device)
        # print("Fused Shape:", fused.shape)

        # (B, 1)
        preds = self.fusion(fused)
        # print("Fusion Output Shape:", preds.shape)

        return preds.squeeze(1)


### Train Eval Functions

In [238]:
def train_one_epoch(model, loader, optimiser, criterion, device):

    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in loader:

        target = batch['target'].to(device)
        # print("Targets Shape:", target.shape)

        optimiser.zero_grad()

        pred = model(batch)

        loss = criterion(pred, target)

        total_loss += loss.item()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimiser.step()

        n_batches += 1

        # print("Train", pred, target, sep="|")

    return total_loss / n_batches

In [239]:
@torch.no_grad()
def evaluate(model, loader, criterion, device):

    model.eval()

    total_loss = 0.0
    n_batches = 0

    group_loss = defaultdict(list)

    for batch in loader:

        target = batch['target'].to(device)

        pred = model(batch)

        loss = criterion(pred, target)

        total_loss += loss.item()

        n_batches += 1

        gest_days  = batch['static'][:,-1]
        gest_weeks = (gest_days // 7).long()+1

        abs_err = (pred-target).abs()

        for g, e in zip(gest_weeks.tolist(), abs_err.tolist()):
            group_loss[str(g)].append(e)

    per_group_mae = {
        g: round(float(torch.tensor(es).mean()), 4)\
        for g, es in sorted(group_loss.items(), key=lambda x: x[0])
    }

    return total_loss/n_batches, per_group_mae

### Main

In [240]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

model_cfg = ModelCfg()

model = FusedRegressor(model_cfg).to(device)

train_dataset = PatientDataset(train)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=patient_collate_fn,
    num_workers=4,
    pin_memory=True
)

eval_dataset = PatientDataset(test)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=patient_collate_fn,
    num_workers=4,
    pin_memory=True,
)

optimizer = torch.optim.Adam(
    list(model.parameters()),
    lr=3e-4,
    weight_decay=1e-5
)

criterion = torch.nn.L1Loss()

train_losses, test_losses = [], []

# {'week' : [err_1, err_2, ...]}
group_losses_week = {str(k):[] for k in cleaned_df['gest_age_weeks']}

group_losses = []

for epoch in tqdm(range(n_epochs)):

    train_loss = train_one_epoch(
        model,
        train_loader,
        optimizer,
        criterion,
        device
    )

    test_loss, group_loss = evaluate(
        model,
        eval_loader,
        criterion,
        device
    )

    train_losses.append(train_loss)
    test_losses.append(test_loss)

    for k in group_losses_week:

        if k not in group_loss:
            group_losses_week[k].append(None)

        else:
            group_losses_week[k].append(group_loss[k])

    group_losses.append(group_loss)

    print(f"[{epoch+1:02d}] train_loss={train_loss:.4f}  test_loss={test_loss:.4f}")
    print(group_loss)
    print()

cuda


  0%|          | 0/200 [00:00<?, ?it/s]

[01] train_loss=15.9002  test_loss=15.4630
{'29': 58.8474, '30': 49.3266, '31': 39.2382, '32': 38.1925, '33': 31.1415, '34': 22.4941, '35': 16.1783, '36': 9.5747, '37': 6.4531, '38': 6.4016, '39': 9.5019, '40': 12.4273, '41': 14.7275}

[02] train_loss=14.9451  test_loss=14.4942
{'29': 52.3154, '30': 42.3425, '31': 33.9036, '32': 33.1781, '33': 25.675, '34': 17.3135, '35': 11.896, '36': 6.8842, '37': 6.882, '38': 9.0435, '39': 12.4557, '40': 15.477, '41': 17.2251}

[03] train_loss=14.1845  test_loss=12.9341
{'29': 49.0595, '30': 37.6802, '31': 31.3202, '32': 30.9586, '33': 23.6678, '34': 16.4256, '35': 11.6829, '36': 7.6635, '37': 6.7817, '38': 7.4071, '39': 9.5063, '40': 11.3317, '41': 11.4522}

[04] train_loss=12.4302  test_loss=10.0066
{'29': 36.4158, '30': 22.1485, '31': 20.7085, '32': 21.8152, '33': 15.5645, '34': 10.7365, '35': 9.1087, '36': 7.902, '37': 8.0621, '38': 7.5863, '39': 7.3335, '40': 7.2373, '41': 5.9261}

[05] train_loss=9.8375  test_loss=8.7095
{'29': 28.9083, '30': 

### Report Generation

In [241]:
def generate_report(

    original_count    : int,
    cleaned_count     : int,
    train_count       : int,
    test_count        : int,
    gest_counts       : dict[int, int],
    train_losses      : list[float],
    test_losses       : list[float],
    group_losses      : list[dict[str, float]],
    group_losses_week : dict[str, float],
    model_name        : str,
    n_epochs          : int | None,
    output_path       : str

):

    Path(output_path).parent.mkdir(parents=True, exist_ok=True)

    with PdfPages(output_path) as pdf:

        # --- Page 1: Summary ---
        plt.figure(figsize=(8,6)) ; plt.axis('off')
        summary_text = (
            f"Model: {model_name}\n"
            f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M')}\n\n"
            f"Measurements (original): {original_count}\n"
            f"Measurements (cleaned) : {cleaned_count}\n"
            f"Train/Test: {train_count}/{test_count}\n\n"
            f"Epochs: {n_epochs}\n\n"
            f"Gestational Age Counts:\n" +
            "\n".join([f"{int(k)} weeks: {int(v)}" for k, v in sorted(gest_counts.items())])
        )
        plt.text(0, 1, summary_text, va='top', fontsize=10, family='monospace')
        pdf.savefig() ; plt.close()

        # --- Page 2: Overall Loss Plot ---
        plt.figure(figsize=(10,8))
        x = np.arange(1, len(train_losses) + 1)
        plt.plot(x, train_losses, label='Train MAE')
        plt.plot(x, test_losses,  label='Test MAE')
        plt.xlabel('Epoch'); plt.ylabel('MAE'); plt.title('Train vs Test MAE')
        plt.legend() ; plt.grid(True) ; plt.tight_layout()
        pdf.savefig() ; plt.close()

        # --- Page 3: Per-Gestational-Week Plot (robust to pd.NA/missing) ---
        df_group = pd.DataFrame(group_losses_week)
        df_group = df_group.replace({pd.NA: np.nan}).apply(pd.to_numeric, errors='coerce')
        def _to_week(c):
            try: return int(float(c))
            except Exception: return c
        df_group = df_group.rename(columns=_to_week)
        plt.figure(figsize=(10,8))
        x = np.arange(1, len(df_group) + 1)
        plotted = 0
        for col in sorted(df_group.columns, key=lambda k: (isinstance(k, int), k)):
            y = df_group[col].to_numpy(dtype=float)
            if np.isfinite(y).any():
                label = f'{col}w' if isinstance(col, int) else str(col)
                plt.plot(x, y, label=label)
                plotted += 1
        plt.xlabel('Epoch') ; plt.ylabel('Test MAE') ; plt.title('Test MAE by Gestational Age (weeks)')
        if plotted: plt.legend(ncol=3)
        plt.grid(True) ; plt.tight_layout()
        pdf.savefig() ; plt.close()

        # --- Page 4: Per-epoch Table ---
        def _format_group_line(d):
            parts = [
                f"{int(float(k))}w={float(v):.3f}"
                for k, v in sorted(d.items(), key=lambda kv: float(kv[0]))
                if pd.notna(v)
            ]
            return ", ".join(parts)

        lines = [
            f"[{i+1:03d}]  train={t:.4f}  test={v:.4f}\n{_format_group_line(g)}\n"
            for i, (t, v, g) in enumerate(zip(train_losses, test_losses, group_losses))
        ]

        visual_rows = sum(s.count("\n")+2 for s in lines)

        fontsize = 10
        line_h_in = fontsize * 1.35 / 72.0
        top_margin_in = 0.6
        bottom_margin_in = 0.6
        fig_w_in = 8.5
        fig_h_in = top_margin_in + bottom_margin_in + line_h_in * visual_rows
        plt.figure(figsize=(fig_w_in, fig_h_in))
        plt.axis('off')
        ax = plt.gca()
        ax.text(
            0, 1,
            "\n".join(lines),
            va='top', ha='left',
            transform=ax.transAxes,
            family='monospace',
            fontsize=fontsize,
            wrap=True,
        )

        pdf.savefig(); plt.close()


    print(f"Report saved to {output_path}")

In [242]:
model_name = f'ts2vec-03-static'

generate_report(
    original_count=len(data),
    cleaned_count=len(cleaned_data),
    train_count=len(train),
    test_count=len(test),
    gest_counts=cleaned_df["gest_age_weeks"].value_counts().sort_index().to_dict(),
    train_losses=train_losses,
    test_losses=test_losses,
    group_losses=group_losses,
    group_losses_week=group_losses_week,
    model_name=f'{model_name}-{target_var}',
    n_epochs=n_epochs,
    output_path=f'/content/drive/MyDrive/reports/{model_name}/{model_name}-{target_var}-report.pdf'
)

  df_group = df_group.replace({pd.NA: np.nan}).apply(pd.to_numeric, errors='coerce')


Report saved to /content/drive/MyDrive/reports/ts2vec-03-static/ts2vec-03-static-onset-report.pdf
