In [5]:
import os
import json
from typing import List, Dict, Any

import pandas as pd

# ==============================
# 0. 경로 및 기본 설정
# ==============================

COHORT_PATH = "./cohort/cohort_ver50_only_subject_id.csv"

HOSP_DIR = "../../data/MIMIC4-hosp-icu"
ICU_DIR = "../../data/MIMIC4-hosp-icu"
ED_DIR = "../../data/mimic-iv-ed/ed"
ECG_DIR = "../../data/mimic-iv-ecg"

OUTPUT_DIR = "./cohort"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 140 버전 event log 출력 경로
EVENT_LOG_140_PATH = os.path.join(OUTPUT_DIR, "cohort_ver140_event_log.csv")

# Troponin 양성 기준
TROP_POS_THRESHOLD = 0.04

# PCI ICD 코드 prefix
PCI_ICD9_PREFIXES = ["00.66", "36.0"]
PCI_ICD10_PREFIXES = ["027", "929"]

# 항혈소판제 이름 리스트
ANTI_PLT_DRUGS = ["aspirin", "clopidogrel", "ticagrelor", "prasugrel"]


# ==============================
# 1. 공통 유틸
# ==============================

def _to_datetime(df: pd.DataFrame, cols: List[str]) -> pd.DataFrame:
    """
    지정된 컬럼들을 datetime으로 변환.
    연도 이상치는 그대로 허용하고, 파싱 실패는 NaT로 둔다.
    """
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_datetime(df[c], errors="coerce")
    return df


def _attrs_to_json(attrs: Dict[str, Any]) -> str:
    """
    attributes dict를 JSON 문자열로 변환.
    NaN은 None으로 치환.
    """
    clean = {}
    for k, v in attrs.items():
        if isinstance(v, float) and pd.isna(v):
            clean[k] = None
        else:
            clean[k] = v
    return json.dumps(clean, ensure_ascii=False)


# ==============================
# 2. Cohort 및 원본 테이블 로딩
# ==============================

def load_labevents_troponin() -> pd.DataFrame:
    """
    Troponin 전용 labevents_troponin.csv 로딩.
    """
    candidates = [
        os.path.join(HOSP_DIR, "labevents_troponin.csv"),
        "./labevents_troponin.csv",
    ]
    for p in candidates:
        if os.path.exists(p):
            print(f"[LOAD] labevents_troponin.csv 로딩: {p}")
            return pd.read_csv(p)

    raise FileNotFoundError(
        "labevents_troponin.csv 파일을 찾을 수 없습니다.\n"
        f"다음 경로 중 하나에 존재해야 합니다.\n{candidates}"
    )


def load_cohort_first_stemi_admission(path: str) -> pd.DataFrame:
    """
    STEMI 대상 환자 코호트 로딩.
    - 입력: subject_id, hadm_id (중복 가능)
    - 처리: admissions.admittime 기준으로 각 subject_id당 가장 이른 hadm_id 하나만 남긴다.
    """
    cohort_raw = pd.read_csv(path)

    if "subject_id" not in cohort_raw.columns:
        raise ValueError(f"cohort 파일에 subject_id 컬럼이 없습니다. 현재 컬럼: {list(cohort_raw.columns)}")
    if "hadm_id" not in cohort_raw.columns:
        raise ValueError(
            "cohort 파일에 hadm_id 컬럼이 없습니다.\n"
            "STEMI 분석은 hadm_id(입원 단위) 기준으로 해야 합니다.\n"
            "cohort_ver50_only_subject_id.csv를 subject_id, hadm_id 두 컬럼을 포함하도록 다시 만들어 주세요."
        )

    cohort_all = cohort_raw[["subject_id", "hadm_id"]].drop_duplicates()

    admissions_path = os.path.join(HOSP_DIR, "admissions.csv")
    admissions = pd.read_csv(admissions_path)[["subject_id", "hadm_id", "admittime"]]
    admissions = _to_datetime(admissions, ["admittime"])

    merged = cohort_all.merge(admissions, on=["subject_id", "hadm_id"], how="left")
    merged = merged.sort_values(["subject_id", "admittime"])

    first_stemi_adm = merged.groupby("subject_id", as_index=False).head(1)
    cohort_final = first_stemi_adm[["subject_id", "hadm_id"]].drop_duplicates().reset_index(drop=True)

    print(f"[COHORT] 원래 STEMI 입원 후보 row 수: {len(cohort_all)}")
    print(f"[COHORT] subject_id 기준 '첫 STEMI 입원'만 남긴 row 수: {len(cohort_final)}")

    return cohort_final


