In [1]:
# ============================================================
# COLAB_SETUP.py
#
# *** RUN THIS CELL FIRST IN YOUR COLAB NOTEBOOK ***
#
# This cell does 4 things:
#   1. Mounts your Google Drive
#   2. Copies all pipeline .py files from Drive → /content/sepsis_rl_pipeline/
#   3. Adds /content/sepsis_rl_pipeline to sys.path
#   4. Verifies everything is in place
#
# HOW TO USE:
#   - Upload all pipeline .py files to your Google Drive at:
#       My Drive/sepsis_rl_pipeline/
#   - Then run this cell once at the start of every Colab session
# ============================================================

import os
import sys
import shutil
from google.colab import drive

# ── Step 1: Mount Google Drive ────────────────────────────────
print("Step 1: Mounting Google Drive...")
drive.mount("/content/drive", force_remount=False)
print("  ✓ Drive mounted\n")

# ── Step 2: Define paths ──────────────────────────────────────
# Where your pipeline files live ON DRIVE
DRIVE_PIPELINE_DIR = "/content/drive/MyDrive/sepsis_rl_pipeline"

# Where Colab will run them from (local, fast, in-memory)
LOCAL_PIPELINE_DIR = "/content/sepsis_rl_pipeline"

# ── Step 3: Copy files from Drive → /content/ ────────────────
print("Step 2: Copying pipeline files to /content/sepsis_rl_pipeline/ ...")
os.makedirs(LOCAL_PIPELINE_DIR, exist_ok=True)

# Check Drive folder exists
if not os.path.exists(DRIVE_PIPELINE_DIR):
    raise FileNotFoundError(
        f"\nCould not find pipeline files on Drive at:\n"
        f"  {DRIVE_PIPELINE_DIR}\n\n"
        "Please upload all pipeline .py files to:\n"
        "  My Drive/sepsis_rl_pipeline/\n"
        "then re-run this cell."
    )

# Copy every .py file
copied = []
for fname in os.listdir(DRIVE_PIPELINE_DIR):
    if fname.endswith(".py"):
        src = os.path.join(DRIVE_PIPELINE_DIR, fname)
        dst = os.path.join(LOCAL_PIPELINE_DIR, fname)
        shutil.copy2(src, dst)
        copied.append(fname)
        print(f"  ✓ {fname}")

if not copied:
    raise FileNotFoundError(
        f"No .py files found in {DRIVE_PIPELINE_DIR}.\n"
        "Make sure you uploaded the pipeline files to Drive first."
    )

# ── Step 4: Add to sys.path ───────────────────────────────────
print("\nStep 3: Adding to sys.path...")
if LOCAL_PIPELINE_DIR not in sys.path:
    sys.path.insert(0, LOCAL_PIPELINE_DIR)
os.chdir(LOCAL_PIPELINE_DIR)
print(f"  ✓ sys.path updated")
print(f"  ✓ Working directory set to {LOCAL_PIPELINE_DIR}")

# ── Step 5: Verify config is importable ──────────────────────
print("\nStep 4: Verifying setup...")
try:
    import config
    print(f"  ✓ config.py loaded  (GCP_PROJECT_ID = '{config.GCP_PROJECT_ID}')")
except ImportError as e:
    print(f"  ✗ Could not import config.py: {e}")

print("\n" + "="*55)
print("✓ Setup complete — you can now run the pipeline steps.")
print("="*55)
print("""
Next steps (run each in its own cell):

  exec(open("00_auth_drive.py").read())   # Mount Drive + create folders
  exec(open("01_auth_gcp.py").read())     # Authenticate to BigQuery
  exec(open("02_extract_data.py").read()) # Extract data from BigQuery
  exec(open("03_process_data.py").read()) # Process features + rewards
  exec(open("04_model.py").read())        # Train RL model
  exec(open("05_save_outputs.py").read()) # Save plots + report
""")

Step 1: Mounting Google Drive...
Mounted at /content/drive
  ✓ Drive mounted

Step 2: Copying pipeline files to /content/sepsis_rl_pipeline/ ...
  ✓ config.py

Step 3: Adding to sys.path...
  ✓ sys.path updated
  ✓ Working directory set to /content/sepsis_rl_pipeline

Step 4: Verifying setup...
  ✓ config.py loaded  (GCP_PROJECT_ID = 'silken-physics-467815-g5')

✓ Setup complete — you can now run the pipeline steps.

Next steps (run each in its own cell):

  exec(open("00_auth_drive.py").read())   # Mount Drive + create folders
  exec(open("01_auth_gcp.py").read())     # Authenticate to BigQuery
  exec(open("02_extract_data.py").read()) # Extract data from BigQuery
  exec(open("03_process_data.py").read()) # Process features + rewards
  exec(open("04_model.py").read())        # Train RL model
  exec(open("05_save_outputs.py").read()) # Save plots + report



In [2]:
# ============================================================
# config.py  —  Central configuration for the sepsis RL pipeline
#
# Running in Google Colab — no credential files needed.
# Authentication is handled by google.colab.auth (browser popup).
#
# *** The only value you must change before running: ***
#     GCP_PROJECT_ID  (line 14)
# ============================================================

# ----------------------------------------------------------------
# GCP / BigQuery
# ----------------------------------------------------------------
GCP_PROJECT_ID = "silken-physics-467815-g5"   # ← CHANGE THIS
MIMIC_DATASET  = "physionet-data.mimiciii_clinical"

# No key files or OAuth JSON needed — Colab authenticates as
# your logged-in Google account via a one-time browser popup.

# ----------------------------------------------------------------
# Drive folder layout
# All paths are relative to your Drive root ("My Drive")
# ----------------------------------------------------------------
DRIVE_BASE_FOLDER   = "sepsis_rl"          # top-level folder
DRIVE_RAW_FOLDER    = "raw"                # raw BigQuery extracts
DRIVE_PROC_FOLDER   = "processed"          # windowed / feature-engineered data
DRIVE_OUTPUT_FOLDER = "outputs"            # model artefacts + metrics

# ----------------------------------------------------------------
# Pipeline parameters
# ----------------------------------------------------------------
WINDOW_HOURS      = 4          # length of each time window
MAX_WINDOWS       = 20         # max windows per stay  (= 80 hours)
N_FLUID_BINS      = 5          # discrete fluid action levels  (0–4)
N_VASO_BINS       = 5          # discrete vasopressor levels   (0–4)
TERMINAL_REWARD   = 15.0       # ±15 reward at end of episode
SOFA_PENALTY      = 0.5        # weight on SOFA-change shaping reward
MORTALITY_HORIZON = 90         # days for primary outcome (90-day mortality)

# ----------------------------------------------------------------
# File names written to Drive
# ----------------------------------------------------------------
RAW_COHORT_FILE   = "cohort.parquet"
RAW_FEATURES_FILE = "features.parquet"
RAW_ACTIONS_FILE  = "actions_raw.parquet"
RAW_MORTALITY_FILE= "mortality.parquet"

PROC_STATES_FILE  = "states.parquet"
PROC_ACTIONS_FILE = "actions_binned.parquet"
PROC_DATASET_FILE = "sepsis_rl_dataset.parquet"

