# Synthetic Data Generation

This notebook generates some simple synthetic data for us to use to demonstrate the ESGPT pipeline. We'll generate a few files:
  1. `subjects.csv`, which contains static data about each subject.
  2. `admission_vitals.csv`, which contains records of admissions, transfers, and vitals signs.
  3. `lab_tests.csv`, which contains records of lab test measurements.
  
This is all synthetic data designed solely for demonstrating this pipeline. It is *not* real data, derived from real data, or designed to mimic real data in any way other than plausible file structure.

In [1]:
import random
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import polars as pl

random.seed(1)
np.random.seed(1)

In [2]:
# Parameters:
N_subjects = 100
OUT_DIR = Path("./raw")

## Subjects Data
Subjects will have the following static data elements, and will be organized by a fake identifier column called "MRN":
  * Date of birth
  * Eye Color (among options "BROWN", "BLUE", "HAZEL", "GREEN", "OTHER")
  * Height (in cm, as a raw number)

In [3]:
random.seed(1)
np.random.seed(1)

BASE_BIRTH_DATE = datetime(1980, 1, 1)
EYE_COLORS = ["BROWN", "BLUE", "HAZEL", "GREEN", "OTHER"]
EYE_COLOR_P = [0.45, 0.27, 0.18, 0.09, 0.01]


def yrs_to_dob(yrs: np.ndarray) -> list[str]:
    return [(BASE_BIRTH_DATE + timedelta(days=365 * x)).strftime("%m/%d/%Y") for x in yrs]


size = (N_subjects,)
subject_data = pl.DataFrame(
    {
        "MRN": np.random.randint(low=14221, high=1578208, size=size),
        "dob": yrs_to_dob(np.random.uniform(low=-10, high=10, size=size)),
        "eye_color": list(np.random.choice(EYE_COLORS, size=size, replace=True, p=EYE_COLOR_P)),
        "height": list(np.random.uniform(low=152.4, high=182.88, size=size)),
    }
).sample(fraction=1, with_replacement=False, shuffle=True, seed=1)

assert len(subject_data["MRN"].unique()) == N_subjects

subject_data.write_csv(OUT_DIR / "subjects.csv")
subject_data.head(3)

MRN,dob,eye_color,height
i64,str,str,f64
310243,"""07/28/1981""","""GREEN""",178.767932
384198,"""04/15/1985""","""BROWN""",168.319295
520533,"""04/15/1979""","""BROWN""",165.836447


## Admission Vitals Data
This file will contain records of admission start and end dates, admission department (among options "PULMONARY", "CARDIAC", and "ORTHOPEDIC"), and regular vitals signs (temperature and HR). In this format, admission dates will be replicated across each associated vitals signs measurement, which is wasteful. Real data would not likely be organized like this, but it gives us a more complex file format to work with in our example.

In [4]:
random.seed(1)
np.random.seed(1)

admit_vitals_data = {
    "MRN": [],
    "admit_date": [],
    "disch_date": [],
    "admit_department": [],
    "disch_department": [],
    "vitals_date": [],
    "HR": [],
    "temp": [],
    "weight": [],
}

BASE_ADMIT_DATE = datetime(2010, 1, 1)

hrs = 60
days = 24 * hrs
months = 30 * days

n_admissions_L = np.random.randint(low=1, high=4, size=size)
admit_depts_L = np.random.choice(["PULMONARY", "CARDIAC", "ORTHOPEDIC"], size=size, replace=True)
disch_depts_L = np.random.choice(["PULMONARY", "CARDIAC", "ORTHOPEDIC"], size=size, replace=True)

admissions_by_subject = {}

for MRN, n_admissions, admit_dept, disch_dept in zip(subject_data["MRN"], n_admissions_L, admit_depts_L, disch_depts_L):
    admit_gaps = np.random.uniform(low=1 * days, high=6 * months, size=(n_admissions,))
    admit_lens = np.random.uniform(low=12 * hrs, high=14 * days, size=(n_admissions,))

    running_end = BASE_ADMIT_DATE
    admissions_by_subject[MRN] = []

    for gap, L in zip(admit_gaps, admit_lens):
        running_start = running_end + timedelta(minutes=gap)
        running_end = running_start + timedelta(minutes=L)

        admissions_by_subject[MRN].append((running_start, running_end))

        vitals_time = running_start

        running_weight = np.random.uniform(low=120, high=200)
        running_HR = np.random.uniform(low=60, high=180)
        running_temp = np.random.uniform(low=95, high=101)
        while vitals_time < running_end:
            admit_vitals_data["MRN"].append(MRN)
            admit_vitals_data["admit_date"].append(running_start.strftime("%m/%d/%Y, %H:%M:%S"))
            admit_vitals_data["disch_date"].append(running_end.strftime("%m/%d/%Y, %H:%M:%S"))
            admit_vitals_data["admit_department"].append(admit_dept)
            admit_vitals_data["disch_department"].append(disch_dept)
            admit_vitals_data["vitals_date"].append(vitals_time.strftime("%m/%d/%Y, %H:%M:%S"))
            admit_vitals_data["weight"].append(running_weight)

            running_HR += np.random.uniform(low=-10, high=10)
            if running_HR < 30: running_HR = 30
            if running_HR > 300: running_HR = 300
                
            running_temp += np.random.uniform(low=-0.4, high=0.4)
            if running_temp < 95: running_temp = 95
            if running_temp > 104: running_temp = 104

            running_weight += np.random.uniform(low=-2., high=2.)

            admit_vitals_data["HR"].append(round(running_HR, 1))
            admit_vitals_data["temp"].append(round(running_temp, 1))

            if 7 < vitals_time.hour < 21:
                vitals_gap = 30 + np.random.uniform(low=-30, high=30)
            else:
                vitals_gap = 3 * hrs + np.random.uniform(low=-30, high=30)

            vitals_time += timedelta(minutes=vitals_gap)

