# Demographic Feature Exploration

Extracts a compact demographic feature vector **[Age, Gender, Height (cm), Weight (kg)]** per admission (`hadm_id`) for the RL agent (P-CAFE).

**Pipeline overview:**
1. Age & Gender from `patients` / `admissions` core tables.
2. Height, Weight, and BMI primarily from the outpatient `omr` table (closest record to `admittime`).
3. Fallback to ICU `chartevents` when `omr` data is absent.
4. Remaining NaNs are imputed with the dataset median.
5. Final output saved as `demographic_features.parquet` with columns `subject_id`, `hadm_id`, `demographic_vec`.

## Section 1: Setup & Imports

In [None]:
import pandas as pd
import numpy as np

print("Libraries imported successfully!")

In [None]:
# Adjust this base path to your MIMIC-IV data directory
base_path = '~/data/physionet.org/files/mimiciv/3.1/'

patients_file    = base_path + 'hosp/patients.csv.gz'
admissions_file  = base_path + 'hosp/admissions.csv.gz'
omr_file         = base_path + 'hosp/omr.csv.gz'
chartevents_file = base_path + 'icu/chartevents.csv.gz'

print("File paths defined.")

## Section 2: Core Demographics (Age & Gender)

In [None]:
# Load core patient and admission data
print("Loading patient and admission data for demographic features...")

patients_df = pd.read_csv(
    patients_file,
    usecols=['subject_id', 'gender', 'anchor_age', 'anchor_year']
)
print(f"Patients loaded: {len(patients_df):,} rows")

admissions_df = pd.read_csv(
    admissions_file,
    usecols=['subject_id', 'hadm_id', 'admittime'],
    parse_dates=['admittime']
)
print(f"Admissions loaded: {len(admissions_df):,} rows")

# Merge patients and admissions on subject_id
demo_df = admissions_df.merge(patients_df, on='subject_id', how='left')

# Drop rows where hadm_id is NaN
demo_df = demo_df.dropna(subset=['hadm_id'])

# Calculate admission age
demo_df['admission_age'] = (
    demo_df['anchor_age'] + (demo_df['admittime'].dt.year - demo_df['anchor_year'])
)

# Encode gender as binary (M=0, F=1)
demo_df['gender_encoded'] = demo_df['gender'].map({'M': 0, 'F': 1})

print(f"\nMerged demographic data: {len(demo_df):,} rows")
print(f"Age range: {demo_df['admission_age'].min():.1f} - {demo_df['admission_age'].max():.1f}")
print(f"Gender distribution: {demo_df['gender'].value_counts().to_dict()}")

## Section 3: Process OMR Table (Primary Source for Vitals)

In [None]:
# Unit conversion constants
LBS_TO_KG    = 0.453592
INCHES_TO_CM = 2.54

# Process OMR table: pick the closest record to admittime for each (hadm_id, result_name) pair
print("Loading and processing OMR table...")
omr_relevant = ['Weight (Lbs)', 'Height (Inches)', 'BMI (kg/m2)', 'BMI']

omr_df = pd.read_csv(
    omr_file,
    usecols=['subject_id', 'chartdate', 'result_name', 'result_value'],
    parse_dates=['chartdate']
)
omr_df = omr_df[omr_df['result_name'].isin(omr_relevant)].copy()
omr_df['result_value'] = pd.to_numeric(omr_df['result_value'], errors='coerce')

# Merge with admissions to get admittime
adm_slim = demo_df[['subject_id', 'hadm_id', 'admittime']].copy()
omr_merged = omr_df.merge(adm_slim, on='subject_id', how='inner')

# Calculate absolute time difference between OMR chartdate and admission time
omr_merged['time_diff'] = (omr_merged['admittime'] - omr_merged['chartdate']).abs()

# For each hadm_id and result_name, keep only the closest measurement
omr_merged = omr_merged.sort_values('time_diff')
omr_closest = omr_merged.drop_duplicates(subset=['hadm_id', 'result_name'], keep='first')

# Pivot: one row per hadm_id with a column for each result_name
omr_pivot = omr_closest.pivot_table(
    index='hadm_id',
    columns='result_name',
    values='result_value',
    aggfunc='first'
).reset_index()
omr_pivot.columns.name = None

