In [None]:
import os, glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
# ======================
# CONFIG
# ======================
CSV_PATH   = "D:\\NCKHSV-25-26-LastFire\\weather_dec2025_all_nodes.xls"
MODEL_PATH = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\best_timegnn.pt"
A_PATH     = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\A_norm.npy"
MU_PATH    = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\mu.npy"
SD_PATH    = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\sd.npy"
NODES_PATH = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\nodes.npy"
FEATS_PATH = "D:\\NCKHSV-25-26-LastFire\\LastFile_1.0\\feature_cols.npy"

BASE_DAY = pd.Timestamp("2025-12-31")   # b·∫°n ƒëang ·ªü 01/01/2026 ‚Üí base day l√† 31/12/2025
L = 30                                 # history window (gi·ªëng l√∫c train)
TH = 0.53                               # threshold c·∫£nh b√°o (b·∫°n ƒë·ªïi t√πy m·ª•c ti√™u)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ======================
# MODEL (y nh∆∞ l√∫c train)
# ======================
class TimeGCN_GRU(nn.Module):
    def __init__(self, A_norm, in_dim, gcn_dim=32, gru_dim=64, dropout=0.1):
        super().__init__()
        self.register_buffer("A", torch.tensor(A_norm, dtype=torch.float32))  # [N,N]
        self.gcn = nn.Linear(in_dim, gcn_dim)
        self.act = nn.ReLU()
        self.drop = nn.Dropout(dropout)
        self.gru = nn.GRU(input_size=gcn_dim, hidden_size=gru_dim, batch_first=True)
        self.head = nn.Linear(gru_dim, 1)

    def forward(self, x_seq):
        # x_seq: [B,L,N,F]
        B, L, N, F = x_seq.shape
        A = self.A

        hs = []
        for t in range(L):
            xt = x_seq[:, t, :, :]                          # [B,N,F]
            xmix = torch.einsum("ij,bjf->bif", A, xt)        # [B,N,F]
            ht = self.drop(self.act(self.gcn(xmix)))         # [B,N,gcn_dim]
            hs.append(ht)

        h = torch.stack(hs, dim=1)                           # [B,L,N,gcn_dim]

        logits = []
        for i in range(N):
            hi = h[:, :, i, :]                               # [B,L,gcn_dim]
            _, hn = self.gru(hi)                             # [1,B,gru_dim]
            last = hn[-1]                                    # [B,gru_dim]
            logit = self.head(self.drop(last)).squeeze(-1)   # [B]
            logits.append(logit)

        return torch.stack(logits, dim=1)                    # [B,N]

# ======================
# LOAD ARTIFACTS
# ======================
A_norm = np.load(A_PATH).astype(np.float32)
mu = np.load(MU_PATH).astype(np.float32)   # (1,F)
sd = np.load(SD_PATH).astype(np.float32)   # (1,F)
nodes = np.load(NODES_PATH, allow_pickle=True).tolist()
feature_cols = np.load(FEATS_PATH, allow_pickle=True).tolist()

F = len(feature_cols)
N = len(nodes)

ckpt = torch.load(MODEL_PATH, map_location="cpu")
state = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt

# infer gcn_dim/gru_dim t·ª´ weights (cho ch·∫Øc kh·ªõp)
gcn_dim = state["gcn.weight"].shape[0]
gru_dim = state["head.weight"].shape[1]

model = TimeGCN_GRU(A_norm, in_dim=F, gcn_dim=gcn_dim, gru_dim=gru_dim, dropout=0.0).to(DEVICE)
model.load_state_dict(state, strict=True)
model.eval()

print(f"Loaded model OK | N={N} F={F} gcn_dim={gcn_dim} gru_dim={gru_dim} DEVICE={DEVICE}")
print("Forecast window:", (BASE_DAY + pd.Timedelta(days=1)).date(), "->", (BASE_DAY + pd.Timedelta(days=7)).date())