def load_source_tables() -> Dict[str, pd.DataFrame]:
    """
    MIMIC-IV 원본 테이블 로딩.
    """
    print("[LOAD] admissions, patients, icustays...")
    admissions = pd.read_csv(os.path.join(HOSP_DIR, "admissions.csv"))
    patients = pd.read_csv(os.path.join(HOSP_DIR, "patients.csv"))
    icustays = pd.read_csv(os.path.join(ICU_DIR, "icustays.csv"))

    print("[LOAD] edstays (있으면 로딩)...")
    edstays_path = os.path.join(ED_DIR, "edstays.csv")
    edstays = pd.read_csv(edstays_path) if os.path.exists(edstays_path) else None

    print("[LOAD] labevents_troponin ...")
    labevents_trop = load_labevents_troponin()

    print("[LOAD] procedures_icd, prescriptions, ECG machine_measurements...")
    procedures_icd = pd.read_csv(os.path.join(HOSP_DIR, "procedures_icd.csv"))
    prescriptions = pd.read_csv(os.path.join(HOSP_DIR, "prescriptions.csv"))
    ecg = pd.read_csv(os.path.join(ECG_DIR, "machine_measurements.csv"))

    return {
        "admissions": admissions,
        "patients": patients,
        "icustays": icustays,
        "edstays": edstays,
        "labevents_trop": labevents_trop,
        "procedures_icd": procedures_icd,
        "prescriptions": prescriptions,
        "ecg": ecg,
    }


# ==============================
# 3. 개별 Event 생성 함수들
# ==============================

def build_ed_events(cohort: pd.DataFrame,
                    edstays: pd.DataFrame,
                    admissions: pd.DataFrame) -> pd.DataFrame:
    """
    ED_ARRIVAL, ED_DEPARTURE, ED_ARRIVAL_SURR 이벤트 생성.

    - edstays 기준으로 ED_ARRIVAL / ED_DEPARTURE 생성 (있으면 그대로 사용)
    - 해당 hadm에 ED 이벤트가 전혀 없으면 admissions.admittime으로
      ED_ARRIVAL_SURR 생성
    """
    events = []
    hadm_with_ed = set()

    # 1) 실제 ED_ARRIVAL / ED_DEPARTURE
    if edstays is not None:
        ed = edstays.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
        ed = _to_datetime(ed, ["intime", "outtime"])

        for _, row in ed.iterrows():
            hid = row["hadm_id"]
            sid = row["subject_id"]

            if pd.notnull(row.get("intime")):
                events.append({
                    "subject_id": sid,
                    "hadm_id": hid,
                    "event_name": "ED_ARRIVAL",
                    "timestamp": row["intime"],
                    "attributes": _attrs_to_json({}),
                })
                hadm_with_ed.add(hid)

            if pd.notnull(row.get("outtime")):
                events.append({
                    "subject_id": sid,
                    "hadm_id": hid,
                    "event_name": "ED_DEPARTURE",
                    "timestamp": row["outtime"],
                    "attributes": _attrs_to_json({}),
                })
                hadm_with_ed.add(hid)

    # 2) surrogate ED_ARRIVAL (ED가 전혀 없는 hadm에 대해)
    adm = admissions.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    adm = _to_datetime(adm, ["admittime"])

    for _, row in adm.iterrows():
        hid = row["hadm_id"]
        sid = row["subject_id"]
        if hid in hadm_with_ed:
            continue  # 이미 ED 이벤트 있음

        admt = row.get("admittime")
        if pd.isna(admt):
            continue

        events.append({
            "subject_id": sid,
            "hadm_id": hid,
            "event_name": "ED_ARRIVAL_SURR",
            "timestamp": admt,
            "attributes": _attrs_to_json({"source": "admittime"}),
        })
        hadm_with_ed.add(hid)

    if not events:
        return pd.DataFrame(columns=["subject_id", "hadm_id", "event_name", "timestamp", "attributes"])

    ed_events = pd.DataFrame(events)
    ed_events = _to_datetime(ed_events, ["timestamp"])
    return ed_events