rename_map = {
    'Weight (Lbs)':  'omr_weight',
    'Height (Inches)': 'omr_height',
    'BMI (kg/m2)':   'omr_bmi',
    'BMI':           'omr_bmi_alt',
}
omr_pivot = omr_pivot.rename(columns={k: v for k, v in rename_map.items() if k in omr_pivot.columns})

# Consolidate BMI columns (prefer 'BMI (kg/m2)', fallback to 'BMI')
if 'omr_bmi' not in omr_pivot.columns:
    omr_pivot['omr_bmi'] = np.nan
if 'omr_bmi_alt' in omr_pivot.columns:
    omr_pivot['omr_bmi'] = omr_pivot['omr_bmi'].fillna(omr_pivot['omr_bmi_alt'])
    omr_pivot = omr_pivot.drop(columns=['omr_bmi_alt'])
if 'omr_weight' not in omr_pivot.columns:
    omr_pivot['omr_weight'] = np.nan
if 'omr_height' not in omr_pivot.columns:
    omr_pivot['omr_height'] = np.nan

# Convert units: Lbs -> Kg, Inches -> cm
omr_pivot['omr_weight'] = omr_pivot['omr_weight'] * LBS_TO_KG
omr_pivot['omr_height'] = omr_pivot['omr_height'] * INCHES_TO_CM

print(f"OMR pivot shape: {omr_pivot.shape}")
print(f"Columns: {list(omr_pivot.columns)}")

## Section 4: Process Chartevents Table (Fallback Source)

In [None]:
# Process chartevents in chunks for Weight and Height (fallback source)
print("\nReading chartevents in chunks for Weight and Height...")

# ItemIDs of interest
CE_WEIGHT_KG   = 226512  # Weight (Kg)
CE_WEIGHT_LBS  = 226531  # Weight (Lbs)
CE_HEIGHT_CM   = 226730  # Height (cm)
CE_HEIGHT_INCH = 226707  # Height (Inch)
ce_target_ids  = [CE_WEIGHT_KG, CE_WEIGHT_LBS, CE_HEIGHT_CM, CE_HEIGHT_INCH]

ce_chunks = []
chunk_number = 0

for chunk in pd.read_csv(
    chartevents_file,
    usecols=['hadm_id', 'itemid', 'valuenum'],
    chunksize=1000000
):
    chunk_number += 1
    filtered = chunk[chunk['itemid'].isin(ce_target_ids)].dropna(subset=['hadm_id', 'valuenum'])
    if len(filtered) > 0:
        ce_chunks.append(filtered)

print(f"Chunks processed: {chunk_number}")

ce_df = pd.concat(ce_chunks, ignore_index=True)
ce_df['hadm_id'] = ce_df['hadm_id'].astype(int)

# First non-null measurement per (hadm_id, itemid)
ce_first = ce_df.groupby(['hadm_id', 'itemid'])['valuenum'].first().reset_index()
ce_first.columns = ['hadm_id', 'itemid', 'valuenum']

# Pivot to get one column per itemid
ce_pivot = ce_first.pivot_table(
    index='hadm_id', columns='itemid', values='valuenum', aggfunc='first'
).reset_index()
ce_pivot.columns.name = None
ce_pivot.columns = ['hadm_id'] + [f'item_{c}' for c in ce_pivot.columns[1:]]

# Build ce_weight_kg and ce_height_cm with unit conversion
wkg_col   = f'item_{CE_WEIGHT_KG}'
wlbs_col  = f'item_{CE_WEIGHT_LBS}'
hcm_col   = f'item_{CE_HEIGHT_CM}'
hinch_col = f'item_{CE_HEIGHT_INCH}'

ce_pivot['ce_weight_kg'] = ce_pivot.get(wkg_col, pd.Series(dtype=float))
if wlbs_col in ce_pivot.columns:
    ce_pivot['ce_weight_kg'] = ce_pivot['ce_weight_kg'].fillna(ce_pivot[wlbs_col] * LBS_TO_KG)

ce_pivot['ce_height_cm'] = ce_pivot.get(hcm_col, pd.Series(dtype=float))
if hinch_col in ce_pivot.columns:
    ce_pivot['ce_height_cm'] = ce_pivot['ce_height_cm'].fillna(ce_pivot[hinch_col] * INCHES_TO_CM)

ce_out = ce_pivot[['hadm_id', 'ce_weight_kg', 'ce_height_cm']].copy()
print(f"Chartevents fallback shape: {ce_out.shape}")

## Section 5: Consolidate & Fallback Logic

