# Speed, Memory, and Disk Comparisons

In this notebook, we'll offer some rough comparisons of the computational performance implications of ESGPT vs. other competing pipelines. We'll focus these comparisons on several metrics:
  1. The time, runtime memory, and final disk space required to construct, pre-process, and store an ESGPT dataset relative to other pipelines, where applicable.
  2. The initialization time, iteration speed, and GPU memory costs for producing batches of data within the ESGPT framework vs. other systems.
  
In particular, we'll compare (or justify why they are inappropriate comparators) against the following pipelines:
  1. TemporAI
  2. OMOP-Learn
  3. FIDDLE
  4. MIMIC-Extract
  
We'll make these comparisons leveraging the synthetic data distributed with ESGPT's sample tutorial, but this code can also be ported to any other dataset to run these profiles locally.

In [1]:
%load_ext memory_profiler

import sys
sys.path.append('..')

In [2]:
import os
import numpy as np
import torch

from collections import defaultdict
from datetime import datetime, timedelta
from humanize import naturalsize, naturaldelta
from pathlib import Path
from sparklines import sparklines
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from typing import Callable

from EventStream.data.dataset_polars import Dataset
from EventStream.data.config import PytorchDatasetConfig
from EventStream.data.types import PytorchBatch
from EventStream.data.pytorch_dataset import PytorchDataset

In [3]:
dataset_dir = Path(os.getcwd()) / "processed/sample"

First, let's check and see how much disk space the dataset uses, and in what components

In [4]:
total_dataset_size = sum(f.stat().st_size for f in dataset_dir.glob('**/*') if f.is_file())
DL_reps_size = sum(f.stat().st_size for f in (dataset_dir / "DL_reps").glob('**/*') if f.is_file())
just_dataset_size = total_dataset_size - DL_reps_size

if (dataset_dir / "flat_reps").is_dir():
    flat_reps_size = sum(f.stat().st_size for f in (dataset_dir / "flat_reps").glob('**/*') if f.is_file())
    just_dataset_size -= flat_reps_size
    flat_reps_lines = [f"  * {naturalsize(flat_reps_size)} for the flat representation dataframes."]
else:
    flat_reps_lines = []

lines = [
    f"The total dataset takes up {naturalsize(total_dataset_size)} on disk, which includes:",
    f"  * {naturalsize(just_dataset_size)} for the core dataset.",
    f"  * {naturalsize(DL_reps_size)} for the deep-learning representation dataframes.",
] + flat_reps_lines

print('\n'.join(lines))

The total dataset takes up 31.2 MB on disk, which includes:
  * 19.5 MB for the core dataset.
  * 11.7 MB for the deep-learning representation dataframes.


First, we'll note that loading a dataset doesn't require much of either resource. This is because the data is loaded lazily, so complex dataframe elements aren't loaded until they are needed. 

In [5]:
%%time
%%memit

ESD = Dataset.load(dataset_dir)

peak memory: 348.25 MiB, increment: 2.13 MiB
CPU times: user 123 ms, sys: 22.3 ms, total: 145 ms
Wall time: 258 ms


In [6]:
%%time
%%memit

s_df = ESD.subjects_df
e_df = ESD.events_df
dm_df = ESD.dynamic_measurements_df

Loading subjects from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/subjects_df.parquet...
Loading events from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/events_df.parquet...
Loading dynamic_measurements from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/dynamic_measurements_df.parquet...
peak memory: 510.11 MiB, increment: 161.67 MiB
CPU times: user 270 ms, sys: 85.8 ms, total: 356 ms
Wall time: 312 ms


## Pytorch Dataset Stats
Now let's load a pytorch dataset and examine iteration speed and GPU memory cost:

In [7]:
def summarize(arr: list[float], strify: Callable[float, str] = naturalsize) -> str:
    mean, std, mn, mx = np.mean(arr), np.std(arr), np.min(arr), np.max(arr)
    simple_summ = f"{strify(mean)} ± {strify(std)} ({strify(mn)}-{strify(mx)})"
    
    if len(arr) < 25: return simple_summ
    
    hist_vals, hist_bins = np.histogram(arr)
    lines = [simple_summ, "Histogram:"]
    sparkline = sparklines(hist_vals)
    
    lines.extend(sparkline)
    left_end = strify(hist_bins[0])
    right_end = strify(hist_bins[1])
    W = len(sparkline[0]) - len(left_end) - len(right_end)
    
    if W > 0:
        lines.append(f"{left_end}{'-'*W}{right_end}")
    else:
        lines.append(f"o {left_end} (left endpoint)")
        lines.append(f"{'-'*(len(sparkline[0])-1)}o {right_end} (right endpoint)")
    return '\n'.join(lines)

