In [None]:
import os, io, csv, json, pickle, base64, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from typing import Dict, Any, List
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from starlette.staticfiles import StaticFiles

import tensorflow as tf

# --------------------------
# Paths
# --------------------------
BASE        = r"C:\Users\sagni\Downloads\GraphGuard"
FEATURES_CSV= r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_features.csv"
CLASSES_CSV = r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_classes.csv"
EDGES_CSV   = r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_edgelist.csv"

PREPROC_PKL = os.path.join(BASE, "preprocessor.pkl")
H5_PATH     = os.path.join(BASE, "model.h5")
KERAS_PATH  = os.path.join(BASE, "model.keras")  # optional
THRESH_PATH = os.path.join(BASE, "threshold.json")

# Globals (lazy init)
_preproc = None
_model   = None
_best_t  = None
_df      = None
_df_idx  = None
_G       = None
_feature_cols = None
_time_col    = None
_txid_col    = None

def robust_read_csv(path, expected_min_cols=2):
    delims = [",",";","\t","|"]
    encs   = ["utf-8","utf-8-sig","cp1252","latin1"]
    last_err=None
    for enc in encs:
        for sep in delims:
            try:
                df = pd.read_csv(path, encoding=enc, sep=sep, engine="python")
                if df.shape[1] >= expected_min_cols:
                    return df
            except Exception as e:
                last_err=e
    raise RuntimeError(f"Could not parse {path}. Last error: {last_err}")

def _lazy_load_preproc():
    global _preproc, _feature_cols, _time_col, _txid_col, _best_t
    if _preproc is None:
        if not os.path.exists(PREPROC_PKL):
            raise RuntimeError(f"Missing preprocessor.pkl at {PREPROC_PKL}")
        with open(PREPROC_PKL, "rb") as f:
            _preproc = pickle.load(f)
        _feature_cols = list(_preproc["feature_columns"])
        _time_col     = _preproc["time_column"]
        _txid_col     = _preproc["txid_column"]
        if not os.path.exists(THRESH_PATH):
            raise RuntimeError(f"Missing threshold.json at {THRESH_PATH}")
        with open(THRESH_PATH, "r", encoding="utf-8") as f:
            _best_t = float(json.load(f)["best_threshold"])
    return _preproc

def _lazy_load_model():
    global _model
    if _model is not None:
        return _model
    if os.path.exists(KERAS_PATH):
        try:
            _model = tf.keras.models.load_model(KERAS_PATH, safe_mode=False)
            return _model
        except Exception as e:
            print("[WARN] model.keras load failed, falling back to H5:", e)
    if os.path.exists(H5_PATH):
        _model = tf.keras.models.load_model(H5_PATH)
        return _model
    raise RuntimeError(f"Model not found: {KERAS_PATH} or {H5_PATH}")

