In [None]:
import pandas as pd
import numpy as np
import ast
from collections import Counter

from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier


csv_path = "output_with_is_dead.csv"   
df = pd.read_csv(csv_path)
print("Loaded shape:", df.shape)


vital_cols = ["AT_START", "After_6hr", "After_12hr", "After_18hr", "After_24hr"]
for col in vital_cols:
    if col not in df.columns:
        raise ValueError(f"Required column '{col}' not found in CSV.")

def parse_vital_string(s):
    """
    Parse '[..]' string into np.array of floats.
    'Undefined', 'nan', 'none', '' -> NaN inside the array.
    """
    if pd.isna(s):
        return None
    try:
        arr = ast.literal_eval(str(s))
    except Exception:
        return None

    vals = []
    for x in arr:
        if isinstance(x, str):
            if x.strip().lower() in ["undefined", "nan", "none", ""]:
                vals.append(np.nan)
            else:
                try:
                    vals.append(float(x))
                except Exception:
                    vals.append(np.nan)
        else:
            try:
                vals.append(float(x))
            except Exception:
                vals.append(np.nan)
    return np.array(vals, dtype=float)

parsed_vitals = {col: df[col].apply(parse_vital_string) for col in vital_cols}

lengths = [
    len(a) for a in parsed_vitals["AT_START"]
    if a is not None
]
if not lengths:
    raise ValueError("No valid AT_START vitals parsed.")
target_len = Counter(lengths).most_common(1)[0][0]
print("Vital vector length (target_len):", target_len)

def vitals_to_matrix(series, target_len):
    """
    Convert a Series of arrays into (n_rows, target_len) matrix.
    - None -> all-NaN row
    - longer -> truncate
    - shorter -> pad with NaN
    """
    rows = []
    for arr in series:
        if arr is None:
            rows.append(np.full(target_len, np.nan))
            continue
        a = np.array(arr, dtype=float)
        if a.shape[0] > target_len:
            a = a[:target_len]
        elif a.shape[0] < target_len:
            pad = np.full(target_len - a.shape[0], np.nan)
            a = np.concatenate([a, pad])
        rows.append(a)
    return np.vstack(rows)

# Vitals matrices for all rows (may contain NaNs)
v_start = vitals_to_matrix(parsed_vitals["AT_START"],   target_len)
v_6     = vitals_to_matrix(parsed_vitals["After_6hr"],  target_len)
v_12    = vitals_to_matrix(parsed_vitals["After_12hr"], target_len)
v_18    = vitals_to_matrix(parsed_vitals["After_18hr"], target_len)
v_24    = vitals_to_matrix(parsed_vitals["After_24hr"], target_len)

# =========================================
# 3. Vector-space features for clustering
#    (AT_START + deltas + dose + abx one-hot)
# =========================================
delta_6  = v_6  - v_start
delta_12 = v_12 - v_start
delta_18 = v_18 - v_start
delta_24 = v_24 - v_start

# Dose
if "DOSE_VAL_RX" not in df.columns:
    raise ValueError("'DOSE_VAL_RX' column missing.")
dose_raw = pd.to_numeric(df["DOSE_VAL_RX"], errors="coerce").values.reshape(-1, 1)

# Antibiotics one-hot
abx_cols = [c for c in df.columns if c.startswith("abx__")]
if abx_cols:
    abx_raw = df[abx_cols].apply(pd.to_numeric, errors="coerce").values
else:
    abx_raw = None
    print("âš  No 'abx__' columns found; antibiotic info won't be used.")

# Build clustering feature matrix
X_cluster_parts = [
    v_start,
    delta_6,
    delta_12,
    delta_18,
    delta_24,
    dose_raw,
]
if abx_raw is not None:
    X_cluster_parts.append(abx_raw)

X_cluster_raw = np.hstack(X_cluster_parts)
print("X_cluster_raw shape:", X_cluster_raw.shape)

# Impute (median) + scale for KMeans (df itself untouched)
imp_cluster = SimpleImputer(strategy="median")
X_cluster_imp = imp_cluster.fit_transform(X_cluster_raw)

scaler = StandardScaler()
X_cluster_scaled = scaler.fit_transform(X_cluster_imp)