def summarize_times(arr: list[float, timedelta]):
    as_seconds = [x / timedelta(seconds=1) for x in arr]
    return summarize(as_seconds, strify=lambda x: str(timedelta(seconds=x)))

In [8]:
def profile_batch_iteration_speed_and_cost(
    batch_size: int,
    pyd: Dataset,
    n_iter_samples: int = 30,
    collate_fn: Callable | None = None,
    num_workers: int | None = None,
):
    def make_dataloader():
        dataloader_kwargs = {'dataset': pyd, 'batch_size': batch_size, 'shuffle': True}
        if collate_fn is not None:
            dataloader_kwargs['collate_fn'] = collate_fn
        if num_workers is not None:
            dataloader_kwargs['num_workers'] = num_workers
        return DataLoader(**dataloader_kwargs)

    dataloader = make_dataloader()
    batch_sizes = defaultdict(list)
    total_sizes = []
    for batch in tqdm(dataloader, leave=False):
        total_size = 0
        for k, v in batch.items():
            if v is None: continue
            el_size = v.element_size() * v.nelement()
            batch_sizes[k].append(el_size)
            total_size += el_size
        total_sizes.append(total_size)

    batch_iteration_times = []
    for samp in tqdm(list(range(n_iter_samples)), leave=False, desc="Sampling Dataloader Iteration Speed"):
        dataloader = make_dataloader()
        st = datetime.now()
        for batch in tqdm(dataloader, leave=False, desc="Sampling Batch"):
            pass
        batch_iteration_times.append((datetime.now() - st) / len(dataloader))

    print(
        f"Iterating through an entire dataloader of {len(dataloader)} batches of size {batch_size} "
        f"took the following time per batch:\n{summarize_times(batch_iteration_times)}\n\n"
        f"Total batch size:\n{summarize(total_sizes)}"
    )
    for k, v in batch_sizes.items():
        print(f"  Size of {k}:\n    {summarize(v)}")

In [9]:
%%time
%%memit
pyd_config = PytorchDatasetConfig(
    save_dir=ESD.config.save_dir,
    max_seq_len=1024,
)
pyd = PytorchDataset(config=pyd_config, split='train')

peak memory: 831.51 MiB, increment: 320.94 MiB
CPU times: user 2.09 s, sys: 181 ms, total: 2.27 s
Wall time: 2.18 s


In [10]:
batch_size=16
profile_batch_iteration_speed_and_cost(
    batch_size=batch_size, pyd=pyd, n_iter_samples=30, collate_fn=pyd.collate,
    num_workers=None,
)

  0%|          | 0/5 [00:00<?, ?it/s]

Sampling Dataloader Iteration Speed:   0%|          | 0/30 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Iterating through an entire dataloader of 5 batches of size 16 took the following time per batch:
0:00:01.065503 ± 0:00:00.069628 (0:00:00.999924-0:00:01.327426)
Histogram:
▇█▄▂▂▁▁▂▁▂
o 0:00:00.999924 (left endpoint)
---------o 0:00:01.032674 (right endpoint)

Total batch size:
7.7 MB ± 958.5 kB (6.6 MB-9.4 MB)
  Size of event_mask:
    16.4 kB ± 0 Bytes (16.4 kB-16.4 kB)
  Size of time_delta:
    65.5 kB ± 0 Bytes (65.5 kB-65.5 kB)
  Size of static_indices:
    128 Bytes ± 0 Bytes (128 Bytes-128 Bytes)
  Size of static_measurement_indices:
    128 Bytes ± 0 Bytes (128 Bytes-128 Bytes)
  Size of dynamic_indices:
    2.9 MB ± 365.1 kB (2.5 MB-3.5 MB)
  Size of dynamic_measurement_indices:
    2.9 MB ± 365.1 kB (2.5 MB-3.5 MB)
  Size of dynamic_values:
    1.5 MB ± 182.6 kB (1.2 MB-1.8 MB)
  Size of dynamic_values_mask:
    363.7 kB ± 45.6 kB (311.3 kB-442.4 kB)


