**Config**



In [None]:
import os, json
import numpy as np
import pandas as pd
from google.colab import drive
PARSED_DIR = '/content/drive/MyDrive/SymptomTrajectories/out-data'
TRAIN_DIR   = '/content/drive/MyDrive/SymptomTrajectories/training-data'
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


**Parameters**

In [None]:
DX_ADHERENCE_THRESHOLD    = 0.8 # % adherence line to signal decreasing MH
PHQ9_SPIKE_TRIGGER        = 5.0 # Numerical spike in phq 9 to signal decreasing MH
UTILIZATION_SPIKE_TRIGGER = 2.0
RELAPSE_HORIZON_DAYS      = 30

TRAIN_SPLIT = 0.75
VAL_SPLIT   = 0.10 # Clean slice for early stopping, debugging, etc
TEST_SPLIT  = 0.15

mental_health_regex     = r"depress|anxiet|bipolar|schizo|psych|suicid|ptsd|panic|ocd|substance|addict"
severe_condition_regex  = r"severe|psychosis|suicid|mania|catatonia|acute"

**Load Prepared Feature Data**

In [None]:
patients = pd.read_csv(os.path.join(PARSED_DIR, "patients.csv"),
  parse_dates=["birthdate"],
  usecols=["patient_id", "birthdate", "sex", "race"],
  low_memory=False,
  compression="gzip"
)
encounters = pd.read_csv(os.path.join(PARSED_DIR, "encounters.csv"),
  parse_dates=["start_time"],
  usecols=["patient_id", "start_time", "enc_class_bucket"],
  low_memory=False,
  compression="gzip"
)
conditions = pd.read_csv(os.path.join(PARSED_DIR, "conditions.csv"),
  parse_dates=["start_time"],
  usecols=["patient_id", "start_time", "snomed_code", "description"],
  low_memory=False,
  compression="gzip"
)
meds = pd.read_csv(os.path.join(PARSED_DIR, "medications.csv"),
  parse_dates=["start_time"],
  usecols=["patient_id", "start_time", "rx_code", "description"],
  low_memory=False,
  compression="gzip"
)
observations = pd.read_csv(os.path.join(PARSED_DIR, "observations.csv"),
  parse_dates=["obs_time"],
  usecols=["patient_id", "obs_time", "loinc_code", "description", "value", "units"],
  low_memory=False,
  compression="gzip"
)
daily = pd.read_csv(os.path.join(PARSED_DIR, "daily.csv"),
  parse_dates=["date"],
  low_memory=False,
  compression="gzip"
)
events = pd.read_csv(os.path.join(PARSED_DIR, "events.csv"),
  parse_dates=["date"],
  usecols=["patient_id", "date", "event_type", "code"],
  low_memory=False,
  compression="gzip"
)

In [1]:
# Helper to ensure no Errors when comparing dates w and w/o timezones
def harmonize_types(patients, encounters, conditions, meds, observations, daily, events):
  def _tz_naive(series):
    if not np.issubdtype(series.dtype, np.datetime64):
      series = pd.to_datetime(series, errors="coerce", utc=False)
    try:
      return series.dt.tz_localize(None)
    except AttributeError:
        return series

  daily["date"]            = _tz_naive(daily["date"])
  events["date"]           = _tz_naive(events["date"])
  encounters["start_time"] = _tz_naive(encounters["start_time"])
  conditions["start_time"] = _tz_naive(conditions["start_time"])
  # observations obs_time tz-naive (even though not used directly downstream)
  if "obs_time" in observations.columns:
    observations["obs_time"] = _tz_naive(observations["obs_time"])

  events["event_type"] = events["event_type"].astype(str).str.lower().str.strip()
  return patients, encounters, conditions, meds, observations, daily, events

patients, encounters, conditions, meds, observations, daily, events = harmonize_types(
    patients, encounters, conditions, meds, observations, daily, events
)

NameError: name 'patients' is not defined

# Factorize Events/Codes into UID's