OUTPUT_MODEL_FILE   = "model.pt"
OUTPUT_METRICS_FILE = "metrics.json"
OUTPUT_PLOTS_DIR    = "plots"

In [3]:
# ============================================================
# 00_auth_drive.py  —  Google Drive auth for Google Colab
#
# Uses Colab's built-in drive.mount() — no credential files,
# no OAuth2 JSON, no token files. Just a one-click browser prompt.
#
# Drive is mounted as a local filesystem at /content/drive/
# so all file operations use standard Python file I/O —
# no upload/download API calls needed.
# ============================================================

import sys, os

# ── Colab path bootstrap ─────────────────────────────────────
# Tells Python where to find config.py and other modules.
# Update PIPELINE_DIR if your files are in a different location.
PIPELINE_DIR = "/content/sepsis_rl_pipeline"
if PIPELINE_DIR not in sys.path:
    sys.path.insert(0, PIPELINE_DIR)
os.chdir(PIPELINE_DIR)  # also set working directory
# ─────────────────────────────────────────────────────────────

from google.colab import drive

# Where Colab mounts your Drive
DRIVE_MOUNT_POINT = "/content/drive/MyDrive"


# ============================================================
# MOUNT DRIVE
# ============================================================

def mount_drive():
    """
    Mount Google Drive into the Colab filesystem.
    Shows a one-time browser prompt asking you to allow access.
    After that, your Drive is accessible like a normal folder.
    """
    drive.mount("/content/drive", force_remount=False)
    print(f"  Google Drive mounted at /content/drive")
    print(f"  Your files are at: {DRIVE_MOUNT_POINT}")


# ============================================================
# FOLDER HELPERS
# (Just local filesystem ops — Drive looks like a normal folder)
# ============================================================

def get_folder_path(relative_path: str) -> str:
    """
    Build an absolute path inside your Drive.
    e.g. get_folder_path("sepsis_rl/raw")
         → "/content/drive/MyDrive/sepsis_rl/raw"
    """
    return os.path.join(DRIVE_MOUNT_POINT, relative_path)


def setup_drive_folders() -> dict:
    """
    Create the pipeline folder structure inside Drive if it doesn't exist.
    Returns a dict of { name: absolute_path } for each folder.
    """
    import config

    folders = {
        "base":      get_folder_path(config.DRIVE_BASE_FOLDER),
        "raw":       get_folder_path(f"{config.DRIVE_BASE_FOLDER}/{config.DRIVE_RAW_FOLDER}"),
        "processed": get_folder_path(f"{config.DRIVE_BASE_FOLDER}/{config.DRIVE_PROC_FOLDER}"),
        "outputs":   get_folder_path(f"{config.DRIVE_BASE_FOLDER}/{config.DRIVE_OUTPUT_FOLDER}"),
    }

    for name, path in folders.items():
        os.makedirs(path, exist_ok=True)
        print(f"  Ready: {path}")

    return folders


# ============================================================
# FILE HELPERS
# (Simple wrappers so other scripts stay consistent)
# ============================================================

def save_parquet(df, folder_path: str, filename: str):
    """Save a DataFrame as parquet directly into a Drive folder."""
    path = os.path.join(folder_path, filename)
    df.to_parquet(path, index=False)
    print(f"  Saved → {path}")
    return path


def load_parquet(folder_path: str, filename: str):
    """Load a parquet file directly from a Drive folder."""
    import pandas as pd
    path = os.path.join(folder_path, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"File not found: {path}\n"
            "Make sure the previous pipeline step has been run."
        )
    df = pd.read_parquet(path)
    print(f"  Loaded {len(df):,} rows ← {path}")
    return df


def save_json(data: dict, folder_path: str, filename: str):
    """Save a dict as JSON directly into a Drive folder."""
    import json
    path = os.path.join(folder_path, filename)
    with open(path, "w") as fh:
        json.dump(data, fh, indent=2)
    print(f"  Saved → {path}")
    return path