# =========================================
# 4. KMeans clustering on ALL rows
# =========================================
K = 4
kmeans = KMeans(n_clusters=K, random_state=42, n_init="auto")
cluster_labels = kmeans.fit_predict(X_cluster_scaled)
cluster_series = pd.Series(cluster_labels, index=df.index)

# =========================================
# 5. Mortality per cluster -> poor_response pseudo-labels
# =========================================
if "is_dead" not in df.columns:
    raise ValueError("'is_dead' column missing.")

mort_df = pd.DataFrame({
    "cluster": cluster_series,
    "is_dead": pd.to_numeric(df["is_dead"], errors="coerce")
})

cluster_mortality = (
    mort_df
    .dropna(subset=["is_dead"])
    .groupby("cluster")["is_dead"]
    .mean()
)

print("\nMortality by cluster:")
print(cluster_mortality)

# If some clusters missing mortality (no is_dead data), fill with global mean

if len(cluster_mortality) < K:
    overall_mean = cluster_mortality.mean()
    for c in range(K):
        if c not in cluster_mortality.index:
            cluster_mortality.loc[c] = overall_mean
    cluster_mortality = cluster_mortality.sort_index()

mort_threshold = float(cluster_mortality.median())
bad_clusters = cluster_mortality[cluster_mortality > mort_threshold].index.tolist()

print("\nMortality threshold:", mort_threshold)
print("Bad (poor-response) clusters:", bad_clusters)

cluster_to_poor = {c: int(c in bad_clusters) for c in range(K)}
poor_response_series = cluster_series.map(cluster_to_poor)  # 0 or 1 for ALL rows

# =========================================
# 6. Sweet-spot dose per antibiotic (good responders only)
# =========================================
sweet_spots = {}

if abx_cols:
    good_idx = poor_response_series[poor_response_series == 0].index
    df_good = df.loc[good_idx]

    for abx in abx_cols:
        doses_abx = pd.to_numeric(
            df_good.loc[df_good[abx] == 1, "DOSE_VAL_RX"],
            errors="coerce"
        ).dropna().values
        if len(doses_abx) < 10:
            continue
        q25, q50, q75 = np.percentile(doses_abx, [25, 50, 75])
        sweet_spots[abx] = {
            "q25": q25,
            "q50": q50,
            "q75": q75,
            "n": len(doses_abx),
        }

print("\nSweet-spot dose ranges (good responders):")
for abx, info in sweet_spots.items():
    print(f"{abx}: n={info['n']} | q25={info['q25']:.2f}, median={info['q50']:.2f}, q75={info['q75']:.2f}")

# =========================================
# 7. Classifier: predict poor_response (0/1) for all rows
#    Features: AT_START + dose + abx
# =========================================
X_clf_parts = [v_start, dose_raw]
if abx_raw is not None:
    X_clf_parts.append(abx_raw)

X_clf_raw = np.hstack(X_clf_parts)
y_clf = poor_response_series.values

imp_clf = SimpleImputer(strategy="median")
X_clf_imp = imp_clf.fit_transform(X_clf_raw)

clf = RandomForestClassifier(
    n_estimators=300,
    random_state=42,
    class_weight="balanced",
    n_jobs=-1,
)
clf.fit(X_clf_imp, y_clf)

print("\nClassifier trained. Class counts:", np.bincount(y_clf))

# Probability of poor response for ALL rows
prob_all = clf.predict_proba(X_clf_imp)
prob_poor_all = prob_all[:, 1]  # P(poor_response = 1)

# =========================================
# 8. Final decision: "increase" / "decrease"
#    Purely based on prob_poor + sweet-spot
# =========================================

dose_series = pd.to_numeric(df["DOSE_VAL_RX"], errors="coerce")

def pick_main_abx(row, abx_columns):
    """Pick first abx__* with value == 1 (if any)."""
    if not abx_columns:
        return None
    for abx in abx_columns:
        val = row.get(abx)
        try:
            if float(val) == 1.0:
                return abx
        except Exception:
            continue
    return None

