# Speed, Memory, and Disk Comparisons

In this notebook, we'll offer some rough comparisons of the computational performance implications of ESGPT vs. other competing pipelines. We'll focus these comparisons on several metrics:
  1. The time, runtime memory, and final disk space required to construct, pre-process, and store an ESGPT dataset relative to other pipelines, where applicable.
  2. The initialization time, iteration speed, and GPU memory costs for producing batches of data within the ESGPT framework vs. other systems.
  
In particular, we'll compare (or justify why they are inappropriate comparators) against the following pipelines:
  1. TemporAI
  2. OMOP-Learn
  3. FIDDLE
  4. MIMIC-Extract
  
We'll make these comparisons leveraging the synthetic data distributed with ESGPT's sample tutorial, but this code can also be ported to any other dataset to run these profiles locally.

In [1]:
%load_ext memory_profiler

import sys
sys.path.append('..')

In [24]:
import os
import numpy as np

from collections import defaultdict
from datetime import datetime, timedelta
from humanize import naturalsize, naturaldelta
from pathlib import Path
from sparklines import sparklines
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from typing import Callable

from EventStream.data.dataset_polars import Dataset
from EventStream.data.config import PytorchDatasetConfig
from EventStream.data.types import PytorchBatch
from EventStream.data.pytorch_dataset import PytorchDataset

In [3]:
dataset_dir = Path(os.getcwd()) / "processed/sample"

First, let's check and see how much disk space the dataset uses, and in what components

In [4]:
total_dataset_size = sum(f.stat().st_size for f in dataset_dir.glob('**/*') if f.is_file())
DL_reps_size = sum(f.stat().st_size for f in (dataset_dir / "DL_reps").glob('**/*') if f.is_file())
just_dataset_size = total_dataset_size - DL_reps_size

if (dataset_dir / "flat_reps").is_dir():
    flat_reps_size = sum(f.stat().st_size for f in (dataset_dir / "flat_reps").glob('**/*') if f.is_file())
    just_dataset_size -= flat_reps_size
    flat_reps_lines = [f"  * {naturalsize(flat_reps_size)} for the flat representation dataframes."]
else:
    flat_reps_lines = []

lines = [
    f"The total dataset takes up {naturalsize(total_dataset_size)} on disk, which includes:",
    f"  * {naturalsize(just_dataset_size)} for the core dataset.",
    f"  * {naturalsize(DL_reps_size)} for the deep-learning representation dataframes.",
] + flat_reps_lines

print('\n'.join(lines))

The total dataset takes up 31.2 MB on disk, which includes:
  * 19.5 MB for the core dataset.
  * 11.7 MB for the deep-learning representation dataframes.


First, we'll note that loading a dataset doesn't require much of either resource. This is because the data is loaded lazily, so complex dataframe elements aren't loaded until they are needed. 

In [5]:
%%time
%%memit

ESD = Dataset.load(dataset_dir)

peak memory: 346.61 MiB, increment: 1.75 MiB
CPU times: user 133 ms, sys: 32.2 ms, total: 165 ms
Wall time: 277 ms


In [6]:
%%time
%%memit

s_df = ESD.subjects_df
e_df = ESD.events_df
dm_df = ESD.dynamic_measurements_df

Loading subjects from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/subjects_df.parquet...
Loading events from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/events_df.parquet...
Loading dynamic_measurements from /home/mmd/Projects/EventStreamGPT/sample_data/processed/sample/dynamic_measurements_df.parquet...
peak memory: 505.35 MiB, increment: 158.61 MiB
CPU times: user 348 ms, sys: 123 ms, total: 472 ms
Wall time: 325 ms


## Pytorch Dataset Stats
Now let's load a pytorch dataset and examine iteration speed and GPU memory cost:

In [30]:
def summarize(arr: list[float], strify: Callable[float, str] = naturalsize) -> str:
    mean, std, mn, mx = np.mean(arr), np.std(arr), np.min(arr), np.max(arr)
    simple_summ = f"{strify(mean)} ± {strify(std)} ({strify(mn)}-{strify(mx)})"
    
    if len(arr) < 25: return simple_summ
    
    hist_vals, hist_bins = np.histogram(arr)
    lines = [simple_summ, "Histogram:"]
    sparkline = sparklines(hist_vals)
    
    lines.extend(sparkline)
    left_end = strify(hist_bins[0])
    right_end = strify(hist_bins[1])
    W = len(sparkline[0]) - len(left_end) - len(right_end)
    
    if W > 0:
        lines.append(f"{left_end}{'-'*W}{right_end}")
    else:
        lines.append(f"o {left_end} (left endpoint)")
        lines.append(f"{'-'*(len(sparkline[0])-1)}o {right_end} (right endpoint)")
    return '\n'.join(lines)

def summarize_times(arr: list[float, timedelta]):
    as_seconds = [x / timedelta(seconds=1) for x in arr]
    return summarize(as_seconds, strify=lambda x: str(timedelta(seconds=x)))

In [7]:
%%time
%%memit
pyd_config = PytorchDatasetConfig(
    save_dir=ESD.config.save_dir,
    max_seq_len=32,
)
pyd = PytorchDataset(config=pyd_config, split='train')

peak memory: 817.68 MiB, increment: 312.33 MiB
CPU times: user 2.04 s, sys: 148 ms, total: 2.18 s
Wall time: 2.02 s


In [31]:
%%time
%%memit

batch_size = 16
n_iter_samples = 30