In [None]:
"""
Factorize (event_type|code) -> int ids; group per‑day grouping.
Returns code ids per day, and list of all code id's for analysis
"""
def encode_codes_factorize(events):
  keys = events["event_type"] + "|" + events["code"]
  ids, uniques = pd.factorize(keys, sort=True)
  ids = (ids.astype(np.int32) + 1) if ids.size else ids  # reserve 0 for padding tensor for if non-empty

  # frequency for analysis
  freq = pd.Series(keys).value_counts()
  cookbook = (pd.DataFrame({"key": uniques})
    .assign(id=lambda d: np.arange(1, len(d)+1, dtype=np.int32))
    .assign(type=lambda d: d["key"].str.split("|", n=1).str[0],
            code=lambda d: d["key"].str.split("|", n=1).str[1],
            freq=lambda d: d["key"].map(freq).fillna(0).astype(int))
    [["id", "type", "code", "key", "freq"]]
  )

  events = events.assign(code_id=ids) if ids.size else events.assign(code_id=np.array([], dtype=np.int32))

  def _uniq_json(s):
    return json.dumps(sorted(set(map(int, s.tolist()))))

  per_day = (events
    .groupby(["patient_id", "date"], observed=True)["code_id"]
      .apply(_uniq_json)
      .reset_index()
      .rename(columns={"code_id": "code_ids"})
      if len(events) else pd.DataFrame(columns=["patient_id","date","code_ids"])
  )
  return per_day, cookbook

In [None]:
"""
Recompute labels
- y_relapse: future <= 30days: (ED/INPATIENT & MH condition same day) OR severe condition OR ad_gap_days >= 30
- y_det: future <= 30d: PHQ9 rises >= FLAG OR util_7d rises by >= FLAG OR medication starts (ad_covered 0->1)
- non_adherent_flag: pdc < FLAG OR ad_gap_days >= 1
"""
def compute_labels(daily, encounters, conditions):
  daily_df = daily.copy()

  # Day‑granularity views for encounters/conditions
  encounter_days_df  = encounters.assign(date=encounters["start_time"].dt.floor("D"))[ ["patient_id","date","enc_class_bucket"] ]
  condition_days_df  = conditions.assign(date=conditions["start_time"].dt.floor("D"))[ ["patient_id","date","description"] ]

  # --- Adherence label
  pdc_col = next(c for c in daily_df.columns if c.startswith("pdc_"))  # Step 2 guarantees one pdc_* column
  daily_df["non_adherent_flag"] = (
      (daily_df["ad_gap_days"].fillna(0).astype(float) >= 1) |
      (daily_df[pdc_col].fillna(0).astype(float) < 0.8)
  ).astype("int8")

  # --- Relapse anchors
  mental_health_regex     = r"depress|anxiet|bipolar|schizo|psych|suicid|ptsd|panic|ocd|substance|addict"
  severe_condition_regex  = r"severe|psychosis|suicid|mania|catatonia|acute"

  mh_conditions      = condition_days_df.loc[
      condition_days_df["description"].str.contains(mental_health_regex, case=False, na=False),
      ["patient_id","date"]
  ]
  severe_conditions  = condition_days_df.loc[
      condition_days_df["description"].str.contains(severe_condition_regex, case=False, na=False),
      ["patient_id","date"]
  ]
  acute_encounters   = encounter_days_df.loc[
      encounter_days_df["enc_class_bucket"].str.upper().isin(["EMERGENCY","INPATIENT"]),
      ["patient_id","date"]
  ]

  ed_with_mh_same_day   = mh_conditions.merge(acute_encounters, on=["patient_id","date"], how="inner").drop_duplicates()
  long_medication_gap   = daily_df.loc[daily_df["ad_gap_days"].fillna(0).astype(float) >= 30, ["patient_id","date"]]
  relapse_anchor_dates  = pd.concat([ed_with_mh_same_day, severe_conditions, long_medication_gap], ignore_index=True).drop_duplicates()

  # --- y_relapse: whether a relapse anchor occurs within 30 days after each day
  daily_df["y_relapse"] = 0
  for patient_id, patient_days in daily_df.groupby("patient_id", sort=False):
      row_index   = patient_days.index
      day_dates   = patient_days["date"].to_numpy(dtype="datetime64[D]")
      anchor_dates= relapse_anchor_dates.loc[
          relapse_anchor_dates["patient_id"] == patient_id, "date"
      ].to_numpy(dtype="datetime64[D]")
      anchor_dates = np.sort(anchor_dates)

      next_anchor_idx     = np.searchsorted(anchor_dates, day_dates + np.timedelta64(1, 'D'), side='left')
      has_future_anchor   = (next_anchor_idx < anchor_dates.size)
      within_30d          = np.zeros_like(has_future_anchor, dtype=bool)
      valid_rows          = np.where(has_future_anchor)[0]
      if valid_rows.size:
        deltas = anchor_dates[next_anchor_idx[valid_rows]] - day_dates[valid_rows]
        within_30d[valid_rows] = deltas <= np.timedelta64(30, 'D')
      daily_df.loc[row_index, "y_relapse"] = (has_future_anchor & within_30d).astype("int8")

  # --- y_det: subtle deterioration within the next 30 days
  daily_df["y_det"] = 0
  for patient_id, patient_days in daily_df.groupby("patient_id", sort=False):
      row_index        = patient_days.index
      day_dates        = patient_days["date"].to_numpy(dtype="datetime64[D]")
      phq9_values      = patient_days.get("phq9", pd.Series(index=row_index, dtype=float)).to_numpy(dtype=float)
      util7_values     = patient_days.get("util_7d", pd.Series(index=row_index, dtype=float)).to_numpy(dtype=float)
      adherence_values = patient_days.get("ad_covered", pd.Series(index=row_index, dtype=float)).to_numpy(dtype=float)

      n_days                 = len(patient_days)
      future_phq9_max        = np.full(n_days, np.nan, dtype=float)
      future_util7_max       = np.zeros(n_days, dtype=float)
      future_any_med_covered = np.zeros(n_days, dtype=bool)

      for i in range(n_days):
          in_horizon = (day_dates > day_dates[i]) & (day_dates <= day_dates[i] + np.timedelta64(30, 'D'))
          idx_future = np.where(in_horizon)[0]
          if idx_future.size:
              future_phq_vals   = phq9_values[idx_future]
              future_phq9_max[i]= np.nanmax(future_phq_vals) if np.isfinite(future_phq_vals).any() else np.nan
              future_util7_max[i]= util7_values[idx_future].max()
              future_any_med_covered[i] = (adherence_values[idx_future] == 1).any()

      future_phq9_rise   = (~np.isnan(future_phq9_max)) & (~np.isnan(phq9_values)) & ((future_phq9_max - phq9_values) >= 5.0)
      utilization_spike  = (future_util7_max - util7_values) >= 2.0
      medication_start   = (adherence_values == 0) & future_any_med_covered
      deterioration_flag = future_phq9_rise | utilization_spike | medication_start
      daily_df.loc[row_index, "y_det"] = deterioration_flag.astype("int8")

  return daily_df