def dose_zone_and_thresholds(dose_value, abx_name, sweet_spots_dict):
    """
    Return (zone, low_q, high_q) where:
      zone in {"below", "in", "above", "unknown"}
    """
    if pd.isna(dose_value) or abx_name not in sweet_spots_dict:
        return "unknown", None, None

    ss = sweet_spots_dict[abx_name]
    low, high = ss["q25"], ss["q75"]

    if dose_value < low:
        return "below", low, high
    elif dose_value > high:
        return "above", low, high
    else:
        return "in", low, high

def decide_increase_or_decrease(prob_poor, dose_value, abx_name, sweet_spots_dict):
    """
    Clinically motivated, threshold-based logic:

    - HIGH_RISK_THR: probability above which we treat as clearly high risk
    - MID_RISK_THR:  probability above which, if dose is below sweet-spot,
                     we are okay to increase

    Cases:
      1) zone == "above"  -> always "decrease"
      2) zone == "below"  ->
           if prob_poor >= MID_RISK_THR  -> "increase"
           else                          -> "decrease"
      3) zone == "in"     ->
           always "decrease" (no escalation from sweet-spot in this 2-label setup)
      4) zone == "unknown" ->
           if prob_poor >= HIGH_RISK_THR -> "increase"
           else                          -> "decrease"
    """
    if prob_poor is None or np.isnan(prob_poor):
        return "decrease"  # safe default

    # tune yaha se: kaafi aggressive/ conservative ban sakta hai
    HIGH_RISK_THR = 0.75
    MID_RISK_THR  = 0.55

    zone, low_q, high_q = dose_zone_and_thresholds(dose_value, abx_name, sweet_spots_dict)

    # 1) Already above sweet-spot: never increase
    if zone == "above":
        return "decrease"

    # 2) Dose below sweet-spot
    if zone == "below":
        if prob_poor >= MID_RISK_THR:
            return "increase"
        else:
            return "decrease"

    # 3) Dose within sweet-spot range
    if zone == "in":
        # In real life you might 'switch' drug; hamare 2-label world me
        # "no further escalation" ko "decrease" treat kar rahe hain.
        return "decrease"

    # 4) Unknown zone (no sweet-spot info or no dose)
    if zone == "unknown":
        if prob_poor >= HIGH_RISK_THR:
            return "increase"
        else:
            return "decrease"

    # Fallback (should not reach here)
    return "decrease"

# -----------------------------------------
# 8.1 Loop over rows and apply decision
# -----------------------------------------
final_preds = []

for i, (idx, row) in enumerate(df.iterrows()):
    prob_poor = prob_poor_all[i]
    dose_val = dose_series.iloc[i]
    main_abx = pick_main_abx(row, abx_cols)
    pred_label = decide_increase_or_decrease(
        prob_poor=prob_poor,
        dose_value=dose_val,
        abx_name=main_abx,
        sweet_spots_dict=sweet_spots,
    )
    final_preds.append(pred_label)

# =========================================
# 9. Add ONLY ONE new column + save
# =========================================
df["final_prediction"] = final_preds  # original columns unchanged

out_path = "output_with_prediction.csv"
df.to_csv(out_path, index=False)
print("Saved with predictions to:", out_path)

print("\nPrediction counts:")
print(df["final_prediction"].value_counts())
print("\nPrediction proportions:")
print(df["final_prediction"].value_counts(normalize=True))


  df = pd.read_csv(csv_path)


Loaded shape: (16431, 135)
Vital vector length (target_len): 4
X_cluster_raw shape: (16431, 142)

Mortality by cluster:
cluster
0    0.000000
1    0.162529
2    0.181435
3    0.013333
Name: is_dead, dtype: float64

Mortality threshold: 0.08793140540316909
Bad (poor-response) clusters: [1, 2]

Sweet-spot dose ranges (good responders):
abx__Cephalexin: n=75 | q25=500.00, median=500.00, q75=500.00

Classifier trained. Class counts: [   80 16351]
Saved with predictions to: output_with_prediction.csv

Prediction counts:
final_prediction
increase    16351
decrease       80
Name: count, dtype: int64

Prediction proportions:
final_prediction
increase    0.995131
decrease    0.004869
Name: proportion, dtype: float64