def load_json(folder_path: str, filename: str) -> dict:
    """Load a JSON file directly from a Drive folder."""
    import json
    path = os.path.join(folder_path, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    with open(path) as fh:
        return json.load(fh)


def save_file(src_path: str, folder_path: str):
    """Copy any file (e.g. .pt, .png) into a Drive folder."""
    import shutil
    filename = os.path.basename(src_path)
    dst_path = os.path.join(folder_path, filename)
    shutil.copy2(src_path, dst_path)
    print(f"  Saved → {dst_path}")
    return dst_path


# ============================================================
# ENTRY POINT
# ============================================================

if __name__ == "__main__":
    print("=" * 55)
    print("STEP 0: Google Drive Mount & Folder Setup")
    print("=" * 55)

    mount_drive()
    folders = setup_drive_folders()

    print("\nDrive folder paths:")
    for k, v in folders.items():
        print(f"  {k:12s} →  {v}")

    print("\nDrive setup complete.")

STEP 0: Google Drive Mount & Folder Setup
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
  Google Drive mounted at /content/drive
  Your files are at: /content/drive/MyDrive
  Ready: /content/drive/MyDrive/sepsis_rl
  Ready: /content/drive/MyDrive/sepsis_rl/raw
  Ready: /content/drive/MyDrive/sepsis_rl/processed
  Ready: /content/drive/MyDrive/sepsis_rl/outputs

Drive folder paths:
  base         →  /content/drive/MyDrive/sepsis_rl
  raw          →  /content/drive/MyDrive/sepsis_rl/raw
  processed    →  /content/drive/MyDrive/sepsis_rl/processed
  outputs      →  /content/drive/MyDrive/sepsis_rl/outputs

Drive setup complete.


In [4]:
# ============================================================
# 01_auth_gcp.py  —  GCP / BigQuery auth for Google Colab
#
# Uses google.colab.auth.authenticate_user() — no service account
# key file, no JSON credentials on disk.
#
# Colab authenticates as YOUR Google account (the one logged
# into the notebook), so you just click "Allow" in a popup
# once per session. That's it.
#
# Your account needs these IAM roles on the GCP project:
#   - BigQuery Data Viewer
#   - BigQuery Job User
#
# Public helpers used by other steps:
#   get_bq_client()            →  google.cloud.bigquery.Client
#   run_query(client, sql, params)  →  pd.DataFrame
# ============================================================

import sys, os

# ── Colab path bootstrap ─────────────────────────────────────
# The pipeline files must be in /content/sepsis_rl_pipeline/
# Run the SETUP CELL in the Colab notebook first — it copies
# everything from Drive into this location automatically.
PIPELINE_DIR = "/content/sepsis_rl_pipeline"
os.makedirs(PIPELINE_DIR, exist_ok=True)   # safe to call if already exists
if PIPELINE_DIR not in sys.path:
    sys.path.insert(0, PIPELINE_DIR)
os.chdir(PIPELINE_DIR)
# ─────────────────────────────────────────────────────────────


from google.colab  import auth
from google.cloud  import bigquery
import google.auth

import config


# ----------------------------------------------------------------
# Authentication
# ----------------------------------------------------------------

def authenticate_gcp():
    """
    Trigger Colab's built-in GCP authentication popup.
    Safe to call multiple times — only prompts once per session.
    """
    auth.authenticate_user()
    print("  GCP authentication complete (logged in as your Google account).")


# ----------------------------------------------------------------
# Client factory
# ----------------------------------------------------------------

def get_bq_client() -> bigquery.Client:
    """
    Return an authenticated BigQuery client using Colab credentials.
    Call authenticate_gcp() before this if not already done.
    """
    credentials, _ = google.auth.default(
        scopes=["https://www.googleapis.com/auth/cloud-platform"]
    )
    client = bigquery.Client(
        project=config.GCP_PROJECT_ID,
        credentials=credentials,
    )
    print(f"  BigQuery client ready  (project: {config.GCP_PROJECT_ID})")
    return client


# ----------------------------------------------------------------
# Query helper
# ----------------------------------------------------------------

def run_query(
    client: bigquery.Client,
    sql: str,
    params: list | None = None,
) -> "pd.DataFrame":
    """
    Execute a BigQuery SQL string and return results as a DataFrame.

    Parameters
    ----------
    client : bigquery.Client
    sql    : str   — Standard SQL (may contain @param placeholders)
    params : list  — Optional list of bigquery.*QueryParameter objects
    """
    job_config = bigquery.QueryJobConfig(query_parameters=params or [])
    job        = client.query(sql, job_config=job_config)
    return job.result().to_dataframe()


# ----------------------------------------------------------------
# Validation
# ----------------------------------------------------------------

def validate_bq_access(client: bigquery.Client):
    """
    Run a cheap COUNT query to confirm MIMIC-III access.
    """
    test_sql = f"""
        SELECT COUNT(*) AS n_stays
        FROM `{config.MIMIC_DATASET}.icustays`
        LIMIT 1
    """
    try:
        df = run_query(client, test_sql)
        n  = df["n_stays"].iloc[0]
        print(f"  MIMIC-III access confirmed — {n:,} ICU stays found.")
    except Exception as exc:
        print(
            f"\n  ERROR: Could not query MIMIC-III.\n"
            f"  Reason: {exc}\n\n"
            "  Checklist:\n"
            "    1. GCP_PROJECT_ID is set correctly in config.py\n"
            "    2. Your Google account has BigQuery Data Viewer +\n"
            "       BigQuery Job User roles on the project\n"
            "    3. Your GCP project is linked to PhysioNet:\n"
            "       https://physionet.org/settings/credentialing/\n"
            "    4. BigQuery API is enabled in your GCP project\n"
        )
        raise


# ----------------------------------------------------------------
# Entry point
# ----------------------------------------------------------------

if __name__ == "__main__":
    print("=" * 55)
    print("STEP 1: GCP / BigQuery Authentication (Colab)")
    print("=" * 55)

    authenticate_gcp()
    client = get_bq_client()
    validate_bq_access(client)

    print("\nReady to query BigQuery.")

STEP 1: GCP / BigQuery Authentication (Colab)
  GCP authentication complete (logged in as your Google account).
  BigQuery client ready  (project: silken-physics-467815-g5)
  MIMIC-III access confirmed — 61,532 ICU stays found.

Ready to query BigQuery.


**EXTRACT DATA**

In [5]:
# ============================================================
# 02_extract_data.py  —  Extract MIMIC-III data from BigQuery
#                         and save raw parquet files to Drive
#
# SELF-CONTAINED: no imports from other pipeline files needed.
# Just paste this entire file into a Colab cell and run it.
#
# Outputs written to Drive → My Drive/sepsis_rl/raw/
#   cohort.parquet
#   features.parquet
#   actions_raw.parquet
#   mortality.parquet
# ============================================================

import os
import pandas as pd
from google.colab import drive, auth
from google.cloud  import bigquery
import google.auth

# ============================================================
# ★  CONFIG — only thing you need to change
# ============================================================
GCP_PROJECT_ID = "silken-physics-467815-g5"       # ← CHANGE THIS
MIMIC_DATASET  = "physionet-data.mimiciii_clinical"

DRIVE_BASE     = "/content/drive/MyDrive/sepsis_rl"
RAW_DIR        = os.path.join(DRIVE_BASE, "raw")

# ============================================================
# AUTH — Mount Drive + Authenticate GCP
# ============================================================
print("="*55)
print("STEP 2: Data Extraction → Google Drive")
print("="*55)

print("\n[Auth] Mounting Drive...")
drive.mount("/content/drive", force_remount=False)

print("[Auth] Authenticating to GCP...")
auth.authenticate_user()
credentials, _ = google.auth.default(
    scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
client = bigquery.Client(project=GCP_PROJECT_ID, credentials=credentials)
print(f"  ✓ BigQuery client ready  (project: {GCP_PROJECT_ID})")

os.makedirs(RAW_DIR, exist_ok=True)
print(f"  ✓ Output folder ready: {RAW_DIR}")

# ============================================================
# QUERY HELPER
# ============================================================
def run_query(sql, params=None):
    job_config = bigquery.QueryJobConfig(query_parameters=params or [])
    return client.query(sql, job_config=job_config).result().to_dataframe()

# ============================================================
# SQL DEFINITIONS
# ============================================================
COHORT_SQL = f"""
WITH
icu_adults AS (
  SELECT ie.subject_id, ie.hadm_id, ie.icustay_id,
    ie.intime, ie.outtime,
    DATETIME_DIFF(ie.outtime, ie.intime, HOUR) AS los_hours,
    DATETIME_DIFF(ie.intime, p.dob, YEAR)      AS age
  FROM `{MIMIC_DATASET}.icustays` ie
  JOIN `{MIMIC_DATASET}.patients` p USING (subject_id)
  WHERE DATETIME_DIFF(ie.intime, p.dob, YEAR) >= 18
    AND DATETIME_DIFF(ie.outtime, ie.intime, HOUR) BETWEEN 12 AND 240
),
abx AS (
  SELECT DISTINCT hadm_id, startdate AS abx_date
  FROM `{MIMIC_DATASET}.prescriptions`
  WHERE LOWER(drug) IN (
    'vancomycin','piperacillin','meropenem','ceftriaxone',
    'levofloxacin','metronidazole','ciprofloxacin','ampicillin',
    'cefepime','azithromycin','fluconazole','micafungin'
  )
),
cultures AS (
  SELECT DISTINCT hadm_id, DATE(charttime) AS cult_date
  FROM `{MIMIC_DATASET}.microbiologyevents`
),
suspected_infection AS (
  SELECT a.hadm_id FROM abx a
  JOIN cultures c USING (hadm_id)
  WHERE ABS(DATE_DIFF(a.abx_date, c.cult_date, DAY)) <= 1
)
SELECT ia.* FROM icu_adults ia
JOIN suspected_infection si USING (hadm_id)
ORDER BY subject_id, intime
"""

VITALS_SQL = """
SELECT ce.icustay_id, ce.charttime,
  CASE
    WHEN ce.itemid IN (211,220045)                                  THEN 'heart_rate'
    WHEN ce.itemid IN (51,442,455,6701,220179,220050)               THEN 'sysbp'
    WHEN ce.itemid IN (8368,8440,8441,8555,220180,220051)           THEN 'diasbp'
    WHEN ce.itemid IN (456,52,6702,443,220052,220181,225312)        THEN 'meanbp'
    WHEN ce.itemid IN (618,615,220210,224690)                       THEN 'resp_rate'
    WHEN ce.itemid IN (223761,678)                                  THEN 'temp_f'
    WHEN ce.itemid IN (223762,676)                                  THEN 'temp_c'
    WHEN ce.itemid IN (646,220277)                                  THEN 'spo2'
    WHEN ce.itemid IN (807,811,1529,3745,3744,225664,220621,226537) THEN 'glucose'
    WHEN ce.itemid = 226730                                         THEN 'weight_kg'
  END AS feature,
  ce.valuenum
FROM `{dataset}.chartevents` ce
WHERE ce.icustay_id IN UNNEST(@icustay_ids)
  AND ce.error IS DISTINCT FROM 1
  AND ce.itemid IN (
    211,220045,51,442,455,6701,220179,220050,
    8368,8440,8441,8555,220180,220051,
    456,52,6702,443,220052,220181,225312,
    618,615,220210,224690,223761,678,223762,676,
    646,220277,807,811,1529,3745,3744,225664,220621,226537,226730
  )
  AND ce.valuenum IS NOT NULL AND ce.valuenum > 0
"""

LABS_SQL = """
SELECT le.hadm_id, le.charttime,
  CASE
    WHEN le.itemid = 50912 THEN 'creatinine'   WHEN le.itemid = 50902 THEN 'chloride'
    WHEN le.itemid = 50882 THEN 'bicarbonate'  WHEN le.itemid = 50893 THEN 'calcium'
    WHEN le.itemid = 50960 THEN 'magnesium'    WHEN le.itemid = 50983 THEN 'sodium'
    WHEN le.itemid = 50971 THEN 'potassium'    WHEN le.itemid = 51006 THEN 'bun'
    WHEN le.itemid = 51221 THEN 'hematocrit'   WHEN le.itemid = 51222 THEN 'hemoglobin'
    WHEN le.itemid = 51265 THEN 'platelets'    WHEN le.itemid = 51301 THEN 'wbc'
    WHEN le.itemid = 50813 THEN 'lactate'      WHEN le.itemid = 50820 THEN 'ph'
    WHEN le.itemid = 50821 THEN 'pao2'         WHEN le.itemid = 50818 THEN 'paco2'
    WHEN le.itemid = 50811 THEN 'base_excess'  WHEN le.itemid = 50861 THEN 'alt'
    WHEN le.itemid = 50878 THEN 'ast'          WHEN le.itemid = 50863 THEN 'alp'
    WHEN le.itemid = 50885 THEN 'bilirubin_total' WHEN le.itemid = 51275 THEN 'ptt'
    WHEN le.itemid = 51237 THEN 'inr'          WHEN le.itemid = 50889 THEN 'crp'
    WHEN le.itemid = 50931 THEN 'glucose_lab'  WHEN le.itemid = 51484 THEN 'bands'
  END AS feature,
  le.valuenum
FROM `{dataset}.labevents` le
WHERE le.hadm_id IN UNNEST(@hadm_ids)
  AND le.valuenum IS NOT NULL
  AND le.itemid IN (
    50912,50902,50882,50893,50960,50983,50971,51006,
    51221,51222,51265,51301,50813,50820,50821,50818,
    50811,50861,50878,50863,50885,51275,51237,50889,50931,51484
  )
"""

GCS_SQL = """
SELECT icustay_id, charttime, 'gcs' AS feature,
       CAST(valuenum AS FLOAT64) AS valuenum
FROM `{dataset}.chartevents`
WHERE icustay_id IN UNNEST(@icustay_ids)
  AND itemid IN (198, 226755, 227013)
  AND valuenum IS NOT NULL
"""

URINE_SQL = """
SELECT icustay_id, charttime, 'urine_output' AS feature, value AS valuenum
FROM `{dataset}.outputevents`
WHERE icustay_id IN UNNEST(@icustay_ids)
  AND itemid IN (
    40055,43175,40069,40094,40715,40473,40085,40057,40056,
    40405,40428,40086,40096,40651,
    226559,226560,226561,226584,226563,226564,
    226565,226567,226557,226558,227488,227489
  )
  AND value > 0
"""

ACTIONS_SQL = """
WITH
fluids AS (
  SELECT icustay_id, starttime AS charttime, 'iv_fluid' AS drug_type, amount AS dose
  FROM `{dataset}.inputevents_mv`
  WHERE icustay_id IN UNNEST(@icustay_ids)
    AND itemid IN (
      225158,225943,226089,225168,225828,225823,220862,
      220970,220864,225159,220995,225170,225825,227531,229268,227072
    )
    AND amount > 0 AND amountuom = 'mL'
),
vasopressors AS (
  SELECT icustay_id, starttime AS charttime, 'vasopressor' AS drug_type,
    CASE
      WHEN itemid IN (221906,221289) THEN rate
      WHEN itemid = 221662           THEN rate * 0.1
      WHEN itemid = 221749           THEN rate * 10
      WHEN itemid = 222315           THEN rate * 0.4
    END AS dose
  FROM `{dataset}.inputevents_mv`
  WHERE icustay_id IN UNNEST(@icustay_ids)
    AND itemid IN (221906,221289,221662,221749,222315)
    AND rate > 0
)
SELECT * FROM fluids UNION ALL SELECT * FROM vasopressors
"""

MORTALITY_SQL = """
SELECT ie.icustay_id,
  CASE WHEN p.dod IS NOT NULL
            AND p.dod <= DATETIME_ADD(ie.outtime, INTERVAL 90 DAY)
       THEN 1 ELSE 0 END AS died_90d,
  CASE WHEN p.dod IS NOT NULL
            AND p.dod <= DATETIME_ADD(ie.intime,  INTERVAL 28 DAY)
       THEN 1 ELSE 0 END AS died_28d
FROM `{dataset}.icustays` ie
JOIN `{dataset}.patients` p USING (subject_id)
WHERE ie.icustay_id IN UNNEST(@icustay_ids)
"""

# ============================================================
# EXTRACT
# ============================================================

# --- Cohort ---
print("\n[1/4] Extracting sepsis cohort...")
cohort = run_query(COHORT_SQL)
cohort.to_parquet(os.path.join(RAW_DIR, "cohort.parquet"), index=False)
print(f"  ✓ {len(cohort):,} ICU stays  →  cohort.parquet")

icustay_ids = cohort["icustay_id"].tolist()
hadm_ids    = cohort["hadm_id"].tolist()
icu_params  = [bigquery.ArrayQueryParameter("icustay_ids", "INT64", icustay_ids)]
adm_params  = [bigquery.ArrayQueryParameter("hadm_ids",    "INT64", hadm_ids)]

# --- Features ---
print("\n[2/4] Extracting clinical features...")
print("  Fetching vitals...")
vitals = run_query(VITALS_SQL.format(dataset=MIMIC_DATASET), icu_params)
vitals = vitals.merge(cohort[["icustay_id","hadm_id"]], on="icustay_id")

print("  Fetching labs...")
labs = run_query(LABS_SQL.format(dataset=MIMIC_DATASET), adm_params)
labs = labs.merge(cohort[["hadm_id","icustay_id"]], on="hadm_id")

print("  Fetching GCS...")
gcs = run_query(GCS_SQL.format(dataset=MIMIC_DATASET), icu_params)

print("  Fetching urine output...")
urine = run_query(URINE_SQL.format(dataset=MIMIC_DATASET), icu_params)
urine.rename(columns={"value": "valuenum"}, inplace=True, errors="ignore")

features = pd.concat([
    vitals[["icustay_id","charttime","feature","valuenum"]],
    labs  [["icustay_id","charttime","feature","valuenum"]],
    gcs   [["icustay_id","charttime","feature","valuenum"]],
    urine [["icustay_id","charttime","feature","valuenum"]],
], ignore_index=True)
features.dropna(subset=["feature","valuenum"], inplace=True)
features.to_parquet(os.path.join(RAW_DIR, "features.parquet"), index=False)
print(f"  ✓ {len(features):,} rows | {features['feature'].nunique()} features  →  features.parquet")

# --- Actions ---
print("\n[3/4] Extracting actions (fluids + vasopressors)...")
actions_raw = run_query(ACTIONS_SQL.format(dataset=MIMIC_DATASET), icu_params)
actions_raw.to_parquet(os.path.join(RAW_DIR, "actions_raw.parquet"), index=False)
print(f"  ✓ {len(actions_raw):,} rows  →  actions_raw.parquet")

# --- Mortality ---
print("\n[4/4] Extracting mortality outcomes...")
mortality = run_query(MORTALITY_SQL.format(dataset=MIMIC_DATASET), icu_params)
mortality.to_parquet(os.path.join(RAW_DIR, "mortality.parquet"), index=False)
print(f"  ✓ 90-day mortality: {mortality['died_90d'].mean():.1%}  →  mortality.parquet")

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "="*55)
print(f"✓ All raw data saved to: {RAW_DIR}")
print("="*55)
for f in ["cohort.parquet","features.parquet","actions_raw.parquet","mortality.parquet"]:
    path = os.path.join(RAW_DIR, f)
    size_mb = os.path.getsize(path) / 1e6 if os.path.exists(path) else 0
    print(f"  {f:35s}  {size_mb:.1f} MB")
print("\nNext step → run 03_process_data.py")

STEP 2: Data Extraction → Google Drive

[Auth] Mounting Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[Auth] Authenticating to GCP...
  ✓ BigQuery client ready  (project: silken-physics-467815-g5)
  ✓ Output folder ready: /content/drive/MyDrive/sepsis_rl/raw

[1/4] Extracting sepsis cohort...
  ✓ 88,544 ICU stays  →  cohort.parquet

[2/4] Extracting clinical features...
  Fetching vitals...
  Fetching labs...
  Fetching GCS...
  Fetching urine output...
  ✓ 102,961,257 rows | 38 features  →  features.parquet

[3/4] Extracting actions (fluids + vasopressors)...
  ✓ 83,893 rows  →  actions_raw.parquet

[4/4] Extracting mortality outcomes...
  ✓ 90-day mortality: 26.9%  →  mortality.parquet

✓ All raw data saved to: /content/drive/MyDrive/sepsis_rl/raw
  cohort.parquet                       1.3 MB
  features.parquet                     403.3 MB
  actions_raw.parquet                  1.5 MB
  mortal

**PROCESS DATA**

In [6]:
# ============================================================
# 03_process_data.py  —  Feature engineering and reward assignment
#
# SELF-CONTAINED: no imports from other pipeline files needed.
# Just paste this entire file into a Colab cell and run it.
#
# Reads from Drive  → My Drive/sepsis_rl/raw/
# Writes to Drive   → My Drive/sepsis_rl/processed/
# ============================================================

import os
import numpy as np
import pandas as pd
from google.colab import drive

# ============================================================
# ★  CONFIG
# ============================================================
DRIVE_BASE    = "/content/drive/MyDrive/sepsis_rl"
RAW_DIR       = os.path.join(DRIVE_BASE, "raw")
PROC_DIR      = os.path.join(DRIVE_BASE, "processed")

WINDOW_HOURS  = 4
MAX_WINDOWS   = 20
N_FLUID_BINS  = 5
N_VASO_BINS   = 5
TERMINAL_REWARD = 15.0
SOFA_PENALTY    = 0.5

# ============================================================
# MOUNT DRIVE
# ============================================================
print("="*55)
print("STEP 3: Data Processing")
print("="*55)

drive.mount("/content/drive", force_remount=False)
os.makedirs(PROC_DIR, exist_ok=True)
print(f"  ✓ Drive mounted")
print(f"  ✓ Output folder ready: {PROC_DIR}")

# ============================================================
# LOAD RAW DATA
# ============================================================
print("\n[Loading raw data from Drive]")

def load(filename):
    path = os.path.join(RAW_DIR, filename)
    if not os.path.exists(path):
        raise FileNotFoundError(f"Not found: {path}\nRun 02_extract_data.py first.")
    df = pd.read_parquet(path)
    print(f"  ✓ {filename}  ({len(df):,} rows)")
    return df

cohort      = load("cohort.parquet")
features    = load("features.parquet")
raw_actions = load("actions_raw.parquet")
mortality   = load("mortality.parquet")

cohort["intime"]         = pd.to_datetime(cohort["intime"])
cohort["outtime"]        = pd.to_datetime(cohort["outtime"])
features["charttime"]    = pd.to_datetime(features["charttime"])

# ============================================================
# A — Create 4-hour time windows
# ============================================================
print("\n[A] Creating 4-hour time windows...")
rows = []
for _, r in cohort.iterrows():
    for w in range(MAX_WINDOWS):
        ws = r["intime"] + pd.Timedelta(hours=w * WINDOW_HOURS)
        we = ws + pd.Timedelta(hours=WINDOW_HOURS)
        if ws >= r["outtime"]:
            break
        rows.append({
            "icustay_id": r["icustay_id"], "hadm_id": r["hadm_id"],
            "subject_id": r["subject_id"], "window_id": w,
            "window_start": ws, "window_end": we,
        })
windows = pd.DataFrame(rows)
print(f"  ✓ {len(windows):,} windows for {cohort['icustay_id'].nunique():,} stays")

# ============================================================
# B — Aggregate features into windows
# ============================================================
print("\n[B] Aggregating features into windows (this may take a few minutes)...")
records = []
for iid, stay_wins in windows.groupby("icustay_id"):
    stay_feat = features[features["icustay_id"] == iid]
    for _, win in stay_wins.iterrows():
        row = {"icustay_id": iid, "window_id": win["window_id"]}
        if not stay_feat.empty:
            mask  = (stay_feat["charttime"] >= win["window_start"]) & \
                    (stay_feat["charttime"] <  win["window_end"])
            chunk = stay_feat[mask]
            if len(chunk):
                row.update(chunk.groupby("feature")["valuenum"].mean().to_dict())
        records.append(row)

states = pd.DataFrame(records)
states = (states.sort_values(["icustay_id","window_id"])
                .groupby("icustay_id", group_keys=False)
                .apply(lambda g: g.ffill()))

if "temp_f" in states.columns and "temp_c" not in states.columns:
    states["temp_c"] = (states["temp_f"] - 32) * 5 / 9

states = states.reset_index(drop=True)
states.to_parquet(os.path.join(PROC_DIR, "states.parquet"), index=False)
print(f"  ✓ States: {states.shape[0]:,} rows × {states.shape[1]} cols  →  states.parquet")

# ============================================================
# C — Discretise actions
# ============================================================
print("\n[C] Discretising actions...")

def bin_dose(series, n_bins):
    bins    = pd.Series(0, index=series.index, dtype=int)
    nonzero = series > 0
    if nonzero.sum() == 0:
        return bins
    qs = series[nonzero].quantile([i/(n_bins-1) for i in range(1, n_bins)]).values
    for level, threshold in enumerate(qs, start=1):
        bins[nonzero & (series <= threshold)] = level
    bins[nonzero & (series > qs[-1])] = n_bins - 1
    return bins

fluids = (raw_actions[raw_actions["drug_type"]=="iv_fluid"]
          .groupby(["icustay_id","charttime"])["dose"].sum().reset_index())
vasos  = (raw_actions[raw_actions["drug_type"]=="vasopressor"]
          .groupby(["icustay_id","charttime"])["dose"].sum().reset_index())

fluids["fluid_bin"] = bin_dose(fluids["dose"], N_FLUID_BINS)
vasos ["vaso_bin"]  = bin_dose(vasos ["dose"], N_VASO_BINS)

actions_binned = fluids.merge(
    vasos[["icustay_id","charttime","vaso_bin"]],
    on=["icustay_id","charttime"], how="outer"
).fillna(0)
actions_binned["fluid_bin"] = actions_binned["fluid_bin"].astype(int)
actions_binned["vaso_bin"]  = actions_binned["vaso_bin"].astype(int)
actions_binned["action"]    = actions_binned["fluid_bin"] * N_VASO_BINS + actions_binned["vaso_bin"]
actions_binned.to_parquet(os.path.join(PROC_DIR, "actions_binned.parquet"), index=False)
print(f"  ✓ {actions_binned['action'].nunique()} unique actions  →  actions_binned.parquet")

# ============================================================
# D — Assign rewards
# ============================================================
print("\n[D] Assigning rewards...")
final_df = states.merge(mortality, on="icustay_id", how="left")
last_win = final_df.groupby("icustay_id")["window_id"].transform("max")
is_term  = final_df["window_id"] == last_win
final_df["reward"] = 0.0
final_df.loc[is_term, "reward"] = np.where(
    final_df.loc[is_term, "died_90d"] == 1,
    -TERMINAL_REWARD, +TERMINAL_REWARD
)
if "sofa" in final_df.columns:
    sofa_delta = final_df.groupby("icustay_id")["sofa"].diff().fillna(0)
    final_df["reward"] -= sofa_delta * SOFA_PENALTY

final_df.to_parquet(os.path.join(PROC_DIR, "sepsis_rl_dataset.parquet"), index=False)
print(f"  ✓ Final dataset: {final_df.shape[0]:,} rows × {final_df.shape[1]} cols")
print(f"  ✓ Reward stats:\n{final_df['reward'].value_counts().sort_index()}")

# ============================================================
# SUMMARY
# ============================================================
print("\n" + "="*55)
print(f"✓ Processed data saved to: {PROC_DIR}")
print("="*55)
for f in ["states.parquet","actions_binned.parquet","sepsis_rl_dataset.parquet"]:
    path = os.path.join(PROC_DIR, f)
    size_mb = os.path.getsize(path) / 1e6 if os.path.exists(path) else 0
    print(f"  {f:40s}  {size_mb:.1f} MB")
print("\nNext step → run 04_model.py")

STEP 3: Data Processing
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
  ✓ Drive mounted
  ✓ Output folder ready: /content/drive/MyDrive/sepsis_rl/processed

[Loading raw data from Drive]
  ✓ cohort.parquet  (88,544 rows)
  ✓ features.parquet  (102,961,257 rows)
  ✓ actions_raw.parquet  (83,893 rows)
  ✓ mortality.parquet  (21,182 rows)

[A] Creating 4-hour time windows...
  ✓ 1,349,117 windows for 21,182 stays

[B] Aggregating features into windows (this may take a few minutes)...


  .apply(lambda g: g.ffill()))


  ✓ States: 1,349,117 rows × 40 cols  →  states.parquet