## Generate Splits

In [None]:
"""
Do a stratified random sample of the population by
- stratifying into pos (y_relapse=1) and neg groups -> sampling -> merging
- Check each patient for for > 1 eventful day
"""

# Split by patient!! Models are smart
def build_splits(model_table, seed=42):
  cols = ["y_relapse", "y_det"]
  rng = np.random.default_rng(seed)

  per_patient_max = model_table.groupby("patient_id")[cols].max(min_count=1).fillna(0)
  has_pos = (per_patient_max.max(axis=1) >= 1).to_numpy()

  pos_pids = per_patient_max.index[has_pos].to_numpy()
  neg_pids = per_patient_max.index[~has_pos].to_numpy()
  rng.shuffle(pos_pids)
  rng.shuffle(neg_pids)

  def _split_group(arr):
    n = len(arr)
    n_tr = int(round(n * TRAIN_SPLIT))
    n_val = int(round(n * VAL_SPLIT))
    if n_tr + n_val > n:
        n_val = max(0, n - n_tr)
    n_te = max(0, n - n_tr - n_val)
    return arr[:n_tr], arr[n_tr:n_tr+n_val], arr[n_tr+n_val:]

  pos_tr, pos_val, pos_te = _split_group(pos_pids)
  neg_tr, neg_val, neg_te = _split_group(neg_pids)

  # Merge groups, re-randomize, split
  train_ids = np.array(list(pos_tr) + list(neg_tr))
  val_ids   = np.array(list(pos_val) + list(neg_val))
  test_ids  = np.array(list(pos_te) + list(neg_te))
  rng.shuffle(train_ids)
  rng.shuffle(val_ids)
  rng.shuffle(test_ids)
  split = (
    [(pid, "train") for pid in train_ids] +
    [(pid, "val")   for pid in val_ids] +
    [(pid, "test")  for pid in test_ids]
  )
  return pd.DataFrame(split, columns=["patient_id","split"])