dataloader = DataLoader(pyd, collate_fn=pyd.collate, batch_size=batch_size)

batch_sizes = defaultdict(list)
total_sizes = []
for batch in tqdm(dataloader, leave=False):
    total_size = 0
    for k, v in batch.items():
        if v is None: continue
        el_size = v.element_size() * v.nelement()
        batch_sizes[k].append(el_size)
        total_size += el_size
    total_sizes.append(total_size)

batch_iteration_times = []
for samp in tqdm(list(range(n_iter_samples)), leave=False, desc="Sampling Dataloader Iteration Speed"):
    dataloader = DataLoader(pyd, collate_fn=pyd.collate, batch_size=batch_size, shuffle=True)
    st = datetime.now()
    for batch in tqdm(dataloader, leave=False, desc="Sampling Batch"):
        pass
    batch_iteration_times.append((datetime.now() - st) / len(dataloader))
    
print(
    f"Iterating through an entire dataloader of {len(dataloader)} batches of size {batch_size} "
    f"took the following time per batch:\n{summarize_times(batch_iteration_times)}\n\n"
    f"Total batch size:\n{summarize(total_sizes)}"
)
for k, v in batch_sizes.items():
    print(f"  Size of {k}:\n    {summarize(v)}")

  0%|          | 0/5 [00:00<?, ?it/s]

Sampling Dataloader Iteration Speed:   0%|          | 0/30 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Sampling Batch:   0%|          | 0/5 [00:00<?, ?it/s]

Iterating through an entire dataloader of 5 batches of size 16 took the following time per batch:
0:00:00.029285 ± 0:00:00.001813 (0:00:00.027680-0:00:00.035112)
Histogram:
▇█▅▁▂▂▁▁▁▂
o 0:00:00.027680 (left endpoint)
---------o 0:00:00.028423 (right endpoint)

Total batch size:
170.5 kB ± 52.1 kB (110.3 kB-239.4 kB)
  Size of event_mask:
    512 Bytes ± 0 Bytes (512 Bytes-512 Bytes)
  Size of time_delta:
    2.0 kB ± 0 Bytes (2.0 kB-2.0 kB)
  Size of static_indices:
    128 Bytes ± 0 Bytes (128 Bytes-128 Bytes)
  Size of static_measurement_indices:
    128 Bytes ± 0 Bytes (128 Bytes-128 Bytes)
  Size of dynamic_indices:
    63.9 kB ± 19.8 kB (41.0 kB-90.1 kB)
  Size of dynamic_measurement_indices:
    63.9 kB ± 19.8 kB (41.0 kB-90.1 kB)
  Size of dynamic_values:
    31.9 kB ± 9.9 kB (20.5 kB-45.1 kB)
  Size of dynamic_values_mask:
    8.0 kB ± 2.5 kB (5.1 kB-11.3 kB)
peak memory: 824.50 MiB, increment: 3.96 MiB
CPU times: user 5.72 s, sys: 240 ms, total: 5.96 s
Wall time: 5.05 s


## Other Pipelines
### TemporAI Format

In [32]:
import pandas as pd

In [35]:
e_df

event_id,subject_id,timestamp,event_type,age,age_is_inlier
u32,u8,datetime[μs],cat,f64,bool
0,0,2010-06-24 13:23:00,"""ADMISSION&VITA…",-0.558276,true
1,0,2010-06-24 14:23:00,"""VITAL&LAB""",-0.55825,true
2,0,2010-06-24 15:23:00,"""VITAL&LAB""",-0.558224,true
3,0,2010-06-24 16:23:00,"""VITAL&LAB""",-0.558199,true
4,0,2010-06-24 17:23:00,"""VITAL&LAB""",-0.558173,true
5,0,2010-06-24 18:23:00,"""VITAL&LAB""",-0.558148,true
6,0,2010-06-24 19:23:00,"""VITAL&LAB""",-0.558122,true
7,0,2010-06-24 20:23:00,"""VITAL&LAB""",-0.558097,true
8,0,2010-06-24 21:23:00,"""LAB""",-0.558071,true
9,0,2010-06-24 22:23:00,"""LAB""",-0.558045,true


In [None]:
def ESD_to_temporai(ESD: Dataset) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Converts an ESD data format into a TemporAI dataset format."""

    static_df = ESD.subjects_df.select(
        'subject_id', *[pl.col(c) for c, cfg in ESD.measurement_configs.items() if cfg.temporality == 'static']
    ).to_pandas()
    static_df.set_index('subject_id', inplace=True)
    
    time_series_df = (
        ESD.events_df
        .select(
            'subject_id', 'event_id', 'timestamp', *[
                pl.col(c) for c, cfg in ESD.measurement_configs.items()
                if cfg.temporality == 'functional_time_dependent'
            ]
        )
        .join(
            (
                ESD.dynamic_measurements_df
                .select(
                    'event_id', *[
                        pl.col(c) for c, cfg in ESD.measurement_configs.items()
                        if cfg.temporality == 'dynamic'
                    ], *[
                        pl.col(cfg.values_column) for _, cfg in ESD.measurement_configs.items()
                        if cfg.temporality == 'dynamic' and cfg.modality == 'multivariate_regression'
                    ]
                )
            ),
            on=['event_id'],
            how='inner'
        )
        .drop('event_id')
        .groupby('subject_id', 'timestamp')
        .agg(pl.all().count().map_alias(lambda c: f"{c}/))
    )