In [None]:
# Merge Core Demographics, OMR data, and Chartevents data on hadm_id
print("\nMerging core demographics, OMR and chartevents data...")

full_df = demo_df[['subject_id', 'hadm_id', 'admittime', 'admission_age', 'gender_encoded']].copy()
full_df = full_df.merge(omr_pivot, on='hadm_id', how='left')
full_df = full_df.merge(ce_out, on='hadm_id', how='left')

# Final weight/height: OMR first, fallback to chartevents
full_df['final_weight_kg'] = full_df['omr_weight'].fillna(full_df['ce_weight_kg'])
full_df['final_height_cm'] = full_df['omr_height'].fillna(full_df['ce_height_cm'])
full_df['final_bmi']       = full_df['omr_bmi'].copy()

# Calculate BMI dynamically if missing but weight and height are available
missing_bmi_mask = (
    full_df['final_bmi'].isna() &
    full_df['final_weight_kg'].notna() &
    full_df['final_height_cm'].notna() &
    (full_df['final_height_cm'] > 0)
)
full_df.loc[missing_bmi_mask, 'final_bmi'] = (
    full_df.loc[missing_bmi_mask, 'final_weight_kg'] /
    (full_df.loc[missing_bmi_mask, 'final_height_cm'] / 100) ** 2
)

# Impute remaining NaNs with the dataset median
weight_median = full_df['final_weight_kg'].median()
height_median = full_df['final_height_cm'].median()
age_median    = full_df['admission_age'].median()

full_df['final_weight_kg'] = full_df['final_weight_kg'].fillna(weight_median)
full_df['final_height_cm'] = full_df['final_height_cm'].fillna(height_median)
full_df['admission_age']   = full_df['admission_age'].fillna(age_median)
full_df['gender_encoded']  = full_df['gender_encoded'].fillna(0)  # Default to M=0 if missing

print(f"Consolidated dataframe shape: {full_df.shape}")

## Section 6: Vector Creation & Output

In [None]:
# Print missing data statistics
total_adm = len(full_df)

omr_weight_pct   = full_df['omr_weight'].notna().sum()      / total_adm * 100
omr_height_pct   = full_df['omr_height'].notna().sum()      / total_adm * 100
omr_bmi_pct      = full_df['omr_bmi'].notna().sum()         / total_adm * 100
final_weight_pct = full_df['final_weight_kg'].notna().sum() / total_adm * 100
final_height_pct = full_df['final_height_cm'].notna().sum() / total_adm * 100
final_bmi_pct    = full_df['final_bmi'].notna().sum()        / total_adm * 100

print("=" * 80)
print("WEIGHT / HEIGHT / BMI AVAILABILITY")
print("=" * 80)
print(f"{'Metric':<20} {'OMR only':>12} {'After fallback':>16}")
print("-" * 52)
print(f"{'Weight':<20} {omr_weight_pct:>11.1f}% {final_weight_pct:>15.1f}%")
print(f"{'Height':<20} {omr_height_pct:>11.1f}% {final_height_pct:>15.1f}%")
print(f"{'BMI':<20} {omr_bmi_pct:>11.1f}% {final_bmi_pct:>15.1f}%")
print("=" * 80)

In [None]:
# Create demographic_vec column: [age, gender, final_height_cm, final_weight_kg]
print("\nCreating demographic feature vectors...")

stack = np.column_stack([
    full_df['admission_age'].values,
    full_df['gender_encoded'].values,
    full_df['final_height_cm'].values,
    full_df['final_weight_kg'].values,
])
full_df['demographic_vec'] = list(stack)

# Save final dataframe with ONLY subject_id, hadm_id, and demographic_vec
output_df = full_df[['subject_id', 'hadm_id', 'demographic_vec']].copy()

output_file = '../data/output/demographic_features.parquet'
output_df.to_parquet(output_file, index=False)

print(f"Demographic features saved to: {output_file}")
print(f"Output shape: {output_df.shape}")
print(f"\nFirst 5 rows:")
print(output_df.head())

# Verify vector format
sample_vec = output_df['demographic_vec'].iloc[0]
print(f"\nSample demographic_vec: {sample_vec}")
print(f"Vector format: [Age={sample_vec[0]:.1f}, Gender={sample_vec[1]}, Height={sample_vec[2]:.1f}cm, Weight={sample_vec[3]:.1f}kg]")