[C] Discretising actions...
  ✓ 1 unique actions  →  actions_binned.parquet

[D] Assigning rewards...
  ✓ Final dataset: 1,349,117 rows × 43 cols
  ✓ Reward stats:
reward
-15.0      28314
 0.0     1260573
 15.0      60230
Name: count, dtype: int64

✓ Processed data saved to: /content/drive/MyDrive/sepsis_rl/processed
  states.parquet                            15.7 MB
  actions_binned.parquet                    0.6 MB
  sepsis_rl_dataset.parquet                 15.8 MB

Next step → run 04_model.py


**MODEL**

In [None]:
# ============================================================
# 04_model.py  —  Train and evaluate the sepsis RL model
#
# Reads from Drive  →  sepsis_rl/processed/sepsis_rl_dataset.parquet
#
# Architecture: Dueling Double DQN (as per WD3QNE paper)
#   - State  : normalised clinical feature vector
#   - Action : 25 discrete actions (5 fluid bins × 5 vaso bins)
#   - Reward : ±15 terminal  +  optional SOFA shaping
#
# Outputs saved to Drive  →  sepsis_rl/outputs/
#   model.pt      (PyTorch state dict)
#   metrics.json  (training loss, eval metrics)
# ============================================================

import os
import json
import tempfile
import warnings
warnings.filterwarnings("ignore")