# ======================
# READ CSV WEATHER
# ======================
df = pd.read_csv(CSV_PATH, engine="python", on_bad_lines="skip")
if "date" not in df.columns or "node" not in df.columns:
    raise ValueError("CSV must contain columns: node, date (v√† c√°c feature th·ªùi ti·∫øt)")

df["date"] = pd.to_datetime(df["date"]).dt.floor("D")

# check base_day c√≥ trong data
max_date = df["date"].max()
print("CSV date range:", df["date"].min().date(), "->", max_date.date())
if BASE_DAY > max_date:
    raise ValueError(f"BASE_DAY={BASE_DAY.date()} nh∆∞ng CSV ch·ªâ c√≥ t·ªõi {max_date.date()}")

# l·∫•y lat/lon theo node (ƒë·ªÉ in ra)
lat_col = "lat" if "lat" in df.columns else ("latitude" if "latitude" in df.columns else None)
lon_col = "lon" if "lon" in df.columns else ("longitude" if "longitude" in df.columns else None)

node_latlon = {}
if lat_col and lon_col:
    tmp = df.groupby("node")[[lat_col, lon_col]].mean(numeric_only=True).reset_index()
    for _, r in tmp.iterrows():
        node_latlon[str(r["node"])] = (float(r[lat_col]), float(r[lon_col]))
else:
    for n in df["node"].unique():
        node_latlon[str(n)] = (np.nan, np.nan)

# ======================
# BUILD INPUT WINDOW [L,N,F]
# ======================
base_weather_cols = [c for c in feature_cols if c != "weather_missing"]

end = BASE_DAY
start = end - pd.Timedelta(days=L-1)
full_dates = pd.date_range(start, end, freq="D")

X_list = []
missing_nodes = []
for n in nodes:
    g = df[(df["node"] == n) & (df["date"] <= BASE_DAY)].copy()
    if g.empty:
        missing_nodes.append(n)
        continue

    g = g.sort_values("date")
    g = g.set_index("date").reindex(full_dates).reset_index().rename(columns={"index":"date"})

    # t·∫°o b·∫£ng feature ƒë√∫ng th·ª© t·ª±
    # n·∫øu c·ªôt thi·∫øu trong CSV -> t·∫°o NaN
    for c in base_weather_cols:
        if c not in g.columns:
            g[c] = np.nan

    # missing flag theo ng√†y
    miss_flag = g[base_weather_cols].isna().any(axis=1).astype(np.float32).to_numpy()

    # fill: ffill (c√≥ t√≠nh "causal" v√¨ d√πng qu√° kh·ª©)
    g[base_weather_cols] = g[base_weather_cols].ffill()

    # c√≤n NaN ·ªü ƒë·∫ßu chu·ªói (ho·∫∑c c·ªôt thi·∫øu ho√†n to√†n) -> fill b·∫±ng mu c·ªßa train
    for c in base_weather_cols:
        if g[c].isna().any():
            idx = feature_cols.index(c)
            g[c] = g[c].fillna(float(mu[0, idx]))

    if "weather_missing" in feature_cols:
        g["weather_missing"] = miss_flag

    Xn = g[feature_cols].to_numpy(dtype=np.float32)  # (L,F)
    if Xn.shape != (L, F):
        raise ValueError(f"Node {n} got shape {Xn.shape}, expected {(L,F)}")
    X_list.append(Xn)

if missing_nodes:
    raise ValueError(f"Thi·∫øu data cho c√°c node n√†y trong CSV: {missing_nodes}")

X_raw = np.stack(X_list, axis=1)  # (L,N,F)

# standardize
Xz = (X_raw - mu.reshape(1,1,-1)) / (sd.reshape(1,1,-1) + 1e-6)

# ======================
# PREDICT
# ======================
x_seq = torch.tensor(Xz, dtype=torch.float32).unsqueeze(0).to(DEVICE)  # [1,L,N,F]
with torch.no_grad():
    prob = torch.sigmoid(model(x_seq)).cpu().numpy()[0]  # [N]

alert = (prob >= TH).astype(int)