def _lazy_load_data_graph():
    global _df, _df_idx, _G
    if _df is not None and _df_idx is not None and _G is not None:
        return _df, _df_idx, _G

    pre = _lazy_load_preproc()
    feature_cols = _feature_cols
    time_col     = _time_col
    txid_col     = _txid_col

    df_feat = robust_read_csv(FEATURES_CSV, expected_min_cols=3)
    df_cls  = robust_read_csv(CLASSES_CSV,  expected_min_cols=2)
    df_edge = robust_read_csv(EDGES_CSV,    expected_min_cols=2)

    feat_cols = list(df_feat.columns)
    tx_col_feat, time_col_feat = feat_cols[0], feat_cols[1]
    cls_cols = list(df_cls.columns)
    tx_col_cls, class_col = cls_cols[0], cls_cols[1]
    edge_cols = list(df_edge.columns)
    src_col, dst_col = edge_cols[0], edge_cols[1]

    # IDs as string
    df_feat[tx_col_feat] = df_feat[tx_col_feat].astype(str)
    df_cls[tx_col_cls]   = df_cls[tx_col_cls].astype(str)
    df_edge[src_col]     = df_edge[src_col].astype(str)
    df_edge[dst_col]     = df_edge[dst_col].astype(str)

    # Degrees
    in_deg  = df_edge.groupby(dst_col).size().rename("in_degree")
    out_deg = df_edge.groupby(src_col).size().rename("out_degree")
    deg_df  = pd.concat([in_deg, out_deg], axis=1).fillna(0.0).reset_index()
    deg_df.rename(columns={deg_df.columns[0]: tx_col_feat}, inplace=True)

    # Merge
    df_feat[time_col_feat] = pd.to_numeric(df_feat[time_col_feat], errors="coerce")
    df = df_feat.merge(deg_df, on=tx_col_feat, how="left")
    df[["in_degree","out_degree"]] = df[["in_degree","out_degree"]].fillna(0.0)

    # Optional labels
    df_cls[class_col] = df_cls[class_col].astype(str).str.lower().str.strip()
    df_cls["label"] = df_cls[class_col].map({"1":0,"2":1,"licit":0,"illicit":1}).astype("Int64")
    df = df.merge(df_cls[[tx_col_cls,"label"]], left_on=tx_col_feat, right_on=tx_col_cls, how="left")
    if tx_col_cls in df.columns and tx_col_cls != tx_col_feat:
        df = df.drop(columns=[tx_col_cls])

    # Keep numeric features
    for c in feature_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce")
    df = df.dropna(subset=[time_col_feat] + feature_cols).reset_index(drop=True)

    # Build undirected graph for ego subgraph
    G = nx.from_pandas_edgelist(df_edge, source=src_col, target=dst_col, create_using=nx.Graph())

    _df, _df_idx, _G = df, df.set_index(tx_col_feat), G
    # align names
    assert time_col == time_col_feat and txid_col == tx_col_feat, "Saved columns don't match CSV headers."
    return _df, _df_idx, _G

def _scale_and_predict(x_row: np.ndarray) -> (float, int):
    pre = _lazy_load_preproc()
    model = _lazy_load_model()
    scaler = pre["scaler"]
    Xs = scaler.transform(x_row)
    prob = float(model.predict(Xs, verbose=0).ravel()[0])
    pred = 1 if prob >= _best_t else 0
    return prob, pred

def _grad_input_importance(x_scaled: np.ndarray) -> np.ndarray:
    x = tf.convert_to_tensor(x_scaled.astype("float32"))
    with tf.GradientTape() as tape:
        tape.watch(x)
        p = _lazy_load_model()(x, training=False)
        y = p[:, 0]
    grads = tape.gradient(y, x).numpy()[0]
    contrib = np.abs(grads * x.numpy()[0])
    return contrib

def _get_tx_features(txid: str) -> Dict[str, Any]:
    df, df_idx, _ = _lazy_load_data_graph()
    pre = _lazy_load_preproc()
    feature_cols = _feature_cols
    time_col     = _time_col
    if txid not in df_idx.index:
        raise KeyError(f"txId '{txid}' not found.")
    row = df_idx.loc[txid]
    x = row[feature_cols].values.reshape(1, -1)
    t = int(row[time_col])
    label = None
    if "label" in df_idx.columns and not pd.isna(row.get("label", pd.NA)):
        label = int(row["label"])
    return {"x": x, "time": t, "label": label}

def _ego_subgraph_png(txid: str, k: int = 2, max_nodes: int = 150):
    _, _, G = _lazy_load_data_graph()
    if txid not in G:
        raise KeyError(f"txId '{txid}' has no edges.")
    nodes = list(nx.ego_graph(G, txid, radius=k).nodes())
    if len(nodes) > max_nodes:
        nodes = nodes[:max_nodes]
    SG = G.subgraph(nodes).copy()
    plt.figure(figsize=(7, 6))
    pos = nx.spring_layout(SG, seed=42, k=1/np.sqrt(max(len(SG),1)))
    node_colors = ["red" if n == txid else "steelblue" for n in SG.nodes()]
    nx.draw_networkx_nodes(SG, pos, node_color=node_colors, node_size=80, alpha=0.9, linewidths=0.3, edgecolors="white")
    nx.draw_networkx_edges(SG, pos, alpha=0.4, width=0.8)
    nx.draw_networkx_labels(SG, pos, labels={txid: txid}, font_size=8)
    buf = io.BytesIO()
    plt.tight_layout(); plt.savefig(buf, format="png", dpi=150); plt.close()
    return {"nodes":[{"id":n} for n in SG.nodes()], "edges":[{"source":u,"target":v} for u,v in SG.edges()]}, base64.b64encode(buf.getvalue()).decode("utf-8")