import numpy  as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import config
from auth_drive import setup_drive_folders, load_parquet, save_json, save_file


# ============================================================
# MODEL ARCHITECTURE  —  Dueling Double DQN
# ============================================================

class DuelingDQN(nn.Module):
    """
    Dueling network with separate value and advantage streams.
    Inputs  : state_dim  (number of clinical features after imputation)
    Outputs : Q-values for each of n_actions discrete actions
    """

    def __init__(self, state_dim: int, n_actions: int, hidden: int = 256):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        # Value stream  V(s)
        self.value_stream = nn.Sequential(
            nn.Linear(hidden, 128), nn.ReLU(),
            nn.Linear(128, 1),
        )
        # Advantage stream  A(s, a)
        self.adv_stream = nn.Sequential(
            nn.Linear(hidden, 128), nn.ReLU(),
            nn.Linear(128, n_actions),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h   = self.shared(x)
        V   = self.value_stream(h)
        A   = self.adv_stream(h)
        # Q(s,a) = V(s) + A(s,a) - mean(A(s,·))
        return V + A - A.mean(dim=-1, keepdim=True)


# ============================================================
# REPLAY BUFFER
# ============================================================

class ReplayBuffer:
    def __init__(self, capacity: int = 50_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, ns, d = zip(*batch)
        return (
            torch.FloatTensor(s),
            torch.LongTensor(a),
            torch.FloatTensor(r),
            torch.FloatTensor(ns),
            torch.FloatTensor(d),
        )

    def __len__(self):
        return len(self.buffer)


# ============================================================
# DATA PREPARATION
# ============================================================

FEATURE_COLS = [
    "heart_rate", "sysbp", "diasbp", "meanbp", "resp_rate",
    "temp_c", "spo2", "glucose", "weight_kg",
    "creatinine", "chloride", "bicarbonate", "calcium", "magnesium",
    "sodium", "potassium", "bun", "hematocrit", "hemoglobin",
    "platelets", "wbc", "lactate", "ph", "pao2", "paco2",
    "base_excess", "bilirubin_total", "ptt", "inr",
    "alt", "ast", "alp", "crp",
    "gcs", "urine_output",
]


def prepare_dataset(df: pd.DataFrame):
    """
    Returns arrays of (states, actions, rewards, next_states, dones)
    suitable for experience-replay training.
    """
    # Use only columns present in this cohort extract
    feat_cols = [c for c in FEATURE_COLS if c in df.columns]

    # Median imputation for remaining NaNs
    df[feat_cols] = df[feat_cols].fillna(df[feat_cols].median())

    # Normalise features
    scaler = StandardScaler()
    df[feat_cols] = scaler.fit_transform(df[feat_cols])

    # Sort into episodes
    df = df.sort_values(["icustay_id", "window_id"]).reset_index(drop=True)

    states, actions, rewards, next_states, dones = [], [], [], [], []

    for _, episode in df.groupby("icustay_id"):
        ep = episode.reset_index(drop=True)
        for t in range(len(ep) - 1):
            s  = ep.loc[t,     feat_cols].values.astype(np.float32)
            a  = int(ep.loc[t, "action"]) if "action" in ep.columns else 0
            r  = float(ep.loc[t, "reward"])
            ns = ep.loc[t + 1, feat_cols].values.astype(np.float32)
            d  = 1.0 if t == len(ep) - 2 else 0.0
            states.append(s);  actions.append(a)
            rewards.append(r); next_states.append(ns)
            dones.append(d)

    return (
        np.array(states),
        np.array(actions),
        np.array(rewards),
        np.array(next_states),
        np.array(dones),
        feat_cols,
        scaler,
    )


# ============================================================
# TRAINING LOOP
# ============================================================

def train(
    states, actions, rewards, next_states, dones,
    state_dim: int,
    n_actions:  int = 25,
    n_epochs:   int = 200,
    batch_size: int = 512,
    lr:         float = 1e-4,
    gamma:      float = 0.99,
    tau:        float = 0.01,
    device:     str = "cpu",
):
    """
    Offline Dueling Double DQN training from a fixed replay buffer.
    Returns (online_net, metrics_dict).
    """
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    print(f"  Training on: {device}  |  state_dim={state_dim}  |  n_actions={n_actions}")

    online_net = DuelingDQN(state_dim, n_actions).to(device)
    target_net = DuelingDQN(state_dim, n_actions).to(device)
    target_net.load_state_dict(online_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(online_net.parameters(), lr=lr, weight_decay=1e-5)
    loss_fn   = nn.SmoothL1Loss()

    # Populate replay buffer
    buffer = ReplayBuffer(capacity=len(states))
    for i in range(len(states)):
        buffer.push(states[i], actions[i], rewards[i], next_states[i], dones[i])

    loss_history = []

    for epoch in range(1, n_epochs + 1):
        if len(buffer) < batch_size:
            break

        s, a, r, ns, d = buffer.sample(batch_size)
        s, a, r, ns, d = s.to(device), a.to(device), r.to(device), ns.to(device), d.to(device)

        # Double DQN target
        with torch.no_grad():
            best_actions = online_net(ns).argmax(dim=1)
            target_q     = target_net(ns).gather(1, best_actions.unsqueeze(1)).squeeze()
            y            = r + gamma * target_q * (1 - d)

        q_vals = online_net(s).gather(1, a.unsqueeze(1)).squeeze()
        loss   = loss_fn(q_vals, y)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(online_net.parameters(), 1.0)
        optimizer.step()

        # Soft update target network
        for p_online, p_target in zip(online_net.parameters(), target_net.parameters()):
            p_target.data.copy_(tau * p_online.data + (1 - tau) * p_target.data)

        loss_val = loss.item()
        loss_history.append(loss_val)

        if epoch % 20 == 0:
            print(f"  Epoch {epoch:>4d}/{n_epochs}  —  loss: {loss_val:.4f}")

    metrics = {
        "n_transitions":    int(len(states)),
        "state_dim":        state_dim,
        "n_actions":        n_actions,
        "final_loss":       float(loss_history[-1]) if loss_history else None,
        "mean_loss":        float(np.mean(loss_history)) if loss_history else None,
        "loss_history":     loss_history,
    }

    return online_net.cpu(), metrics


# ============================================================
# MAIN
# ============================================================

def main():
    print("=" * 55)
    print("STEP 4: Model Training (Dueling Double DQN)")
    print("=" * 55)

    folders  = setup_drive_folders()
    proc_path = folders["processed"]
    out_path  = folders["outputs"]

    # --- Load processed dataset ---
    print("\n[Loading processed dataset from Drive]")
    df = load_parquet(proc_path, config.PROC_DATASET_FILE)
    print(f"  Loaded {len(df):,} rows × {df.shape[1]} columns")

    # --- Prepare ---
    print("\n[Preparing training data]")
    states, actions, rewards, next_states, dones, feat_cols, scaler = \
        prepare_dataset(df)
    print(f"  Transitions: {len(states):,}  |  Features used: {len(feat_cols)}")

    # --- Train ---
    print("\n[Training]")
    model, metrics = train(
        states, actions, rewards, next_states, dones,
        state_dim = len(feat_cols),
    )
    metrics["feature_columns"] = feat_cols

    # --- Save model directly to Drive ---
    print("\n[Saving model and metrics to Drive]")

    import tempfile, os
    with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as tmp:
        model_path = tmp.name
    torch.save(
        {"model_state_dict": model.state_dict(), "feature_columns": feat_cols},
        model_path,
    )
    save_file(model_path, out_path)
    os.unlink(model_path)

    save_json(metrics, out_path, config.OUTPUT_METRICS_FILE)

    print(f"\nFinal loss : {metrics['final_loss']:.4f}")
    print(f"Mean loss  : {metrics['mean_loss']:.4f}")
    print("\n✓ Model saved to Drive → sepsis_rl/outputs/")


if __name__ == "__main__":
    main()

**SAVE OUTPUT**

In [None]:
# ============================================================
# 05_save_outputs.py  —  Generate evaluation plots & summary report,
#                         then upload everything to Drive outputs/
#
# Reads from Drive  →  sepsis_rl/outputs/model.pt, metrics.json
#                   →  sepsis_rl/processed/sepsis_rl_dataset.parquet
#
# Writes to Drive   →  sepsis_rl/outputs/
#   loss_curve.png
#   action_distribution.png
#   reward_distribution.png
#   summary_report.json
# ============================================================

import warnings
warnings.filterwarnings("ignore")

import numpy  as np
import pandas as pd
import torch
import matplotlib
matplotlib.use("Agg")   # non-interactive backend
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler

import config
from auth_drive import setup_drive_folders, load_parquet, load_json, save_file
from model import DuelingDQN, FEATURE_COLS


# ============================================================
# LOAD HELPERS
# ============================================================

def load_model_from_drive(folder_path: str, state_dim: int, n_actions: int):
    import os
    path = os.path.join(folder_path, config.OUTPUT_MODEL_FILE)
    if not os.path.exists(path):
        raise FileNotFoundError(
            f"'{config.OUTPUT_MODEL_FILE}' not found at {path}. "
            "Run 04_model.py first."
        )
    checkpoint = torch.load(path, map_location="cpu")
    model = DuelingDQN(state_dim, n_actions)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    print(f"  Loaded model ← {path}")
    return model, checkpoint.get("feature_columns", [])


# ============================================================
# PLOT FUNCTIONS
# ============================================================

def plot_loss_curve(loss_history: list, save_path: str):
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.plot(loss_history, color="#2563EB", linewidth=1.2, alpha=0.9)
    ax.set_xlabel("Training Epoch")
    ax.set_ylabel("Smooth L1 Loss")
    ax.set_title("Dueling Double DQN — Training Loss Curve")
    ax.grid(alpha=0.3)
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"  Saved loss curve → {save_path}")


def plot_action_distribution(df: pd.DataFrame, save_path: str):
    if "action" not in df.columns:
        print("  No 'action' column — skipping action distribution plot.")
        return

    n_actions = config.N_FLUID_BINS * config.N_VASO_BINS
    counts    = df["action"].value_counts().reindex(range(n_actions), fill_value=0)

    fluid_labels = [f"Fluid {i}" for i in range(config.N_FLUID_BINS)]
    vaso_labels  = [f"Vaso {j}"  for j in range(config.N_VASO_BINS)]

    matrix = counts.values.reshape(config.N_FLUID_BINS, config.N_VASO_BINS)

    fig, ax = plt.subplots(figsize=(7, 5))
    im = ax.imshow(matrix, cmap="YlOrRd")
    ax.set_xticks(range(config.N_VASO_BINS));  ax.set_xticklabels(vaso_labels)
    ax.set_yticks(range(config.N_FLUID_BINS)); ax.set_yticklabels(fluid_labels)
    ax.set_title("Action Frequency Heatmap\n(Fluid bin × Vasopressor bin)")
    plt.colorbar(im, ax=ax, label="Count")
    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"  Saved action distribution → {save_path}")


def plot_reward_distribution(df: pd.DataFrame, save_path: str):
    if "reward" not in df.columns:
        print("  No 'reward' column — skipping reward distribution plot.")
        return

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    # All rewards
    axes[0].hist(df["reward"], bins=30, color="#10B981", edgecolor="white")
    axes[0].set_title("All-Step Reward Distribution")
    axes[0].set_xlabel("Reward"); axes[0].set_ylabel("Count")

    # Terminal rewards only
    last_w    = df.groupby("icustay_id")["window_id"].transform("max")
    term_df   = df[df["window_id"] == last_w]
    term_vals = term_df["reward"].value_counts().sort_index()
    axes[1].bar(
        term_vals.index.astype(str), term_vals.values,
        color=["#EF4444" if v < 0 else "#10B981" for v in term_vals.index]
    )
    axes[1].set_title("Terminal Reward Distribution")
    axes[1].set_xlabel("Reward"); axes[1].set_ylabel("Stays")

    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"  Saved reward distribution → {save_path}")