def build_admission_events(cohort: pd.DataFrame,
                           admissions: pd.DataFrame,
                           patients: pd.DataFrame) -> pd.DataFrame:
    """
    DISCHARGE, DEATH 이벤트 생성.
    """
    adm = admissions.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    adm = _to_datetime(adm, ["admittime", "dischtime"])
    pat = _to_datetime(patients.copy(), ["dod"])

    adm = adm.merge(pat[["subject_id", "dod"]], on="subject_id", how="left")

    events = []
    for _, row in adm.iterrows():
        sid = row["subject_id"]
        hid = row["hadm_id"]
        if pd.notnull(row.get("dischtime")):
            events.append({
                "subject_id": sid,
                "hadm_id": hid,
                "event_name": "DISCHARGE",
                "timestamp": row["dischtime"],
                "attributes": _attrs_to_json({}),
            })
        if pd.notnull(row.get("dod")):
            events.append({
                "subject_id": sid,
                "hadm_id": hid,
                "event_name": "DEATH",
                "timestamp": row["dod"],
                "attributes": _attrs_to_json({}),
            })

    return pd.DataFrame(events)


def build_icu_events(cohort: pd.DataFrame, icustays: pd.DataFrame) -> pd.DataFrame:
    """
    ICU_INTIME, ICU_OUTTIME 이벤트 생성.
    """
    icu = icustays.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    icu = _to_datetime(icu, ["intime", "outtime"])

    events = []
    for _, row in icu.iterrows():
        sid = row["subject_id"]
        hid = row["hadm_id"]
        attrs = {
            "first_careunit": row.get("first_careunit", None),
            "last_careunit": row.get("last_careunit", None),
            "stay_id": row.get("stay_id", None),
        }
        if pd.notnull(row.get("intime")):
            events.append({
                "subject_id": sid,
                "hadm_id": hid,
                "event_name": "ICU_INTIME",
                "timestamp": row["intime"],
                "attributes": _attrs_to_json(attrs),
            })
        if pd.notnull(row.get("outtime")):
            events.append({
                "subject_id": sid,
                "hadm_id": hid,
                "event_name": "ICU_OUTTIME",
                "timestamp": row["outtime"],
                "attributes": _attrs_to_json(attrs),
            })

    return pd.DataFrame(events)