"""
Return per-label positive weight = num neg / num pos
"Scalar": y_relapse, y_det, non_adherent_flag
"cb": comorbidity blindspots
"""
def compute_pos_weights(model_table):
  def _pos_weight(series):
    counts = series.astype("float32").value_counts()
    pos = int(counts.get(1.0, 0))
    neg = int(counts.get(0.0, 0))
    return float(neg / pos) if pos > 0 and neg > 0 else None

  scalar_cols = [c for c in ["y_relapse", "y_det", "non_adherent_flag"] if c in model_table.columns]
  scalar = { c: _pos_weight(model_table[c]) for c in scalar_cols }
  scalar = { k: v for k, v in scalar.items() if v is not None }

  cb_cols = [c for c in model_table.columns if c.startswith("cb_")]
  cb = { c: _pos_weight(model_table[c]) for c in cb_cols }
  cb = { k: v for k, v in cb.items() if v is not None }

  return { "scalar": scalar, "cb": cb }

## Prepare (Run this)

In [None]:
def build_training():
  # factorize codes and build per‑day id bags
  per_day_codes, cookbook = encode_codes_factorize(events);
  daily_labeled = compute_labels(daily, encounters, conditions)

  # merge to final train table
  model_table = daily_labeled.merge(per_day_codes, how="left", on=["patient_id", "date"])
  model_table["code_ids"] = model_table["code_ids"].fillna("[]")
  model_table.sort_values(["patient_id","date"], inplace=True)

  num_cols = model_table.select_dtypes(include=["number","bool"]).columns
  model_table[num_cols] = model_table[num_cols].fillna(0)
  print("Generated final model table")

  # splits + class imbalance weights
  splits = build_splits(model_table)
  posw   = compute_pos_weights(model_table)

  return model_table, splits, posw, cookbook

In [None]:
model_table, splits, posw, cookbook = build_training()

Generated final model table


In [None]:
for col_name in model_table.columns:
  if model_table[col_name].isnull().values.any():
    print(model_table[col_name].isnull().value_counts())

In [None]:
print("Rows:", len(model_table), "| Patients:", model_table["patient_id"].nunique(), "| Vocab ids:", len(cookbook))
print("Targets present:", [c for c in ["y_relapse","y_det","non_adherent_flag"] if c in model_table.columns])

Rows: 732544 | Patients: 5736 | Vocab ids: 634
Targets present: ['y_relapse', 'y_det', 'non_adherent_flag']


In [None]:
cookbook.to_csv(os.path.join(TRAIN_DIR, "dx_uniques.csv"), index=False)
model_table.to_csv(os.path.join(TRAIN_DIR, "model_table.csv"), index=False)
splits.to_csv(os.path.join(TRAIN_DIR, "splits.csv"), index=False)
with open(os.path.join(TRAIN_DIR, "pos_weights.json"), "w") as f:
  json.dump(posw, f)

**Reference**

In [None]:
print(model_table.columns)
print(model_table.head(20))

Index(['index', 'patient_id', 'date', 'age_years', 'EMERGENCY', 'INPATIENT',
       'OTHER', 'OUTPATIENT', 'URGENTCARE', 'WELLNESS', 'util_7d', 'phq9',
       'phq2', 'gad7', 'auditc', 'dast10', 'pregnancy_pos', 'ad_covered',
       'pdc_90', 'ad_gap_days', 'days_since_prev', 'sex_M', 'sex_F',
       'race_white', 'race_black', 'race_asian', 'race_hawaiian', 'race_other',
       'race_native', 'non_adherent_flag', 'y_relapse', 'y_det', 'code_ids'],
      dtype='object')
    index                            patient_id       date  age_years  \
0   22109  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-02-27         75   
1   22110  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-02-28         75   
2   22111  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-03-01         75   
3   22112  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-03-02         75   
4   22113  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-03-03         75   
5   22114  0002a287-8563-a8c2-4a4e-bea0aa749024 2025-03-04         75   
6   22115  000