# --------------------------
# FastAPI app
# --------------------------
app = FastAPI(title="GraphGuard API", version="1.0.1", description="Fraud scoring + ego subgraph (lazy-load)")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory=BASE), name="static")

@app.get("/", response_class=HTMLResponse)
def root():
    return HTMLResponse('<meta http-equiv="refresh" content="0; url=/static/index.html">')

@app.get("/health")
def health():
    try:
        _lazy_load_preproc()
        return {"status":"ok","threshold":_best_t,"feature_dim":len(_feature_cols)}
    except Exception as e:
        return {"status":"error","detail":str(e)}

@app.get("/selftest")
def selftest():
    try:
        df, df_idx, G = _lazy_load_data_graph()
        any_tx = str(df_idx.index[0])
        info = _get_tx_features(any_tx)
        prob, pred = _scale_and_predict(info["x"])
        return {"status":"ok","example_txId":any_tx,"prob_illicit":prob,"pred_label":pred,"n_nodes":int(G.number_of_nodes()),"n_edges":int(G.number_of_edges())}
    except Exception as e:
        return {"status":"error","detail":str(e)}

@app.post("/score")
async def score(payload: Dict[str, Any]):
    mode = payload.get("mode","txid")
    try:
        _lazy_load_preproc(); _lazy_load_model(); _lazy_load_data_graph()
        if mode == "txid":
            txid = str(payload.get("txId","")).strip()
            if not txid:
                raise ValueError("txId required.")
            info = _get_tx_features(txid)
            prob, pred = _scale_and_predict(info["x"])
            x_scaled = _preproc["scaler"].transform(info["x"])
            contrib  = _grad_input_importance(x_scaled)
            top_idx  = np.argsort(contrib)[::-1][:10]
            top_feats= [{"feature": _feature_cols[i], "score": float(contrib[i])} for i in top_idx]
            return {"txId":txid,"timeStep":info["time"],"prob_illicit":prob,"pred_label":pred,"true_label":info["label"],"threshold":_best_t,"top_features":top_feats}
        elif mode == "payload":
            feats = payload.get("features",{})
            missing = [c for c in _feature_cols if c not in feats]
            if missing: raise ValueError(f"Missing features: {missing[:10]}...")
            x = np.array([[float(feats[c]) for c in _feature_cols]], dtype="float32")
            prob, pred = _scale_and_predict(x)
            x_scaled = _preproc["scaler"].transform(x)
            contrib  = _grad_input_importance(x_scaled)
            top_idx  = np.argsort(contrib)[::-1][:10]
            top_feats= [{"feature": _feature_cols[i], "score": float(contrib[i])} for i in top_idx]
            return {"prob_illicit":prob,"pred_label":pred,"threshold":_best_t,"top_features":top_feats}
        else:
            raise ValueError("mode must be 'txid' or 'payload'")
    except KeyError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.post("/explain")
async def explain(payload: Dict[str, Any]):
    txid = str(payload.get("txId","")).strip()
    k = int(payload.get("k",2)); max_nodes = int(payload.get("max_nodes",150))
    if not txid:
        raise HTTPException(status_code=400, detail="txId required")
    try:
        _lazy_load_preproc(); _lazy_load_model(); _lazy_load_data_graph()
        info = _get_tx_features(txid)
        prob, pred = _scale_and_predict(info["x"])
        x_scaled = _preproc["scaler"].transform(info["x"])
        contrib  = _grad_input_importance(x_scaled)
        top_idx  = np.argsort(contrib)[::-1][:10]
        top_feats= [{"feature": _feature_cols[i], "score": float(contrib[i])} for i in top_idx]
        subgraph_json, b64 = _ego_subgraph_png(txid, k=k, max_nodes=max_nodes)
        return {"txId":txid,"timeStep":info["time"],"prob_illicit":prob,"pred_label":pred,"true_label":info["label"],"threshold":_best_t,"top_features":top_feats,"subgraph":subgraph_json,"subgraph_png_base64":b64}
    except KeyError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="127.0.0.1", port=8000, reload=True)


INFO:     Will watch for changes in these directories: ['C:\\Users\\sagni']
INFO:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
INFO:     Started reloader process [4508] using StatReload