admit_vitals_data = pl.DataFrame(admit_vitals_data).sample(
    fraction=1, with_replacement=False, shuffle=True, seed=1
)

admit_vitals_data.write_csv(OUT_DIR / "admit_vitals.csv")
admit_vitals_data.head(3)

MRN,admit_date,disch_date,admit_department,disch_department,vitals_date,HR,temp,weight
i64,str,str,str,str,str,f64,f64,f64
671425,"""02/25/2010, 16…","""03/02/2010, 06…","""PULMONARY""","""PULMONARY""","""02/27/2010, 14…",145.7,97.7,180.763303
980825,"""01/15/2010, 00…","""01/21/2010, 21…","""ORTHOPEDIC""","""CARDIAC""","""01/21/2010, 18…",138.3,98.7,152.283269
1499770,"""06/03/2010, 10…","""06/13/2010, 08…","""ORTHOPEDIC""","""CARDIAC""","""06/06/2010, 17…",162.5,96.9,155.804631


## Labs Data
This file will contain numerical laboratory test results for these subjects across the following lab tests:
  * potassium (3 - 6)
  * creatinine (0.4 - 1.5)
  * SOFA score (1, 2, 3, or 4)
  * Glasgow Coma Scale (1 - 15, discrete)
  * SpO2 (0-1)

In [5]:
random.seed(1)
np.random.seed(1)

labs_data = {
    "MRN": [],
    "timestamp": [],
    "lab_name": [],
    "lab_value": [],
    "weight": [],
}


def lab_delta_fn(running_vals: dict[str, float], lab_to_meas: str) -> float:
    do_outlier = np.random.uniform() < 0.0001

    if lab_to_meas not in ("GCS", "SOFA") and do_outlier:
        return 1e6

    old_val = running_vals[lab_to_meas]
    if lab_to_meas == "SOFA":
        delta = np.random.randint(low=-2, high=2)
        new_val = old_val + delta
        if new_val < 1:
            new_val = 1
        elif new_val > 4:
            new_val = 4
    elif lab_to_meas == "GCS":
        delta = np.random.randint(low=-4, high=4)
        new_val = old_val + delta
        if new_val < 1:
            new_val = 1
        elif new_val > 15:
            new_val = 15
    elif lab_to_meas == "SpO2":
        delta = np.random.randint(low=-2, high=2)
        new_val = old_val + delta
        if new_val < 50:
            new_val = 50
        elif new_val > 100:
            new_val = 100
    else:
        delta = np.random.uniform(low=-0.1, high=0.1)
        new_val = old_val + delta
        if new_val < 0:
            new_val = 0

    running_vals[lab_to_meas] = new_val
    return round(new_val, 2)


hrs = 60
days = 24 * hrs
months = 30 * days

for MRN, admissions in admissions_by_subject.items():
    lab_ps = np.random.dirichlet(alpha=[0.1 for _ in range(5)])

    base_lab_gaps = {
        "potassium": np.random.uniform(low=1 * hrs, high=48 * hrs),
        "creatinine": np.random.uniform(low=1 * hrs, high=48 * hrs),
        "SOFA": np.random.uniform(low=1 * hrs, high=48 * hrs),
        "GCS": np.random.uniform(low=1 * hrs, high=48 * hrs),
        "SpO2": np.random.uniform(low=15, high=1 * hrs),
    }

    running_weight = np.random.uniform(low=120, high=200)
    last_weights_time = None

    for st, end in admissions:
        running_lab_values = {
            "potassium": np.random.uniform(low=3, high=6),
            "creatinine": np.random.uniform(low=0.4, high=1.5),
            "SOFA": np.random.randint(low=1, high=4),
            "GCS": np.random.randint(low=1, high=15),
            "SpO2": np.random.randint(low=70, high=100),
        }

        running_weight += np.random.uniform(low=-10., high=10.)

        for lab in base_lab_gaps.keys():
            gap = base_lab_gaps[lab]
            labs_time = st + timedelta(minutes=gap + np.random.uniform(low=-30, high=30))

            while labs_time < running_end:
                labs_data["MRN"].append(MRN)
                labs_data["timestamp"].append(labs_time.strftime("%H:%M:%S-%Y-%m-%d"))
                labs_data["lab_name"].append(lab)

                labs_data["lab_value"].append(lab_delta_fn(running_lab_values, lab))

                if 7 < labs_time.hour < 21:
                    labs_gap = gap + np.random.uniform(low=-30, high=30)
                else:
                    labs_gap = min(2 * gap, 12 * hrs) + np.random.uniform(low=-30, high=30)

                running_weight += np.random.uniform(low=-1., high=1.)

                if last_weights_time is None:
                
                if
                meas_weight = (last_weights_time is None or labs_time - last_weights_time > timedelta(days=1))

                if meas_weight:
                    labs_data["weight"].append(running_weight)
                    last_weights_time = labs_time
                else:
                    labs_data["weight"].append(None)

                labs_time += timedelta(minutes=labs_gap)

labs_data = pl.DataFrame(labs_data).sample(fraction=1, with_replacement=False, shuffle=True, seed=1)

labs_data.write_csv(OUT_DIR / "labs.csv")
labs_data.head(3)

MRN,timestamp,lab_name,lab_value,weight
i64,str,str,f64,f64
980825,"""17:52:31-2010-…","""SpO2""",50.0,
559302,"""10:32:10-2010-…","""SpO2""",52.0,
407452,"""06:47:36-2010-…","""SOFA""",1.0,