out = pd.DataFrame({
    "node": nodes,
    "lat": [node_latlon[n][0] for n in nodes],
    "lon": [node_latlon[n][1] for n in nodes],
    "prob_fire_next7": prob,
    "prob_%": prob * 100.0,
    "alert_next7": alert
}).sort_values("prob_fire_next7", ascending=False).reset_index(drop=True)

print(f"\n=== PREDICT next7 from base_day={BASE_DAY.date()} | TH={TH:.2f} ===")
display(out)

if out["alert_next7"].any():
    print("\n‚ö†Ô∏è  C·∫¢NH B√ÅO: C√≥ node v∆∞·ª£t threshold nguy c∆° ch√°y trong 7 ng√†y t·ªõi.")
else:
    print("\n‚úÖ  OK: Kh√¥ng node n√†o v∆∞·ª£t threshold trong 7 ng√†y t·ªõi.")


# ================== CONFIG ==================
TH_ALERT = 0.53   # ng∆∞·ª°ng b·∫≠t c·∫£nh b√°o (alert_next7=1 n·∫øu prob >= TH_ALERT)

# C√°c m·ª©c r·ªßi ro (ch·ªâ ƒë·ªÉ hi·ªÉn th·ªã/di·ªÖn gi·∫£i cho d·ªÖ nh√¨n)
LEVELS = [
    (0.00, 0.45, "LOW",     "üü¢", "B√¨nh th∆∞·ªùng"),
    (0.45, 0.55, "WATCH",   "üü°", "Theo d√µi s√°t (c·∫£nh b√°o s·ªõm)"),
    (0.55, 0.70, "WARNING", "üü†", "C·∫£nh gi√°c cao (chu·∫©n b·ªã ph∆∞∆°ng √°n)"),
    (0.70, 1.01, "HIGH",    "üî¥", "C·∫£nh b√°o cao (∆∞u ti√™n ki·ªÉm tra)"),
]

# ================== HELPER ==================
def add_risk_levels(df: pd.DataFrame, prob_col="prob_fire_next7", th_alert=TH_ALERT) -> pd.DataFrame:
    out2 = df.copy()

    # √©p prob v·ªÅ [0,1] ƒë·ªÉ an to√†n
    p = out2[prob_col].astype(float).clip(0, 1)

    # % cho d·ªÖ ƒë·ªçc
    out2["prob_%"] = (p * 100).round(2)

    # b·∫≠t c·∫£nh b√°o theo ng∆∞·ª°ng ch√≠nh TH_ALERT
    out2["alert_next7"] = (p >= float(th_alert)).astype(np.int8)

    # g√°n risk level
    risk_name, risk_icon, advice = [], [], []
    for v in p.to_numpy():
        for lo, hi, name, icon, adv in LEVELS:
            if lo <= v < hi:
                risk_name.append(name)
                risk_icon.append(icon)
                advice.append(adv)
                break

    out2["risk"] = [f"{ic} {nm}" for ic, nm in zip(risk_icon, risk_name)]
    out2["advice"] = advice

    # sort ƒë·∫πp
    out2 = out2.sort_values(prob_col, ascending=False).reset_index(drop=True)
    return out2