def plot_q_value_sample(
    model: DuelingDQN,
    df: pd.DataFrame,
    feat_cols: list,
    save_path: str,
    n_sample: int = 500,
):
    """Plot distribution of max Q-values over a random sample of states."""
    df2 = df[feat_cols].copy().fillna(df[feat_cols].median())
    scaler = StandardScaler()
    scaled = scaler.fit_transform(df2)
    idx    = np.random.choice(len(scaled), min(n_sample, len(scaled)), replace=False)
    sample = torch.FloatTensor(scaled[idx])

    with torch.no_grad():
        q_vals = model(sample).numpy()

    max_q = q_vals.max(axis=1)
    best_a = q_vals.argmax(axis=1)

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].hist(max_q, bins=30, color="#6366F1", edgecolor="white")
    axes[0].set_title("Max Q-Value Distribution (sample)")
    axes[0].set_xlabel("Max Q(s,a)")

    n_actions = config.N_FLUID_BINS * config.N_VASO_BINS
    counts    = np.bincount(best_a, minlength=n_actions)
    axes[1].bar(range(n_actions), counts, color="#6366F1")
    axes[1].set_title("Greedy Action Distribution (sample)")
    axes[1].set_xlabel("Action index")

    fig.tight_layout()
    fig.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"  Saved Q-value sample plot → {save_path}")