def build_troponin_events(cohort: pd.DataFrame,
                          labevents_trop: pd.DataFrame,
                          positive_threshold: float) -> pd.DataFrame:
    """
    TROP_TAKEN, TROP_POSITIVE 이벤트 생성.
    """
    lab = labevents_trop.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    lab = _to_datetime(lab, ["charttime"])

    events = []

    # TROP_TAKEN
    for _, row in lab.iterrows():
        if pd.notnull(row.get("charttime")):
            attrs = {
                "itemid": row.get("itemid", None),
                "valuenum": row.get("valuenum", None),
                "value": row.get("value", None),
                "flag": row.get("flag", None),
            }
            events.append({
                "subject_id": row["subject_id"],
                "hadm_id": row["hadm_id"],
                "event_name": "TROP_TAKEN",
                "timestamp": row["charttime"],
                "attributes": _attrs_to_json(attrs),
            })

    # TROP_POSITIVE: hadm_id 기준 첫 양성
    if "valuenum" in lab.columns:
        lab_pos = lab[lab["valuenum"] >= positive_threshold].copy()
        lab_pos = lab_pos.dropna(subset=["charttime"])
        lab_pos_sorted = lab_pos.sort_values(["subject_id", "hadm_id", "charttime"])
        first_pos = lab_pos_sorted.groupby(["subject_id", "hadm_id"], as_index=False).head(1)

        for _, row in first_pos.iterrows():
            attrs = {
                "itemid": row.get("itemid", None),
                "valuenum": row.get("valuenum", None),
                "value": row.get("value", None),
                "flag": row.get("flag", None),
            }
            events.append({
                "subject_id": row["subject_id"],
                "hadm_id": row["hadm_id"],
                "event_name": "TROP_POSITIVE",
                "timestamp": row["charttime"],
                "attributes": _attrs_to_json(attrs),
            })

    return pd.DataFrame(events)


def build_ecg_events(cohort: pd.DataFrame, ecg: pd.DataFrame) -> pd.DataFrame:
    """
    ECG_TAKEN, ECG_STEMI_FLAG 이벤트 생성.
    machine_measurements.csv: 보통 subject_id만 있으므로
    cohort에서 subject_id→hadm_id 매핑을 붙여 사용.
    """
    ecg_c = ecg.copy()

    # 시간 컬럼 통일
    if "charttime" not in ecg_c.columns:
        if "ecg_time" in ecg_c.columns:
            ecg_c = ecg_c.rename(columns={"ecg_time": "charttime"})
        else:
            raise ValueError(
                "ECG 데이터에 charttime/ecg_time 둘 다 없습니다. 시간 컬럼을 확인해 주세요."
            )

    subj_hadm_map = cohort[["subject_id", "hadm_id"]].drop_duplicates()
    ecg_c = ecg_c.merge(subj_hadm_map, on="subject_id", how="inner")
    ecg_c = _to_datetime(ecg_c, ["charttime"])

    def has_stemi_flag(row) -> bool:
        mm = str(row.get("machine_measurements", "")).upper()
        reports = []
        for i in range(30):
            col = f"report_{i}"
            if col in row.index:
                reports.append(str(row.get(col, "")))
        rep_text = " ".join(reports).upper()
        text = mm + " " + rep_text

        if "STEMI" in text:
            return True
        if "ST ELEVATION" in text:
            return True
        return False

    ecg_c["is_stemi"] = ecg_c.apply(has_stemi_flag, axis=1)

    events = []
    for _, row in ecg_c.iterrows():
        ts = row.get("charttime")
        if pd.isna(ts):
            continue

        events.append({
            "subject_id": row["subject_id"],
            "hadm_id": row["hadm_id"],
            "event_name": "ECG_TAKEN",
            "timestamp": ts,
            "attributes": _attrs_to_json({}),
        })

        if row["is_stemi"]:
            events.append({
                "subject_id": row["subject_id"],
                "hadm_id": row["hadm_id"],
                "event_name": "ECG_STEMI_FLAG",
                "timestamp": ts,
                "attributes": _attrs_to_json({}),
            })

    return pd.DataFrame(events)


def build_pci_events(cohort: pd.DataFrame,
                     procedures_icd: pd.DataFrame,
                     pci_icd9_prefixes: List[str],
                     pci_icd10_prefixes: List[str]) -> pd.DataFrame:
    """
    PCI_START 이벤트 생성.
    """
    proc = procedures_icd.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    proc = _to_datetime(proc, ["chartdate"])

    def is_pci(code: Any, version: Any) -> bool:
        if pd.isna(code):
            return False
        code_str = str(code)
        if version == 9:
            return any(code_str.startswith(p) for p in pci_icd9_prefixes)
        elif version == 10:
            return any(code_str.startswith(p) for p in pci_icd10_prefixes)
        else:
            return False

    proc["is_pci"] = proc.apply(
        lambda r: is_pci(r.get("icd_code", None), r.get("icd_version", None)), axis=1
    )
    pci_rows = proc[proc["is_pci"]].copy()

    events = []
    for _, row in pci_rows.iterrows():
        ts = row.get("chartdate")
        if pd.isna(ts):
            continue
        attrs = {
            "icd_code": row.get("icd_code", None),
            "icd_version": row.get("icd_version", None),
        }
        events.append({
            "subject_id": row["subject_id"],
            "hadm_id": row["hadm_id"],
            "event_name": "PCI_START",
            "timestamp": ts,
            "attributes": _attrs_to_json(attrs),
        })

    return pd.DataFrame(events)