def pretty_print_prediction(out2: pd.DataFrame, base_day=None, forecast_from=None, forecast_to=None, th_alert=TH_ALERT, topk=None):
    # ch·ªçn c·ªôt hi·ªÉn th·ªã
    show_cols = [c for c in ["node", "lat", "lon", "prob_fire_next7", "prob_%", "risk", "alert_next7", "advice"] if c in out2.columns]
    if topk is not None:
        view = out2.loc[:topk-1, show_cols].copy()
    else:
        view = out2[show_cols].copy()

    # format in console ƒë·∫πp
    def _fmt_prob(x): return f"{x:.4f}"
    def _fmt_pct(x):  return f"{x:.2f}%"

    if "prob_fire_next7" in view.columns:
        view["prob_fire_next7"] = view["prob_fire_next7"].map(_fmt_prob)
    if "prob_%" in view.columns:
        view["prob_%"] = view["prob_%"].map(_fmt_pct)

    title = "=== PREDICT next7 (risk-level view) ==="
    meta = []
    if base_day is not None:     meta.append(f"base_day: {base_day}")
    if forecast_from is not None and forecast_to is not None:
        meta.append(f"forecast: {forecast_from} -> {forecast_to}")
    meta.append(f"TH_ALERT: {th_alert}")

    print("\n" + title)
    print(" | ".join(meta))
    print("-" * 80)
    print(view.to_string(index=False))
    print("-" * 80)

    # summary
    if "risk" in out2.columns:
        print("Summary by risk:")
        print(out2["risk"].value_counts().to_string())
    if "alert_next7" in out2.columns:
        n_alert = int(out2["alert_next7"].sum())
        print(f"\nAlerts: {n_alert}/{len(out2)} nodes v∆∞·ª£t TH_ALERT={th_alert}")
        if n_alert == 0:
            print("‚úÖ OK: Kh√¥ng node n√†o v∆∞·ª£t threshold trong 7 ng√†y t·ªõi.")
        else:
            print("‚ö†Ô∏è  C√≥ node v∆∞·ª£t threshold ‚Äî xem c√°c d√≤ng alert_next7=1 ·ªü tr√™n.")


# ================== RUN ==================
# out l√† dataframe b·∫°n ƒë√£ t·∫°o ra tr∆∞·ªõc ƒë√≥ (node, lat, lon, prob_fire_next7, ...)
out2 = add_risk_levels(out, prob_col="prob_fire_next7", th_alert=TH_ALERT)

# N·∫øu b·∫°n c√≥ s·∫µn bi·∫øn base_day / forecast_from / forecast_to th√¨ truy·ªÅn v√†o cho ƒë·∫πp
# V√≠ d·ª•: base_day="2025-12-31", forecast_from="2026-01-01", forecast_to="2026-01-07"
pretty_print_prediction(out2, base_day="2025-12-31", forecast_from="2026-01-01", forecast_to="2026-01-07", th_alert=TH_ALERT)


Loaded model OK | N=7 F=46 gcn_dim=32 gru_dim=64 DEVICE=cpu
Forecast window: 2026-01-01 -> 2026-01-07
CSV date range: 2025-12-01 -> 2025-12-31

=== PREDICT next7 from base_day=2025-12-31 | TH=0.53 ===


Unnamed: 0,node,lat,lon,prob_fire_next7,prob_%,alert_next7
0,DL_FIRE_SV-C2_684275,12.2,108.2,0.278259,27.825945,0
1,DL_FIRE_SV-C2_684276,12.2,108.4,0.278259,27.825945,0
2,DL_FIRE_SV-C2_684286,12.2,108.8,0.277416,27.741608,0
3,DL_FIRE_SV-C2_684287,12.1,109.0,0.277416,27.741608,0
4,DL_FIRE_SV-C2_684294,11.9,108.85,0.277055,27.705505,0
5,DL_FIRE_SV-C2_684281,12.2,108.6,0.276125,27.612501,0
6,DL_FIRE_SV-C2_684292,11.9,108.6,0.276091,27.609062,0



‚úÖ  OK: Kh√¥ng node n√†o v∆∞·ª£t threshold trong 7 ng√†y t·ªõi.

=== PREDICT next7 (risk-level view) ===
base_day: 2025-12-31 | forecast: 2026-01-01 -> 2026-01-07 | TH_ALERT: 0.53
--------------------------------------------------------------------------------
                node  lat    lon prob_fire_next7 prob_%  risk  alert_next7      advice
DL_FIRE_SV-C2_684275 12.2 108.20          0.2783 27.83% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684276 12.2 108.40          0.2783 27.83% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684286 12.2 108.80          0.2774 27.74% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684287 12.1 109.00          0.2774 27.74% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684294 11.9 108.85          0.2771 27.71% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684281 12.2 108.60          0.2761 27.61% üü¢ LOW            0 B√¨nh th∆∞·ªùng
DL_FIRE_SV-C2_684292 11.9 108.60          0.2761 27.61% üü¢ LOW            0 B√¨nh th