## Other Pipelines
### TemporAI Format
First, we'll compare against [TemporAI](https://www.temporai.vanderschaar-lab.com/), a recent library that provides a modular set of pre-built plugins for processing temporal EHR data. 

Their data representation differs from ESGPT's in that it is a _wide_ representation, vs. our _long_ representation at a per-event level. More specifically, whereas ESGPT data structures store the measurements observed per event in a nested list per event, TemporAI data structures pivot that structure and store all measurements (observed or unobserved) in separate columns for each event, stored as a row.

In [11]:
import pandas as pd
import polars as pl
import polars.selectors as cs

In [12]:
def ESD_to_temporai(ESD: Dataset) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Converts an ESD data format into a TemporAI dataset format."""

    static_df = (
        ESD.subjects_df
        .filter(pl.col('subject_id').is_in(list(ESD.split_subjects['train'])))
        .select(
            'subject_id',
            *[pl.col(c) for c, cfg in ESD.measurement_configs.items() if cfg.temporality == 'static']
        )
        .to_pandas()
        .set_index("subject_id")
    )
    
    # For the time-series dataframe, as they need only one row per subject ID, timestamp, we need to use the wide
    # format of the flat representation. 
    
    flat_reps_dir = ESD.config.save_dir / "flat_reps" / "raw"
    if not flat_reps_dir.is_dir():
        raise FileNotFoundError(f"Must have pre-cached flat representations at {flat_reps_dir}!")
        
    time_series_df = (
        pl.scan_parquet(flat_reps_dir / "train" / "*.parquet")
        .select("subject_id", "timestamp", cs.starts_with("dynamic"))
        .collect()
        .to_pandas()
        .set_index(["subject_id", "timestamp"])
    )
    
    return static_df, time_series_df

In [13]:
%%time
%%memit
# We need to convert to a flat format prior to getting temporai representations.
# The performance #s here are not reliable as these files may be already generated.
ESD.cache_flat_representation(
    subjects_per_output_file=None,
    feature_inclusion_frequency=None,
    do_overwrite=False,
    do_update=True,
)

Flattening Splits:   0%|          | 0/3 [00:00<?, ?it/s]

Subject chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Subject chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Subject chunks:   0%|          | 0/1 [00:00<?, ?it/s]

peak memory: 1993.54 MiB, increment: 1174.56 MiB
CPU times: user 21.8 s, sys: 3.39 s, total: 25.2 s
Wall time: 7.96 s


In [14]:
%%time
%%memit

temporai_static, temporai_ts = ESD_to_temporai(ESD)

peak memory: 2063.73 MiB, increment: 965.48 MiB
CPU times: user 2.2 s, sys: 1.81 s, total: 4.02 s
Wall time: 2.15 s


In [15]:
print(
    f"TemporAI uses two dataframes, a static dataframe of shape {temporai_static.shape} "
    f"and a time series dataframe of shape {temporai_ts.shape}."
)

TemporAI uses two dataframes, a static dataframe of shape (100, 1) and a time series dataframe of shape (530742, 160).


Let's save these dataframes to disk, so we can inspect their disk cost and the memory cost to re-load them from scratch.

In [16]:
save_dir = Path("./speed_comparisons/temporai/compressed")
save_dir.mkdir(parents=True, exist_ok=True)

temporai_static.to_parquet(save_dir / "static.parquet")
temporai_ts.to_parquet(save_dir / "ts.parquet")

uncompressed_save_dir = Path("./speed_comparisons/temporai/uncompressed")
uncompressed_save_dir.mkdir(parents=True, exist_ok=True)

temporai_static.to_parquet(uncompressed_save_dir / "static.parquet", compression=None)
temporai_ts.to_parquet(uncompressed_save_dir / "ts.parquet", compression=None)

compressed_temporai_size = sum(f.stat().st_size for f in save_dir.glob('**/*') if f.is_file())
uncompressed_temporai_size = sum(f.stat().st_size for f in uncompressed_save_dir.glob('**/*') if f.is_file())

print(
    f"The compressed data takes up {naturalsize(compressed_temporai_size)} on disk.\n"
    f"The uncompressed data takes up {naturalsize(uncompressed_temporai_size)} on disk "
    "(this is a good approximation of memory cost as it is uncompressed)."
)

The compressed data takes up 23.9 MB on disk.
The uncompressed data takes up 26.0 MB on disk (this is a good approximation of memory cost as it is uncompressed).


In [17]:
%%time
%%memit

temporai_static = pd.read_parquet(save_dir / "static.parquet")
temporai_ts = pd.read_parquet(save_dir / "ts.parquet")

peak memory: 2603.92 MiB, increment: 916.96 MiB
CPU times: user 1.21 s, sys: 806 ms, total: 2.01 s
Wall time: 1.08 s


TemporAI generally converts their timeseries data into a dense, 3D matrix across samples, timepoints, and features. For use in ML pipelines, this is then generally iterated through directly via simple numpy iteration. 

For example: 
  * Datasets are converted to 3D views here: https://github.com/vanderschaarlab/temporai/blob/main/src/tempor/plugins/prediction/one_off/classification/__init__.py#L59 and https://github.com/vanderschaarlab/temporai/blob/67ebd74dc24728163d9aec37f1771a83fc3346e2/src/tempor/data/utils.py#L49
  * Iteration through numpy arrays happens here: https://github.com/vanderschaarlab/temporai/blob/main/src/tempor/models/ddh.py#L155
  
Though a full comparison warrants use of their library (and will further depend on the exact model used (as each has different strategies for processing data), we can simulate that approach here quickly:

In [18]:
def no_categories(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    for c in df.columns:
        if pd.api.types.is_categorical_dtype(df[c]):
            df[c] = df[c].cat.codes
    return df

def to_3D_arr(df: pd.DataFrame, max_timesteps: int | None = None) -> np.ndarray:
    df = no_categories(df)
    samples = set(df.index.get_level_values(0))
    num_samples = len(samples)
    num_features = len(df.columns)
    num_timesteps_per_sample = df.groupby(level=0).size()
    max_actual_timesteps = num_timesteps_per_sample.max()
    max_timesteps = max_actual_timesteps if max_timesteps is None else max_timesteps
    array = np.full(shape=(num_samples, max_timesteps, num_features), fill_value=np.NaN)
    for i_sample, idx_sample in enumerate(samples):
        set_vals = df.loc[idx_sample, :, :].to_numpy()[:max_timesteps, :]  # pyright: ignore
        if i_sample == 0:
            array = array.astype(set_vals.dtype)  # Need to cast to the type matching source data.
        array[i_sample, : num_timesteps_per_sample[idx_sample], :] = set_vals  # pyright: ignore
    return array

In [19]:
class SimpleTemporAIStyleDataset(Dataset):
    def __init__(self, static: np.ndarray, ts: np.ndarray):
        self.static = static
        self.ts = ts
        
    def __len__(self) -> int: return self.ts.shape[0]
    
    def __getitem__(self, idx) -> dict[str, torch.Tensor]:
        return {'static': torch.Tensor(self.static[idx]), 'ts': torch.Tensor(self.ts[idx])}
    
def profile_temporai_dataset(
    temporai_static, temporai_ts, batch_size: int = 16,
    n_iter_samples: int = 30,
    max_seq_len: int = 32,
):
    static_as_np = np.nan_to_num(no_categories(temporai_static).to_numpy(), nan=0)
    ts_as_np = np.nan_to_num(to_3D_arr(temporai_ts, max_timesteps=max_seq_len), nan=0)
    print(
        f"Yielded a static NP array of shape {static_as_np.shape} and a TS NP array "
        f"of shape {ts_as_np.shape}."
    )
    temporai_pyd = SimpleTemporAIStyleDataset(static_as_np, ts_as_np)

    profile_batch_iteration_speed_and_cost(
        batch_size=batch_size, pyd=temporai_pyd, n_iter_samples=n_iter_samples
    )

In [20]:
%%time
%%memit

profile_temporai_dataset(temporai_static, temporai_ts, batch_size=16, n_iter_samples=30, max_seq_len=1024)

Yielded a static NP array of shape (100, 1) and a TS NP array of shape (100, 1024, 160).


  0%|          | 0/7 [00:00<?, ?it/s]

Sampling Dataloader Iteration Speed:   0%|          | 0/30 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/7 [00:00<?, ?it/s]

Iterating through an entire dataloader of 7 batches of size 16 took the following time per batch:
0:00:00.008704 ± 0:00:00.002816 (0:00:00.006541-0:00:00.022163)
Histogram:
█▅▁▁▁▁▁▁▁▁
o 0:00:00.006541 (left endpoint)
---------o 0:00:00.008103 (right endpoint)

Total batch size:
9.4 MB ± 2.8 MB (2.6 MB-10.5 MB)
  Size of static:
    57 Bytes ± 16 Bytes (16 Bytes-64 Bytes)
  Size of ts:
    9.4 MB ± 2.8 MB (2.6 MB-10.5 MB)
peak memory: 2325.00 MiB, increment: 273.90 MiB
CPU times: user 8.6 s, sys: 822 ms, total: 9.42 s
Wall time: 3.81 s


As we can see, the strategy of featurizing and batching used in TemporAI results (on this synthetic dataset) in a significantly faster iteration speed and a marginally lower memory cost than does the strategy used in ESGPT (all formats are mean ± standard deviation (min - max)

TemporAI Speed: `0:00:00.007319 ± 0:00:00.001439 (0:00:00.006224-0:00:00.012835)`  
ESGPT Speed:    `0:00:00.734263 ± 0:00:00.078552 (0:00:00.586128-0:00:00.871448)`

TemporAI Memory: `9.4 MB ± 2.8 MB (2.6 MB-10.5 MB)`  
ESGPT Memory:    `8.3 MB ± 1.1 MB (6.6 MB-9.7 MB)`

In table form (using chatGPT for conversions, so may need to be double checked), where "Delta" means what % of TemporAI's resource cost does ESGPT _save_ (higher is better), we get the following:
|                      | TemporAI              | ESGPT                 | Delta (%)  |
|----------------------|-----------------------|-----------------------|------------|
| **Iteration time / batch (ms)** | 7.32 ± 1.44 (6.22 - 12.8) | 734 ± 78.6 (586 - 871) | -9943%           |
| **Memory (MB)**      | 9.4 ± 2.8 (2.6 - 10.5) | 8.3 ± 1.1 (6.6 - 9.7) | 11.7%             |

There are some biases in this format, on both sides:
  1. ESGPT samples different subsequences per item iteration, whereas TemporAI is limited to only using the first max subsequence samples. 
  2. This dataset has relatively few measurements, which will reduce the memory disparity between the two formats (this bias favors TemporAI).
  3. The strategy of flattening this dataset may induce too much memory overhead, as if multiple measurements are not common within an event, it will have extra columns that TemporAI does not need. Conversely, it may reduce a significant amount of data, as if there are many measurements than a simple count, sum, sum_sqd, min, and max representation will not fully capture the data, thereby reducing the burden on TemporAI. (This bias could favor either).
  
#### On MIMIC-IV
Ultimately, these numbers will only be truly reasonable when compared on real data. To do so, we can use MIMIC-IV. While the full numbers for the MIMIC-IV dataset can be found in that example's repository, here we summarize the table obtained when running on that dataset with an otherwise nearly identical setup:

|                      | TemporAI             | ESGPT                | Improvement (%)  |
|----------------------|----------------------|----------------------|-------------------|
| **Iteration time / batch (ms)**       | 597 ± 20.6 (571 - 624) | 416 ± 4.16 (410 - 423) | 30.3%            |
| **Memory (MB)**      | 877 ± 95.5 (507 - 902) | 67.2 ± 35.0 (18.3 - 411) | 92.3%           |

We can see that the numbers here are very different, favoring ESGPT dramatically more than we saw on sample data. The reasons for this are primarily due to how memory intensive TemporAI's coding system is; whereas ESGPT only stores data-elements in a batch if those elements were actually observed in the record, TemporAI's flat encoding means that batches must have a memory cost that scales with the total vocabulary size (and overall maximum sequence length) of the data. This higher cost impacts both speed and memory, though memory is the clear focus for ESGPT.

### omop-learn

Next, we'll attempt to compare against [omop-Learn](https://github.com/clinicalml/omop-learn/tree/master). omop-learn uses a similar storage model to ESGPT (a _long_ format) but it does not include the storage of numerical values, only codes, and it similarly does not allow for randomly sampled sub-sequences per dataset item. See [here](https://github.com/clinicalml/omop-learn/blob/a33440af2b9f1342e0c16106acf93131b9369441/src/omop_learn/torch/data.py) for more details. To simulate this, we'll first convert the ESGPT data into a pre-tokenized OMOP-learn representation, store in the desired format (JSON), then re-load it and iterate through it in the same manner as omop-learn does. As static data are not separately encoded in the omop-learn dataset format, we'll omit that as well here.

In [88]:
from typing import Any
import json, shutil
REFTIME = pd.Timestamp("1900-01-01")

def to_unixtime(str_time_array):
    unix_times = (pd.to_datetime(str_time_array) - REFTIME) // pd.Timedelta("1d")
    return unix_times


def from_unixtime(unix_time_array):
    datetimes = [pd.to_datetime(t * pd.Timedelta("1d") + REFTIME) for t in unix_time_array]
    return datetimes

def ESD_row_to_omoplearn_row(row: dict[str, Any], unified_vocab: dict[int, str]) -> dict[str, Any]:
    example = {}
    
    # Dates
    example['dates'] = [str(row['start_time'] + timedelta(minutes=t)) for t in row['time']]
    
    # Visits
    example['visits'] = [
        [unified_vocab[idx] for idx in indices] for indices in row['dynamic_indices']
    ]
    
    example['tok_visits'] = row['dynamic_indices']
    
    assert len(example['dates']) == len(example['visits'])
    assert len(example['dates']) == len(example['tok_visits'])
    
    example['y'] = 0
    
    return example

def ESD_to_omoplearn(ESD: Dataset) -> list[dict[str, Any]]:
    unified_vocab = {}
    for m, idxmap in ESD.unified_vocabulary_idxmap.items():
        for k, v in idxmap.items():
            unified_vocab[v] = f"{m}/{k}"
    
    omop_file_dir = Path("./omop_reps")
    if omop_file_dir.is_dir():
        shutil.rmtree(omop_file_dir)

    omop_file_dir.mkdir(exist_ok=True, parents=True)
    
    for sp in tqdm(list(ESD.split_subjects.keys()), desc="JSONifying Splits", leave=False):
        omop_sp_dir = omop_file_dir / sp
        omop_sp_dir.mkdir(exist_ok=True, parents=True)
        for fp in tqdm(
            list((ESD.config.save_dir / "DL_reps").glob(f"{sp}_*.parquet")), desc="File", leave=False
        ):
            DL_reps_df = pl.read_parquet(fp)
            cols = DL_reps_df.columns
            rows = DL_reps_df.rows()
            
            for row in rows:
                row_as_dict = {col: val for col, val in zip(cols, row)}
                example = ESD_row_to_omoplearn_row(row_as_dict, unified_vocab)
                with open(omop_sp_dir / "data.json", "a") as f:
                    json.dump(example, f)
                    f.write('\n')

In [89]:
%%time
%%memit

ESD_to_omoplearn(ESD)

JSONifying Splits:   0%|          | 0/3 [00:00<?, ?it/s]

File:   0%|          | 0/1 [00:00<?, ?it/s]

File:   0%|          | 0/1 [00:00<?, ?it/s]

File:   0%|          | 0/1 [00:00<?, ?it/s]

peak memory: 2246.19 MiB, increment: 42.89 MiB
CPU times: user 7.5 s, sys: 477 ms, total: 7.98 s
Wall time: 7.78 s


In [73]:
JSON_dataset_size = sum(f.stat().st_size for f in Path("./omop_reps").glob('**/*') if f.is_file())
print(f"The JSON dumped dataset takes up {naturalsize(JSON_dataset_size)} on disk.")

The JSON dumped dataset takes up 66.2 MB on disk.


In [102]:
# Lightly adapted from https://github.com/clinicalml/omop-learn/blob/a33440af2b9f1342e0c16106acf93131b9369441/src/omop_learn/torch/data.py

from collections import Counter

class OMOPDatasetTorch(torch.utils.data.Dataset):
    def __init__(
        self,
        omop_dataset_file,
        max_num_visits=None,
    ):
        super().__init__()
        self.items = {}
        self.visit_sequences = []  # patient x (# visits for patient) lists w/ concepts expressed
        self.time_sequences = []  # patient x (# visits for patient) times of visits
        self.visit_sizes = []  # patient x (# visits for patient) # concepts in each visit
        self.outcomes = []  # patient--outcome for each patient
        self.tok_visit_sequences = None
        self.tokenizer = None
        self.max_num_visits = max_num_visits  # if set, truncate to most recent
        self._load_json(omop_dataset_file, False)


    def _load_json(self, path_to_json, tokenize_on_load):
        # read once to build concept set
        # (and load visits if tokenize_on_load=False)
        concept_set = set()
        concept_counts = Counter()
        concept_counts_by_year = Counter()
        years = set()
        max_num_visits = 0
        skipped = 0
        with open(path_to_json) as json_fh:
            for i, line in enumerate(json_fh.readlines()):
                example = self._process_line(line)
                max_num_visits = max(max_num_visits, len(example['visits']))

                for time, visit in zip(example['unix_times'], example['visits']):
                    for concept in visit:
                        concept_set.add(concept)

                if len(example['visits']) == 0:
                    skipped += 1
                    continue

                if i == 0:
                    for key, value in example.items():
                        self.items[key] = [value]
                else:
                    # correctly gives error when key is not found
                    # already; all items need to have exactly the same
                    # set of keys.
                    for key,value in example.items():
                        self.items[key].append(value)

        print(f"Skipped {skipped} patients for empty visit lists")
        if not self.max_num_visits:
            self.max_num_visits = max_num_visits

#         if not self.tokenizer:
#             self.tokenizer = ConceptTokenizer(concept_set)
#             print("built tokenizer")

#         # read again to build tokenized visits
#         if tokenize_on_load:
#             self.items['tok_visits'] = []
#             with open(path_to_json, "r") as json_fh:
#                 for i, line in enumerate(json_fh.readlines()):
#                     example = self._process_line(line)
#                     tok_visit_list = []
#                     for visit in example['visits']:
#                         tok_visit = self.tokenizer.concepts_to_ids(visit)
#                         tok_visit_list.append(tok_visit)
#                     if len(tok_visit_list) > 0:
#                         self.items['tok_visits'].append(tok_visit_list)

        self.outcomes = torch.LongTensor(self.items['y'])
        self.one_fraction = self.outcomes.sum() / len(self.outcomes)
        self.one_odds = self.one_fraction / (1 - self.one_fraction)

    def _process_line(self, line):
        example = json.loads(line)
        dates = example['dates']
        unix_times = to_unixtime(dates)
        example['unix_times'] = unix_times

        # make sure visits are sorted by date
        # This actually contains a minor error correction in omop-learn; namely, that pre-tokenized visits may not
        # have been properly sorted.
        sorted_visits = [v for d,v in sorted(zip(example['unix_times'], example['visits']))]
        if 'tok_visits' in example:
            sorted_tok_visits = [t for d,t in sorted(zip(example['unix_times'], example['tok_visits']))]
            example['tok_visits'] = sorted_tok_visits
        example['visits'] = sorted_visits
        example['unix_times'] = sorted(example['unix_times'])
        example['dates'] = sorted(example['dates'])
        
        assert len(example['visits']) == len(example['unix_times'])
        assert len(example['tok_visits']) == len(example['unix_times'])

        return example

    def __getitem__(self, idx):
        example = {k : v[idx] for k,v in self.items.items()}
        times = torch.LongTensor(example['unix_times'])
        visits = example['visits'] if 'tok_visits' not in example else example['tok_visits']
        assert len(times) == len(visits)

        # trim before tokenizing
        visits = visits[-self.max_num_visits :]
        times = times[-self.max_num_visits :]

        if 'tok_visits' not in example:
            raise NotImplementedError(f"Must be pre-tokenized!")
#             tok_visits = []
#             for visit in example['visits']:
#                 tok_visits.append(self.tokenizer.concepts_to_ids(visit))
#             visits = tok_visits

        example['visits'] = visits
        # Another small bug correction
        example['unix_times'] = times

        visit_sizes = torch.LongTensor([len(v) for v in visits])
        outcome = self.outcomes[idx]
        nvisits = len(visits)

        return example

    # pads a batch to largest # of visits / patient in the batch
    # and largest # of concepts / visit along concept dim.
    def collate(self, batch):
        # first group along dict keys
        batch_collated = {}
        for k in batch[0].keys():
            batch_collated[k] = [b[k] for b in batch]

        keys = list(batch_collated.keys())
        N = len(batch_collated['y'])

        # each patient is a list of visits
        max_num_visits = max([len(v) for v in batch_collated['visits']])
        max_num_concepts = max(l for p in range(N) for l in [len(v) for v in batch_collated['visits'][p]])

        concept_tensor = torch.full(
            (N, max_num_visits, max_num_concepts),
            0,
            dtype=torch.long,
        )

        times_tensor = torch.full((N, max_num_visits), -1, dtype=torch.long)
        batch = batch_collated

        lengths = torch.zeros(N)
        for i, visit_list in enumerate(batch['visits']):
            assert len(visit_list) == len(batch['unix_times'][i]), (
                f"Visits don't match! Got {len(visit_list)} visits and {len(batch['unix_times'][i])} times "
                f"for batch element {i}"
            )
            num_visits = len(visit_list)  # visits of this patient we are including
            lengths[i] = num_visits
            for j, visit in enumerate(visit_list):
                visit_size = len(batch['visits'][i][j])
                assert(visit_size == len(visit))
                concept_tensor[i, j, : visit_size] = torch.Tensor(visit)
            times_tensor[i, :num_visits] = torch.Tensor(batch['unix_times'][i])
        batch.pop("unix_times")
        batch.pop("dates")
        # Another small correction:
        batch.pop("tok_visits")
        batch["visits"] = concept_tensor
        batch["times"] = times_tensor
        batch["lengths"] = lengths
        for k,v in batch.items():
            if not isinstance(v, torch.Tensor):
                batch[k] = torch.tensor(v)
        return batch

    def __len__(self):
        return len(self.outcomes)

In [103]:
%%time
%%memit
ODT = OMOPDatasetTorch("./omop_reps/train/data.json", max_num_visits=1024)

Skipped 0 patients for empty visit lists
peak memory: 2218.59 MiB, increment: 123.47 MiB
CPU times: user 3.37 s, sys: 425 ms, total: 3.8 s
Wall time: 3.95 s


In [104]:
%%time
%%memit

profile_batch_iteration_speed_and_cost(
    batch_size=16, pyd=ODT, n_iter_samples=30, collate_fn=ODT.collate
)

  0%|          | 0/5 [00:00<?, ?it/s]

Sampling Dataloader Iteration Speed:   0%|          | 0/30 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Iterating through an entire dataloader of 5 batches of size 16 took the following time per batch:
0:00:00.364018 ± 0:00:00.049781 (0:00:00.190901-0:00:00.470015)
Histogram:
▂▁▁▁▁▂█▂▁▁
o 0:00:00.190901 (left endpoint)
---------o 0:00:00.218812 (right endpoint)

Total batch size:
3.5 MB ± 365.1 kB (3.0 MB-3.9 MB)
  Size of visits:
    3.4 MB ± 365.1 kB (2.9 MB-3.8 MB)
  Size of y:
    128 Bytes ± 0 Bytes (128 Bytes-128 Bytes)
  Size of times:
    131.1 kB ± 0 Bytes (131.1 kB-131.1 kB)
  Size of lengths:
    64 Bytes ± 0 Bytes (64 Bytes-64 Bytes)
peak memory: 2219.23 MiB, increment: 0.63 MiB
CPU times: user 2min 27s, sys: 1.35 s, total: 2min 29s
Wall time: 57 s


We can't do a true apples to apples comparison in terms of speed, here, as there is less data being processed in the OMOP-learn setting, but in terms of Memory we can break down the cost for just the dynamic indices components of the ESGPT batch, and find that it equates to roughly 3MB per batch, just under what OMOP-learn requires. However, we do need to test this on real data to see the true comparison. Additionally, note that this is testing only the pre-tokenized version of OMOP-learn, which is faster than the on-the-fly tokenizing mode.