def build_antiplatelet_events(cohort: pd.DataFrame,
                              prescriptions: pd.DataFrame,
                              drug_name_list: List[str]) -> pd.DataFrame:
    """
    ANTI_PLT_ORDER, ANTI_PLT_ADMIN 이벤트 생성.
    """
    rx = prescriptions.merge(cohort, on=["subject_id", "hadm_id"], how="inner")
    rx = _to_datetime(rx, ["starttime", "stoptime"])

    drug_lower_list = [d.lower() for d in drug_name_list]
    rx["drug_lower"] = rx["drug"].astype(str).str.lower()
    rx = rx[rx["drug_lower"].isin(drug_lower_list)]

    events = []
    for _, row in rx.iterrows():
        if pd.notnull(row.get("starttime")):
            attrs = {
                "drug": row.get("drug", None),
                "route": row.get("route", None),
                "dose_val_rx": row.get("dose_val_rx", None),
                "dose_unit_rx": row.get("dose_unit_rx", None),
            }
            events.append({
                "subject_id": row["subject_id"],
                "hadm_id": row["hadm_id"],
                "event_name": "ANTI_PLT_ORDER",
                "timestamp": row["starttime"],
                "attributes": _attrs_to_json(attrs),
            })
            events.append({
                "subject_id": row["subject_id"],
                "hadm_id": row["hadm_id"],
                "event_name": "ANTI_PLT_ADMIN",
                "timestamp": row["starttime"],
                "attributes": _attrs_to_json(attrs),
            })

    return pd.DataFrame(events)


# ==============================
# 4. Event Log 통합 + 140 버전 생성
# ==============================