# ============================================================
# SUMMARY REPORT
# ============================================================

def build_summary(metrics: dict, df: pd.DataFrame) -> dict:
    n_stays = df["icustay_id"].nunique() if "icustay_id" in df.columns else None

    last_w = df.groupby("icustay_id")["window_id"].transform("max")
    term_r = df.loc[df["window_id"] == last_w, "reward"]

    return {
        "n_icu_stays":           n_stays,
        "n_transitions":         metrics.get("n_transitions"),
        "state_dim":             metrics.get("state_dim"),
        "n_actions":             metrics.get("n_actions"),
        "training_final_loss":   metrics.get("final_loss"),
        "training_mean_loss":    metrics.get("mean_loss"),
        "pct_positive_terminal": float((term_r > 0).mean()),
        "pct_negative_terminal": float((term_r < 0).mean()),
        "feature_columns":       metrics.get("feature_columns", []),
    }


# ============================================================
# MAIN
# ============================================================

def main():
    print("=" * 55)
    print("STEP 5: Save Outputs to Drive")
    print("=" * 55)

    folders   = setup_drive_folders()
    proc_path = folders["processed"]
    out_path  = folders["outputs"]

    # --- Load artefacts directly from Drive filesystem ---
    print("\n[Loading artefacts from Drive]")
    metrics = load_json(out_path,  config.OUTPUT_METRICS_FILE)
    df      = load_parquet(proc_path, config.PROC_DATASET_FILE)

    feat_cols = metrics.get("feature_columns") or \
                [c for c in FEATURE_COLS if c in df.columns]
    state_dim = len(feat_cols)
    n_actions = config.N_FLUID_BINS * config.N_VASO_BINS

    model, _ = load_model_from_drive(out_path, state_dim, n_actions)

    # --- Generate plots directly into Drive outputs folder ---
    print("\n[Generating plots]")
    import os
    loss_path   = os.path.join(out_path, "loss_curve.png")
    action_path = os.path.join(out_path, "action_distribution.png")
    reward_path = os.path.join(out_path, "reward_distribution.png")
    qval_path   = os.path.join(out_path, "q_value_sample.png")

    plot_loss_curve(metrics.get("loss_history", []), loss_path)
    plot_action_distribution(df, action_path)
    plot_reward_distribution(df, reward_path)
    plot_q_value_sample(model, df, feat_cols, qval_path)

    # --- Build and save summary report ---
    summary = build_summary(metrics, df)
    save_json(summary, out_path, "summary_report.json")
    print(f"  Summary: {summary}")

    print("\n✓ All outputs saved to Drive → sepsis_rl/outputs/")
    print("  Files written:")
    print("    loss_curve.png")
    print("    action_distribution.png")
    print("    reward_distribution.png")
    print("    q_value_sample.png")
    print("    summary_report.json")


if __name__ == "__main__":
    main()