In [287]:
import os, io, re, datetime as dt
import pandas as pd
import numpy as np
from google.colab import drive

**Note on synthetic data**
The data used was generated using the [Synthea Mitre repository](https://github.com/synthetichealth/synthea). I did not include the repo directly, however, steps to recreate the data our below. Feel free to try it out:
1. git clone https://github.com/synthetichealth/synthea
2. cd synthea
3. Ensure Java SDK v11-v17 is installed (as suggested in README)
4. Under "*synthea/src/main/resources/synthea.properties*", change *exporter.csv.export* to true
5. on Windows, run .\gradlew run --args="-p 5000 -d ..\modules --exporter.csv.export=true"
  - -p 7500 -- sets population size
  - -d ../modules -- path to folder to add synthetic data as part of creation. For this project, we add synthetic PHQ-9 data from "/modules/phq9/phq9.json" module (included in reference docs)
  - Files are output under *output/csv* in the repo.
6. Using 7-Zip, compress (.gz) each of **patients, encounters, conditions, medications, and observations**individually. Upload to path of your liking (This code runs with uploaded to *submission/synthetic-data*)

**How to run**
1. Turn on High-RAM (the last 2 df's are big)
2. Ensure the LOCAL_DIR and OUT_DIR match wheren your reading and saving data
3. Run below cells to load data
4. To generate data, run *build_features* in the last section (after compiling fn's before)

In [288]:
LOCAL_DIR = '/content/drive/MyDrive/UT-AIHC/HRP/synthetic-data'
OUT_DIR = '/content/drive/MyDrive/UT-AIHC/HRP/out-data'
drive.mount('/content/drive', force_remount=True)
os.makedirs(OUT_DIR, exist_ok=True)

### Configuration ###
# Rolling window (days) for if a patient has went out of DX coverage
PDC_WINDOW_DAYS = 90

# Max days for a patient to calculate
MAX_SEQUENCE_DAYS = 128

Mounted at /content/drive


In [289]:
# clean dataframe to have stripped, lower case data
def to_lower(df):
  df.columns = [c.strip().lower() for c in df.columns]
  return df

# parse dates for a set of columns in a dataframe
def parse_dates(df, cols):
  for c in cols:
    if c in df.columns:
      df[c] = pd.to_datetime(df[c], errors="coerce", utc=True).dt.tz_localize(None)
  return df

patients_raw =      to_lower(pd.read_csv(f'{LOCAL_DIR}/patients.csv.gz'), )
encounters_raw =    to_lower(pd.read_csv(f'{LOCAL_DIR}/encounters.csv.gz'))
conditions_raw =    to_lower(pd.read_csv(f'{LOCAL_DIR}/conditions.csv.gz'))
observations_raw =  to_lower(pd.read_csv(f'{LOCAL_DIR}/observations.csv.gz'))
medications_raw =   to_lower(pd.read_csv(f'{LOCAL_DIR}/medications.csv.gz'))
patients_raw =      parse_dates(patients_raw, ["birthdate", "deathdate"])
encounters_raw =    parse_dates(encounters_raw, ["start", "stop"])
conditions_raw =    parse_dates(conditions_raw, ["start", "stop"])
medications_raw =   parse_dates(medications_raw, ["start", "stop"])
observations_raw =  parse_dates(observations_raw, ["date"])

## Variables

In [290]:
# https://loinc.org/44249-1
PHQ9_TOTAL_LOINC = "44261-6"  # PHQ-9 total score [Reported]
PHQ9_ITEM_CODES = {
  "44250-9", "44255-8", "44259-0", "44254-1", "44251-7",
  "44258-2", "44252-5", "44253-3", "44260-8",
  "44261-6",
}
SCREENING_TOT_CODES = {
  "phq9":   {"44261-6"},
  "phq2":   {"55758-7"},
  "gad7":   {"70274-6"},
  "auditc": {"75626-2"},
  "dast10": {"82667-7"},
}
# Fallback keyword detectors (case-insensitive)
SCREENING_KWS = {
  "phq9":   ["phq9", "phq-9", "phq9 total", "phq-9 total"],
  "phq2":   ["phq2", "phq2 total", "phq-2 total"],
  "gad7":   ["gad7", "gad-7", "gad7 total", "gad7-total"],
  "auditc": ["auditc", "audit-c", "auditc total", "audit-c total"],
  "dast10": ["dast10", "dast-10", "dast10 total", "dast-10 total"],
}
SMOKING_STATUS_CODE = "72166-2"
SMOKER_COLLAPSE = {
  "never smoker": "never", "never": "never",
  "former smoker": "former","past": "former",
  "current some day smoker": "current", "current every day smoker": "current",
  "current": "current"
}
ALCOHOL_COLLAPSE = {
  "none": "none","no use": "none",
  "moderate": "moderate","light": "moderate",
  "heavy": "heavy","risky": "heavy","hazardous": "heavy"
}
PREGNANCY_CODES = { "2106-3", "80384-1", "2112-1" }
ENCOUNTER_TYPES = {
  "emergency": "EMERGENCY",
  "inpatient": "INPATIENT",
  "ambulatory": "OUTPATIENT",
  "outpatient": "OUTPATIENT",
  "urgentcare": "URGENTCARE",
  "wellness": "WELLNESS",
}
# "https://psychcentral.com/blog/top-25-psychiatric-medications-for-2020#top-25-list"
# combined with ones I know from https://www.nami.org/about-mental-illness/treatments/mental-health-medications/
ANTIDEPRESSANT_KEYWORDS = [
  "alprazolam", "xanax", "amitriptyline", "elavil", "aripiprazole",
  "abilify", "asenapine", "saphris", "atomoxetine", "strattera",
  "brexpiprazole", "rexulti", "bupropion", "wellbutrin", "buspirone",
  "buspar", "buprenorphine", "sublocade", "citalopram", "celexa",
  "clonazepam", "klonopin", "desvenlafaxine", "pristiq", "deutetrabenazine",
  "austedo", "diazepam", "valium", "duloxetine", "cymbalta",
  "escitalopram", "lexapro", "esketamine", "spravato", "fluoxetine",
  "prozac", "gabapentin", "neurontin", "haloperidol", "haldol",
  "hydroxyzine", "vistaril", "lamotrigine", "lamictal", "levomilnacipran",
  "lofexidine", "lucemyra", "lorazepam", "ativan", "mirtazapine",
  "remeron", "milnacipran", "naltrexone", "vivitrol", "nortriptyline",
  "olanzapine", "zyprexa", "paroxetine", "paxil", "pregabalin",
  "lyrica", "quetiapine", "seroquel", "risperidone", "risperdal",
  "sertraline", "zoloft", "trazodone", "desyrel", "venlafaxine",
  "effexor", "vortioxetine", "vilazodone", "imipramine", "desipramine",
  "clomipramine", "doxepin"
]
ANTIDEPRESSANT_REGEX = re.compile(r"(" + "|".join([re.escape(k) for k in ANTIDEPRESSANT_KEYWORDS]) + r")", re.I)

## Clean Data

In [291]:
def collect_tables(patients_df, encounters, conditions, medications, observations):
  # patients
  pat = patients_df[["id", "birthdate", "deathdate", "race", "ethnicity", "gender"]].copy()
  pat = pat.rename(columns={ "id": "patient_id", "gender": "sex" })
  pat["sex"] = patients_df["gender"].str.upper().map({"M":"M","F":"F"}).fillna("UNK")
  pat["race"] = patients_df["race"].fillna("Unknown")

  # encounters
  enc = encounters[["id", "patient", "start", "stop", "encounterclass", "code", "description", "payer"]].copy()
  enc["encounterclass"] = enc["encounterclass"].astype(str).str.lower()
  enc["enc_class_bucket"] = enc["encounterclass"].map(ENCOUNTER_TYPES).fillna("OTHER")
  enc = enc.rename(columns={
    "patient": "patient_id",
    "start": "start_time",
    "stop": "stop_time",
    "id": "encounter_id"
  })

  # conditions
  cond = conditions[["patient","encounter","start","stop","code","description"]].copy()
  cond = cond.rename(columns={
    "patient": "patient_id",
    "start": "start_time",
    "stop": "stop_time",
    "encounter": "encounter_id",
    "code": "snomed_code"
  })

  # medications
  med = (medications[[
    "patient", "encounter", "start", "stop",
    "code", "description", "dispenses", "totalcost",
    "reasoncode", "reasondescription"
  ]].copy()
    .rename(columns={
      "patient": "patient_id",
      "start": "start_time",
      "stop": "stop_time",
      "encounter": "encounter_id",
      "code": "rx_code"
    })
  )

  # observations
  obs = observations[[
    "date", "patient", "encounter", "category", "code",
    "description", "value", "units", "type"
  ]].copy().rename(columns={
    "patient": "patient_id",
    "date": "obs_time",
    "code": "loinc_code"
  })
  obs["obs_time"] = pd.to_datetime(obs["obs_time"], errors="coerce").dt.floor("D")
  obs["loinc_code"] = obs["loinc_code"]
  obs["description"] = obs["description"].apply(lambda x: x.lower())
  obs["val_str"] = obs["value"].apply(lambda x: x.lower())
  obs["val_num"] = pd.to_numeric(obs["value"], errors="coerce")

  return pat, enc, cond, med, obs

## Screening scores, lifestyle choices

In [292]:
"""
Extract a screening score from observations given keywords to search
and a patient id
"""
def get_screening_score(obs_df, loinc_codes, keywords, out_col):
  obs = obs_df[["patient_id", "obs_time", "loinc_code", "description", "value"]].copy()
  obs["patient_id"]  = obs["patient_id"].astype(str)
  obs["obs_time"]    = pd.to_datetime(obs["obs_time"], errors="coerce")
  obs["date"]        = obs["obs_time"].dt.floor("D")
  obs["description"] = obs["description"].astype(str)
  obs["value_num"]   = pd.to_numeric(obs["value"], errors="coerce")

  # masks
  tot_code = obs["loinc_code"].isin(set(loinc_codes))
  kws = pd.Series(False, index=obs.index)
  if keywords:
    pattern = "|".join([re.escape(k) for k in keywords])
    kws = obs["description"].str.contains(pattern, case=False, na=False)

  # Keep candidate rows with a valid date and numeric value
  keep = (tot_code | kws) & obs["date"].notna() & obs["value"].notna()
  result = obs.loc[keep, ["patient_id", "obs_time", "date", "value"]] # Include 'date' in result
  if result.empty:
      return pd.DataFrame(columns=["patient_id", "date", out_col])

  # Sort then pick the last value per day
  cand = result.dropna(subset=["date"]).sort_values(["patient_id", "date", "obs_time"]) # Drop rows with NaT in 'date' before sorting
  last_daily = cand.groupby(["patient_id", "date"], as_index=False).tail(1)  # last row per group
  last_daily = last_daily.rename(columns={"value": out_col})
  last_daily[out_col] = last_daily[out_col].astype("float32")

  return last_daily[["patient_id", "date", out_col]]

In [293]:
"""
Extract Smoking Status, Alcohol Use, and Pregnancy value from Observations into:
- smoking_status: values s.a. "current every day smoker", "former smoker", "never smoker"
- alcohol_use: values s.a. "heavy", "moderate", "none"
- pregnancy_pos: 1 if the textual value for LOINC matches contains "pos"
"""
def extract_life_obs(obs_in):
  obs = obs_in.copy()
  obs["patient_id"]  = obs["patient_id"].astype(str)
  obs["obs_time"]    = pd.to_datetime(obs["obs_time"], errors="coerce")
  obs["date"]        = obs["obs_time"].dt.floor("D")
  obs["desc_lc"]     = obs["description"].astype(str).str.lower().str.strip()
  obs["val_str"]     = obs["value"].astype(str).str.lower().str.strip()
  obs["val_num"]     = pd.to_numeric(obs["value"], errors="coerce")

  def last_daily_rec(rec):
    return rec if rec.empty else (rec.dropna(subset=["date"])
      .sort_values(["patient_id", "date", "obs_time"])
      .groupby(["patient_id", "date"], as_index=False)
      .tail(1))

  smok = last_daily_rec(obs.loc[
    obs["loinc_code"].eq(SMOKING_STATUS_CODE),
     ["patient_id", "date", "obs_time", "val_str"]]
  )
  if smok.empty:
    smoke_df = smok.assign(bucket="unknown")[["patient_id", "date", "bucket"]]
  else:
    smoke_df = (smok
      .assign(bucket = smok["val_str"].replace(SMOKER_COLLAPSE))
      .rename(columns={"val_str": "smoking_status"})
      [["patient_id", "date", "bucket"]])

  # pregnancy
  preg = last_daily_rec(obs.loc[
    obs["loinc_code"].isin(PREGNANCY_CODES),
    ["patient_id", "date", "obs_time","val_str"]]
  )
  if preg.empty:
    preg_df = pd.DataFrame(columns=["patient_id", "date", "pregnancy_pos"])
  else:
    preg_df = preg.assign(pregnancy_pos=preg["val_str"].str.contains("pos", na=False).astype("int8"))[
        ["patient_id","date","pregnancy_pos"]
    ]

  # alcohol
  alc = last_daily_rec(obs.loc[
    obs["loinc_code"].eq("75626-2"),
    ["patient_id",  "date", "obs_time", "val_num"]]
  )
  if alc.empty:
    a_txt = last_daily_rec(obs.loc[
      obs["desc_lc"].isin(ALCOHOL_COLLAPSE.keys()) | obs["val_str"].isin(ALCOHOL_COLLAPSE.keys()),
      ["patient_id", "obs_time", "desc_lc", "val_str"]])
    if a_txt.empty:
      alcohol_df = a_txt.assign(bucket="unknown")[["patient_id", "date", "bucket"]]
    else:
      raw = a_txt["val_str"].where(a_txt["val_str"].isin(ALCOHOL_COLLAPSE), a_txt["desc_lc"])
      alcohol_df = (a_txt
        .assign(bucket = raw.replace(ALCOHOL_COLLAPSE))
        [["patient_id",  "date", "bucket"]]
      )
  else:
    alcohol_df = (alc
      .assign(
        bucket=pd.cut(alc["val_num"],
        bins=[-np.inf, 0, 3, np.inf],
        labels=["none", "moderate", "heavy"])
      )
      [["patient_id", "date", "bucket"]])


  return smoke_df, preg_df, alcohol_df

## Coverage Calendar

In [294]:
"""
Build the antidepressant "coverage calendar":
- For each medication, per patient per-day:
  - flag ad_covered if covered with mental health meds.
  - If stop_time is missing or < start_time, guesstimate 30 days as a supply
"""
def build_coverage_calendar(med_df):
  rows = []
  for _, row in med_df.iterrows():
    pid = row["patient_id"]
    start = row["start_time"]
    stop = row["stop_time"]
    if pd.isna(start):
      continue
    if pd.isna(stop) or stop < start:
      stop = start + pd.Timedelta(days=30)  # <-- ambiguous assumption; swap to days_supply if available.
    for day in pd.date_range(start.floor("D"), stop.floor("D"), freq="D"):
      rows.append((row["patient_id"], day, 1))

  cov = pd.DataFrame(rows, columns=["patient_id","date","ad_covered"])
  cov = (cov
    .groupby(["patient_id","date"], as_index=False)["ad_covered"]
    .max()
  )
  return cov

In [295]:
"""
Add "PDC" and "gap days" to the coverage calendar.
- pdc_{window_days}: rolling mean of ad_covered over trailing window time (90 days)
- Calculate "late refills" by zeros on non-covered days

Adds:
- pdc_{window_days}: rolling mean of ad_covered over trailing window time (90 days)
- ad_gap_days: number of consecutive uncovered days up to (and including) today.
"""
def add_pdc_and_gaps(cov, window_days):
  cov = cov.sort_values(["patient_id","date"])
  cov["date"] = pd.to_datetime(cov["date"], errors="coerce").dt.floor("D")
  cov = cov.dropna(subset=["patient_id","date"])
  cov["patient_id"] = cov["patient_id"].astype(str)
  cov["ad_covered"] = cov["ad_covered"].astype("int8").fillna(0).astype("int8")

  def per_patient(g):
    g = g.set_index("date").asfreq("D", fill_value=0).sort_index()
    pid = str(g["patient_id"].iloc[0])

    # compute a rolling mean over the 30 day window
    pdc = g["ad_covered"].rolling(window_days, min_periods=1).mean().astype("float32")

    # Gaps: consecutive zeros since last covered day
    zero = (g["ad_covered"] == 0).astype("int16")
    groups = g["ad_covered"].cumsum()              # increments after each covered day
    gap = zero.groupby(groups).cumsum().astype("int16")

    return pd.DataFrame({
      "patient_id": pid,
      "date": g.index,
      "ad_covered": g["ad_covered"].astype("int8").values,
      f"pdc_{window_days}": pdc.values,
      "ad_gap_days": gap.values,
    })

  out = (cov
    .groupby("patient_id", group_keys=False)
    .apply(per_patient)
  ).reset_index(drop=True)

  assert "date" in out.columns, "add_pdc_and_gaps: 'date' not in output columns"
  out["date"] = pd.to_datetime(out["date"], errors="coerce").dt.floor("D")
  out["patient_id"] = out["patient_id"].astype(str)
  return out

## Daily Patient & Events

In [296]:
"""
For each patient, create a new row for every calendar day between
their first and last activity in the "calendar"

This is how we'll be able to calculate alert metrics
"""
def dense_grid(activity_df):
  df = activity_df[["patient_id", "date"]].copy()
  df = df.dropna(subset=["patient_id", "date"])
  df["patient_id"] = df["patient_id"].astype(str)
  df["date"] = pd.to_datetime(df["date"], errors="coerce").dt.floor("D")
  df = df.dropna(subset=["date"])

  agg = (df
    .groupby("patient_id")["date"]
    .agg(["min","max"])
    .reset_index()
  )

  rows = []
  for _, row in agg.iterrows():
    for d in pd.date_range(row["min"], row["max"], freq="D"):
      rows.append((row["patient_id"], d))

  return pd.DataFrame(rows, columns=["patient_id","date"])

In [309]:
"""
Construct the final p/patient, p/day feature table "daily" consumed by the LSTM.
Per day, encode:
- Static demographics (sex, race)
- Age
- hot encoding for any admission
- per-day encounter class bucket
- per-day medication adherence (ad_covered, PDC_{window}, and ad_gap_days)
- 7-day rolling sum of encounters
- PHQ-9 score per day (since last)
- hot encoding Sex and race
- days_since_prev: days since last activity

Output: DataFrame sorted by [patient_id, date]; one row per patient-day with all covariates.
"""

def build_daily_features(
  pats_df, enc, cond, med, obs,
  smoke, preg, alcohol, ad_daily,
  phq9, phq2, gad7, auditc, dast10,
  max_seq_days
):
  def require_cols(df, name, cols):
    missing = [c for c in cols if c not in df.columns]
    if missing:
      raise ValueError(f"{name} missing columns: {missing}")
  # Anchor dates from any activity type
  dates_enc = enc.assign(date=enc["start_time"].dt.floor("D"))[["patient_id","date"]]
  dates_cond= cond.assign(date=cond["start_time"].dt.floor("D"))[["patient_id","date"]]
  dates_med = med.assign(date=med["start_time"].dt.floor("D"))[["patient_id","date"]]
  dates_obs = obs.assign(date=obs["obs_time"].dt.floor("D"))[["patient_id","date"]]

  activity = (
    pd.concat([dates_enc, dates_cond, dates_med, dates_obs], ignore_index=True)
      .dropna()
      .drop_duplicates()
  )

  require_cols(activity, "activity", ["patient_id","date"])
  daily = (dense_grid(activity)
    .reset_index()
    .sort_values(["patient_id","date"])
    .groupby("patient_id", group_keys=False, as_index=False).tail(max_seq_days)
  )

  # --- Demographics ---
  daily = daily.merge(pats_df[["patient_id","birthdate"]], how="left", on="patient_id")
  daily["age_years"] = (daily["date"] - daily["birthdate"]).dt.days // 365
  daily = daily.drop(columns=["birthdate"])

  # repeatedly join sources to just these keys to keep memory low
  keep_keys = daily[["patient_id","date"]].drop_duplicates()

  daily = daily.merge(pats_df[["patient_id","birthdate", "sex", "race"]], how="left", on="patient_id")
  daily["age_years"] = ((daily["date"] - daily["birthdate"]).dt.days // 365).astype("int16")
  daily = daily.drop(columns=["birthdate"])

  # --- Encounter bucketrs ---
  enc_day = enc.assign(date=enc["start_time"].dt.floor("D"))  # [pid, date, enc_class_bucket]
  enc_day = enc_day.merge(keep_keys, on=["patient_id","date"], how="inner")  # only keys we care about
  enc_counts = (
    pd.crosstab(index=[enc_day["patient_id"], enc_day["date"]],
                columns=enc_day["enc_class_bucket"])
    .reset_index()
    .rename_axis(None, axis=1)
  )


  daily = daily.merge(enc_counts, how="left", on=["patient_id","date"])
  bucket_cols = [c for c in ["EMERGENCY","INPATIENT","OUTPATIENT","URGENTCARE","WELLNESS","OTHER"] if c in daily.columns]
  for buck in bucket_cols:
    if buck not in daily.columns:
      daily[buck] = 0
  daily[bucket_cols] = daily[bucket_cols].fillna(0).astype("int16")
  daily["visits_per_day"] = daily[bucket_cols].sum(axis=1).astype("int16")

  # 7-day rolling adm util over sum over days
  daily = daily.sort_values(["patient_id","date"])
  daily["util_7d"] = (
    daily.groupby("patient_id", observed=True)["visits_per_day"]
      .transform(lambda s: s.rolling(7, min_periods=1).sum())
      .astype("float32")
  )
  daily = daily.drop(columns=["visits_per_day"])

  # screening
  phq9_    = phq9.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(phq9_, on=["patient_id","date"], how="left")
  phq2_   = phq2.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(phq2_, on=["patient_id","date"], how="left")
  gad7_    = gad7.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(gad7_, on=["patient_id","date"], how="left")
  auditc_  = auditc.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(auditc_, on=["patient_id","date"], how="left")
  dast10_  = dast10.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(dast10_, on=["patient_id","date"], how="left")

  # smoking
  smo  = smoke.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(smo[["patient_id","date"]], on=["patient_id","date"], how="left")

  # pregnancy
  prego   = preg.merge(keep_keys, on=["patient_id","date"], how="inner")
  preg_agg = prego.groupby(["patient_id","date"], as_index=False)["pregnancy_pos"].max()
  daily   = daily.merge(preg_agg, on=["patient_id","date"], how="left")

  # Alcohol
  alc    = alcohol.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(alc[["patient_id","date"]], on=["patient_id","date"], how="left")

  for col in [c for c in daily.columns if c.startswith(("smoke_", "alcohol_", "pregnancy_pos"))]:
    daily[col] = daily[col].fillna(0).astype("int8")

  # Adherence features
  add_    = ad_daily.merge(keep_keys, on=["patient_id","date"], how="inner")
  daily   = daily.merge(add_, on=["patient_id","date"], how="left")
  daily[["ad_covered", f"pdc_{PDC_WINDOW_DAYS}", "ad_gap_days"]] = (
    daily[["ad_covered", f"pdc_{PDC_WINDOW_DAYS}", "ad_gap_days"]].fillna(0))

  daily["ad_covered"] = daily["ad_covered"].astype("int8")
  daily["ad_gap_days"] = daily["ad_gap_days"].astype("int16")
  daily[f"pdc_{PDC_WINDOW_DAYS}"] = daily[f"pdc_{PDC_WINDOW_DAYS}"].astype("float32")

  # --- Gaps since previous days ---
  daily = daily.sort_values(["patient_id","date"])
  daily["prev_date"] = daily.groupby("patient_id", observed=True)["date"].shift(1)
  daily["days_since_prev"] = (daily["date"] - daily["prev_date"]).dt.days.fillna(0).astype("int16")
  daily = daily.drop(columns=["prev_date"])

  # --- Hot encode sex and race---
  daily["sex_M"] = (daily["sex"] == "M").astype("int8")
  daily["sex_F"] = (daily["sex"] == "F").astype("int8")

  # Top 6 races
  for r in daily["race"].value_counts(dropna=False).head(6).index.tolist():
    key = re.sub("[^A-Za-z0-9]+","_", str(r)).lower()
    daily[f"race_{key}"] = (daily["race"] == r).astype("int8")
  daily = daily.drop(columns=["race", "sex"])

  # COmpact data
  for c in daily.select_dtypes(include=["float64","float32"]).columns:
    daily[c] = daily[c].astype("float32")
  for c in daily.select_dtypes(include=["int64"]).columns:
    daily[c] = daily[c].astype("int16")

  return daily.sort_values(["patient_id","date"]).reset_index(drop=True)

In [298]:
"""
Create event log used for modeling tokenization an sequences. Sort chron

Event Codes:
- "dx_{code}" diagnoses with code
- "med_{code}" medications with code
- "adm_{bucket}" encounters with bucket (type)
- "obs_{code}" observations with code
"""
def build_event_stream(cond, med, enc, obs):
    events = []

    dx_cond = cond.copy()
    dx_cond["date"] = dx_cond["start_time"].dt.floor("D")
    events.append(
      dx_cond.assign(
        event_type="dx",
        code=dx_cond["snomed_code"].astype(str)
      )[["patient_id","date","event_type","code"]]
    )

    med = med.copy()
    med["date"] = med["start_time"].dt.floor("D")
    events.append(
      med.assign(
        event_type="med",
        code=med["rx_code"].astype(str)
      )[["patient_id","date","event_type","code"]]
    )

    enc_codes = enc.assign(date=enc["start_time"].dt.floor("D"))
    enc_codes["code"] = "ADM:" + enc_codes["enc_class_bucket"].astype(str)
    events.append(
      enc_codes[["patient_id", "date", "enc_class_bucket", "code"]]
        .rename(columns={"enc_class_bucket": "event_type"})
        .assign(event_type="adm")
    )

    obs_ev = obs.copy()
    obs_ev["date"] = obs_ev["obs_time"].dt.floor("D")
    obs_ev = obs_ev.assign(
      event_type="obs",
      code=obs_ev["loinc_code"].astype(str)
    )
    events.append(obs_ev[["patient_id","date","event_type","code"]])

    events = (
      pd.concat(events, ignore_index=True)
      .dropna(subset=["patient_id", "date", "event_type", "code"])
      .drop_duplicates()
    ).sort_values(["patient_id", "date", "event_type", "code"]).reset_index(drop=True)

    return events

## Clean Data & Build

In [299]:
def clean_daily(pat, enc, con, med, obs, dayyy, events):
  dayly = dayyy.copy()

  # If screening PHQ-9 exists, coalesce into main phq9
  if "phq9_screen" in dayly.columns:
    if "phq9" not in dayly.columns:
        dayly["phq9"] = np.nan
    dayly["phq9"] = dayly["phq9"].fillna(dayly["phq9_screen"])
    dayly = dayly.drop(columns=["phq9_screen"])

  # Fill numbers to 0
  num_fills = [
    "phq9", "phq2", "gad7", "auditc", "dast10",
    "util_7d", "ad_gap_days", "ad_covered", "pregnancy_pos"
  ]
  for col in num_fills:
    if col in dayly.columns:
      dayly[col] = dayly[col].astype("float32").fillna(0.0)

  # Encounter one-hots (ensure present, int, no NaNs). Keep your bucket set consistent with Step 2.
  expected_enc_buckets = set(ENCOUNTER_TYPES.values()).union({"OTHER"})
  for buck in expected_enc_buckets:
    if buck not in dayly.columns:
      dayly[buck] = 0
    dayly[buck] = dayly[buck].fillna(0).astype("int16")

  # Demographic one-hots shouldnâ€™t be NaN; default to 0
  for c in [col for col in dayly.columns if col.startswith("sex_") or col.startswith("race_")]:
    dayly[c] = dayly[c].fillna(0).astype("int8")

  # days_since_prev should be int, non-null
  if "days_since_prev" in dayly.columns:
    dayly["days_since_prev"] = dayly["days_since_prev"].fillna(0).astype("int32")

  return dayly;

In [300]:
# int8 one-hots + unknown
def one_hot(df_in, col, prefix, classes):
  out = df_in.copy()
  for k in classes:
    out[f"{prefix}_{k}"] = (out[col] == k).astype("int8")
  out[f"{prefix}_unknown"] = (~out[col].isin(classes)).astype("int8")
  return out.drop(columns=[col])

In [301]:
def build_features():
  print("Collecting...")
  pats, enc, cond, med, obs = collect_tables(
    patients_raw, encounters_raw, conditions_raw, medications_raw, observations_raw
  );


  phq9   = get_screening_score(obs, SCREENING_TOT_CODES["phq9"],   SCREENING_KWS["phq9"],   "phq9")
  phq2   = get_screening_score(obs, SCREENING_TOT_CODES["phq2"],   SCREENING_KWS["phq2"],   "phq2")
  gad7   = get_screening_score(obs, SCREENING_TOT_CODES["gad7"],   SCREENING_KWS["gad7"],   "gad7")
  auditc = get_screening_score(obs, SCREENING_TOT_CODES["auditc"], SCREENING_KWS["auditc"], "auditc")
  dast10 = get_screening_score(obs, SCREENING_TOT_CODES["dast10"], SCREENING_KWS["dast10"], "dast10")
  print("Retrieved screening scores from observations")

  smoke, preg, alcohol = extract_life_obs(obs)
  smoke = one_hot(smoke, col="bucket", prefix="smoke", classes=("never", "former", "current"))
  alcohol = one_hot(alcohol, col="bucket", prefix="alcohol", classes=("none", "moderate", "heavy"))
  print("Retrieved [bad, exc pregnancy] lifestyle choices from observations")

  med["is_antidepressant"] = med["description"].astype(str).str.contains(ANTIDEPRESSANT_REGEX)
  ad_med = med.loc[med["is_antidepressant"]].copy()
  ad_cov = build_coverage_calendar(ad_med)
  print("Created coverage calendar")
  ad_daily = add_pdc_and_gaps(ad_cov, PDC_WINDOW_DAYS).reset_index(drop=True)
  print("Added rolling times and filled ")

  daily = build_daily_features(
    pats, enc, cond, med, obs,
    smoke, preg, alcohol, ad_daily,
    phq9, phq2, gad7, auditc, dast10,
    MAX_SEQUENCE_DAYS);
  daily_keys = daily[["patient_id", "date"]].drop_duplicates();
  print("Computed daily features")

  events = build_event_stream(cond, med, enc, obs);
  print("Computed event stream")
  # Filter events, encounters, and conditions for the patients MAX_SEQUENCE_DAYS data
  events = events.merge(daily_keys, on=["patient_id","date"], how="inner")
  enc = (enc
    .assign(date=enc["start_time"].dt.floor("D"))
    .merge(daily_keys, on=["patient_id", "date"], how="inner")
  )
  cond_f = (cond
    .assign(date=cond["start_time"].dt.floor("D"))
    .merge(daily_keys, on=["patient_id", "date"], how="inner")
  );

  daily = clean_daily(pats, enc, cond_f, med, obs, daily, events)
  print("Cleaned data")
  print("Done!")
  return pats, enc, cond_f, med, obs, daily, events

## Initiate

In [310]:
p, e, c, m, o, daily_df, events_df = build_features();

Collecting...
Retrieved screening scores from observations
Retrieved [bad, exc pregnancy] lifestyle choices from observations


  med["is_antidepressant"] = med["description"].astype(str).str.contains(ANTIDEPRESSANT_REGEX)


Created coverage calendar


  .apply(per_patient)


Added rolling times and filled 
Computed daily features
Computed event stream
Cleaned data
Done!


In [308]:
def write_csv_gzip(df, filename):
  path = os.path.join(OUT_DIR, filename)
  os.makedirs(os.path.dirname(path), exist_ok=True)
  df.to_csv(path, index=False, compression='gzip')
write_csv_gzip(daily_df, f'{OUT_DIR}/daily.csv')
write_csv_gzip(events_df, f'{OUT_DIR}/events.csv')
write_csv_gzip(p, f'{OUT_DIR}/patients.csv')
write_csv_gzip(e, f'{OUT_DIR}/encounters.csv')
write_csv_gzip(c, f'{OUT_DIR}/conditions.csv')
write_csv_gzip(m, f'{OUT_DIR}/medications.csv')
write_csv_gzip(o, f'{OUT_DIR}/observations.csv')

In [306]:
print(p.columns)
print(e.columns)
print(c.columns)
print(m.columns)
print(o.columns)
print(daily_df.columns)
print(events_df.columns)

Index(['patient_id', 'birthdate', 'deathdate', 'race', 'ethnicity', 'sex'], dtype='object')
Index(['encounter_id', 'patient_id', 'start_time', 'stop_time',
       'encounterclass', 'code', 'description', 'payer', 'enc_class_bucket',
       'date'],
      dtype='object')
Index(['patient_id', 'encounter_id', 'start_time', 'stop_time', 'snomed_code',
       'description', 'date'],
      dtype='object')
Index(['patient_id', 'encounter_id', 'start_time', 'stop_time', 'rx_code',
       'description', 'dispenses', 'totalcost', 'reasoncode',
       'reasondescription', 'is_antidepressant'],
      dtype='object')
Index(['obs_time', 'patient_id', 'encounter', 'category', 'loinc_code',
       'description', 'value', 'units', 'type', 'val_str', 'val_num'],
      dtype='object')
Index(['index', 'patient_id', 'date', 'age_years', 'EMERGENCY', 'INPATIENT',
       'OTHER', 'OUTPATIENT', 'URGENTCARE', 'WELLNESS', 'util_7d', 'phq9',
       'phq2', 'gad7', 'auditc', 'dast10', 'pregnancy_pos', 'ad_covered

In [307]:
print(p.head(10))
print(e.head(10))
print(c.head(10))
print(m.head(10))
print(o.head(10))
print(daily_df.head(10))
print(events_df.head(10))

                             patient_id  birthdate  deathdate   race  \
0  88d5440c-437b-4660-a8a1-f046949473f9 2024-06-02        NaT  white   
1  13ba565d-338f-ad60-1907-17c35b00d7d0 1974-01-09 2008-01-27  black   
2  89c6ce99-7260-c794-98b4-0679e48bd017 1995-12-19        NaT  white   
3  17116b7b-f441-59f0-82fd-3cb0dfaf2f43 2020-07-13        NaT  white   
4  6d68c679-df19-1d63-6ff8-ae7e394d67d9 2012-03-06        NaT  white   
5  8b994084-8d8b-db73-5b0f-200e0029a69c 2014-04-18        NaT  white   
6  39b07cfb-d6e8-1cdd-339c-ca301ef1e9ac 2016-05-31        NaT  white   
7  f0ae9e6d-3317-881e-aed5-8799410f68c4 2019-12-08        NaT  white   
8  ce98bc3e-e083-8a7a-d0c0-53eefb9f0379 1999-02-01        NaT  white   
9  7409e445-a1b5-2517-d4bb-cc73ffc38287 2023-07-04        NaT  white   

     ethnicity sex  
0  nonhispanic   M  
1  nonhispanic   M  
2     hispanic   M  
3  nonhispanic   F  
4  nonhispanic   M  
5  nonhispanic   F  
6  nonhispanic   M  
7  nonhispanic   F  
8  nonhispanic   F

Reference