def build_event_log_140(cohort: pd.DataFrame,
                        tables: Dict[str, pd.DataFrame]) -> pd.DataFrame:
    """
    140 버전 event log 생성.

    출력 컬럼:
      - case_id (hadm_id)
      - subject_id
      - hadm_id
      - event_name
      - timestamp
      - attributes
    """
    admissions = tables["admissions"]
    patients = tables["patients"]
    icustays = tables["icustays"]
    edstays = tables["edstays"]
    labevents_trop = tables["labevents_trop"]
    procedures_icd = tables["procedures_icd"]
    prescriptions = tables["prescriptions"]
    ecg = tables["ecg"]

    ed_events = build_ed_events(cohort, edstays, admissions)
    adm_events = build_admission_events(cohort, admissions, patients)
    icu_events = build_icu_events(cohort, icustays)
    trop_events = build_troponin_events(cohort, labevents_trop, TROP_POS_THRESHOLD)
    ecg_events = build_ecg_events(cohort, ecg)
    pci_events = build_pci_events(
        cohort,
        procedures_icd,
        pci_icd9_prefixes=PCI_ICD9_PREFIXES,
        pci_icd10_prefixes=PCI_ICD10_PREFIXES,
    )
    antiplatelet_events = build_antiplatelet_events(cohort, prescriptions, ANTI_PLT_DRUGS)

    all_events = pd.concat(
        [
            ed_events,
            adm_events,
            icu_events,
            trop_events,
            ecg_events,
            pci_events,
            antiplatelet_events,
        ],
        ignore_index=True
    )

    if all_events.empty:
        print("[WARN] 생성된 이벤트가 없습니다.")
        return pd.DataFrame(columns=["case_id", "subject_id", "hadm_id", "event_name", "timestamp", "attributes"])

    # timestamp 정리 및 정렬
    all_events = all_events.dropna(subset=["timestamp"])
    all_events["timestamp"] = pd.to_datetime(all_events["timestamp"], errors="coerce")
    all_events = all_events.dropna(subset=["timestamp"])

    all_events = all_events.sort_values(
        by=["subject_id", "hadm_id", "timestamp", "event_name"]
    ).reset_index(drop=True)

    # DISCHARGE 또는 DEATH 이후 이벤트 잘라내기
    def trim_after_end(df_one_adm: pd.DataFrame) -> pd.DataFrame:
        end_idx = df_one_adm[df_one_adm["event_name"].isin(["DISCHARGE", "DEATH"])].index
        if len(end_idx) == 0:
            return df_one_adm
        cutoff = end_idx.min()
        return df_one_adm.loc[:cutoff]

    all_events = (
        all_events.groupby("hadm_id", group_keys=False)
                  .apply(trim_after_end)
                  .reset_index(drop=True)
    )

    # case_id = hadm_id
    all_events["case_id"] = all_events["hadm_id"]

    # 최종 컬럼 순서
    all_events = all_events[
        ["case_id", "subject_id", "hadm_id", "event_name", "timestamp", "attributes"]
    ]

    # 디버깅용 요약
    print("\n[SUMMARY] 140 event log 요약")
    print(f"  전체 이벤트 row 수: {len(all_events)}")
    print(f"  hadm_id 수: {all_events['hadm_id'].nunique()} / 코호트 hadm_id 수: {len(cohort)}")
    ev_per_hadm = all_events.groupby("hadm_id")["event_name"].count()
    print("  hadm당 이벤트 개수 통계:")
    print(ev_per_hadm.describe())

    return all_events


# ==============================
# 5. MAIN
# ==============================

def main():
    # 1) 코호트 로딩 (각 subject당 첫 STEMI 입원)
    cohort = load_cohort_first_stemi_admission(COHORT_PATH)

    # 2) 원본 테이블 로딩
    tables = load_source_tables()

    print(f"[INFO] 최종 STEMI cohort size (첫 입원 기준): {len(cohort)}")

    # 3) 140 event log 생성
    event_log_140 = build_event_log_140(cohort, tables)

    # 4) 저장
    event_log_140.to_csv(EVENT_LOG_140_PATH, index=False)
    print(f"[SAVE] 140 버전 event log 저장 완료: {EVENT_LOG_140_PATH}")


if __name__ == "__main__":
    main()


[COHORT] 원래 STEMI 입원 후보 row 수: 1929
[COHORT] subject_id 기준 '첫 STEMI 입원'만 남긴 row 수: 1878
[LOAD] admissions, patients, icustays...
[LOAD] edstays (있으면 로딩)...
[LOAD] labevents_troponin ...
[LOAD] labevents_troponin.csv 로딩: ../../data/MIMIC4-hosp-icu\labevents_troponin.csv
[LOAD] procedures_icd, prescriptions, ECG machine_measurements...


  prescriptions = pd.read_csv(os.path.join(HOSP_DIR, "prescriptions.csv"))
  ecg = pd.read_csv(os.path.join(ECG_DIR, "machine_measurements.csv"))


[INFO] 최종 STEMI cohort size (첫 입원 기준): 1878


  all_events.groupby("hadm_id", group_keys=False)



[SUMMARY] 140 event log 요약
  전체 이벤트 row 수: 38674
  hadm_id 수: 1878 / 코호트 hadm_id 수: 1878
  hadm당 이벤트 개수 통계:
count    1878.000000
mean       20.593184
std        15.284606
min         1.000000
25%        11.000000
50%        18.000000
75%        25.000000
max       219.000000
Name: event_name, dtype: float64
[SAVE] 140 버전 event log 저장 완료: ./cohort\cohort_ver140_event_log.csv
