In [1]:
# ============================================================
# GraphGuard — FastAPI Inference + Explainability pack
#   Writes to: C:\Users\sagni\Downloads\GraphGuard
#   Files: app.py, index.html, requirements_api.txt, run_api.bat
# Endpoints:
#   - GET  /health
#   - POST /score   {mode: "txid"|"payload", ...}
#   - POST /explain {txId, k=2, max_nodes=150} -> subgraph + base64 PNG + top features
# ============================================================
import os
BASE = r"C:\Users\sagni\Downloads\GraphGuard"
os.makedirs(BASE, exist_ok=True)

def write(path, content):
    with open(path, "w", encoding="utf-8") as f:
        f.write(content)

APP_PY = r'''import os, io, csv, json, pickle, base64, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from typing import Optional, List, Dict, Any
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, HTMLResponse
from starlette.staticfiles import StaticFiles

import tensorflow as tf

# --------------------------
# Paths
# --------------------------
BASE        = r"C:\Users\sagni\Downloads\GraphGuard"
FEATURES_CSV= r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_features.csv"
CLASSES_CSV = r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_classes.csv"
EDGES_CSV   = r"C:\Users\sagni\Downloads\GraphGuard\archive (1)\elliptic_bitcoin_dataset\elliptic_txs_edgelist.csv"

PREPROC_PKL = os.path.join(BASE, "preprocessor.pkl")
H5_PATH     = os.path.join(BASE, "model.h5")
KERAS_PATH  = os.path.join(BASE, "model.keras")  # optional
THRESH_PATH = os.path.join(BASE, "threshold.json")

# --------------------------
# Utils
# --------------------------
def robust_read_csv(path, expected_min_cols=2):
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    delims = [",",";","\t","|"]
    encs   = ["utf-8","utf-8-sig","cp1252","latin1"]
    try:
        with open(path, "rb") as f:
            head = f.read(8192).decode("latin1", errors="ignore")
        sniff = csv.Sniffer().sniff(head)
        if sniff.delimiter in delims:
            delims = [sniff.delimiter] + [d for d in delims if d != sniff.delimiter]
    except Exception:
        pass
    last_err=None
    for enc in encs:
        for sep in delims:
            try:
                df = pd.read_csv(path, encoding=enc, sep=sep, engine="python")
                if df.shape[1]>=expected_min_cols:
                    return df
            except Exception as e:
                last_err=e
    raise RuntimeError(f"Could not parse {path}. Last error: {last_err}")

def numpy_to_base64_png(arr, figsize=(6,4), title=None):
    plt.figure(figsize=figsize)
    plt.plot(arr)
    if title: plt.title(title)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png", dpi=150)
    plt.close()
    return base64.b64encode(buf.getvalue()).decode("utf-8")

def fig_to_base64_png():
    buf = io.BytesIO()
    plt.tight_layout()
    plt.savefig(buf, format="png", dpi=150)
    plt.close()
    return base64.b64encode(buf.getvalue()).decode("utf-8")

# --------------------------
# Load artifacts
# --------------------------
with open(PREPROC_PKL, "rb") as f:
    preproc = pickle.load(f)

feature_cols: List[str] = list(preproc["feature_columns"])
scaler = preproc["scaler"]
time_col = preproc["time_column"]
txid_col = preproc["txid_column"]
splits = preproc["splits"]

with open(THRESH_PATH, "r", encoding="utf-8") as f:
    best_t = float(json.load(f)["best_threshold"])

# Model
model = None
if os.path.exists(KERAS_PATH):
    try:
        model = tf.keras.models.load_model(KERAS_PATH, safe_mode=False)
    except Exception:
        model=None
if model is None and os.path.exists(H5_PATH):
    model = tf.keras.models.load_model(H5_PATH)
if model is None:
    raise RuntimeError("Model not found (model.keras / model.h5). Train first.")

# --------------------------
# Load dataframes + build graph & lookups
# --------------------------
df_feat = robust_read_csv(FEATURES_CSV, expected_min_cols=3)
df_cls  = robust_read_csv(CLASSES_CSV,  expected_min_cols=2)
df_edge = robust_read_csv(EDGES_CSV,    expected_min_cols=2)

# Column names
feat_cols = list(df_feat.columns)
tx_col_feat, time_col_feat = feat_cols[0], feat_cols[1]
cls_cols = list(df_cls.columns)
tx_col_cls, class_col = cls_cols[0], cls_cols[1]
edge_cols = list(df_edge.columns)
src_col, dst_col = edge_cols[0], edge_cols[1]

# Cast IDs to string
df_feat[tx_col_feat] = df_feat[tx_col_feat].astype(str)
df_cls[tx_col_cls]   = df_cls[tx_col_cls].astype(str)
df_edge[src_col]     = df_edge[src_col].astype(str)
df_edge[dst_col]     = df_edge[dst_col].astype(str)

# Degrees
in_deg  = df_edge.groupby(dst_col).size().rename("in_degree")
out_deg = df_edge.groupby(src_col).size().rename("out_degree")
deg_df  = pd.concat([in_deg, out_deg], axis=1).fillna(0.0).reset_index()
deg_df.rename(columns={deg_df.columns[0]: tx_col_feat}, inplace=True)

# Merge features + degrees + labels (labels optional here)
df_feat[time_col_feat] = pd.to_numeric(df_feat[time_col_feat], errors="coerce")
df = df_feat.merge(deg_df, on=tx_col_feat, how="left")
df[["in_degree","out_degree"]] = df[["in_degree","out_degree"]].fillna(0.0)
# Optional labels
label_map = {"1":0,"2":1,"licit":0,"illicit":1}
df_cls[class_col] = df_cls[class_col].astype(str).str.lower().str.strip()
df_cls["label"] = df_cls[class_col].map(label_map).astype("Int64")
df = df.merge(df_cls[[tx_col_cls,"label"]], left_on=tx_col_feat, right_on=tx_col_cls, how="left")
if tx_col_cls in df.columns and tx_col_cls != tx_col_feat:
    df = df.drop(columns=[tx_col_cls])

# Keep numeric features
for c in feature_cols:
    df[c] = pd.to_numeric(df[c], errors="coerce")
df = df.dropna(subset=[time_col_feat] + feature_cols).reset_index(drop=True)

# Build networkx graph for explanations (undirected view is fine for ego)
G = nx.from_pandas_edgelist(df_edge, source=src_col, target=dst_col, create_using=nx.Graph())

# Index for quick lookup by txId
df_indexed = df.set_index(tx_col_feat)

# --------------------------
# FastAPI app
# --------------------------
app = FastAPI(title="GraphGuard API", version="1.0.0", description="Fraud scoring + k-hop subgraph explanations")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
)
# Serve static (index.html)
app.mount("/static", StaticFiles(directory=BASE), name="static")

@app.get("/", response_class=HTMLResponse)
def home():
    return HTMLResponse('<meta http-equiv="refresh" content="0; url=/static/index.html">')

@app.get("/health")
def health():
    return {
        "status": "ok",
        "feature_dim": len(feature_cols),
        "n_nodes": int(G.number_of_nodes()),
        "n_edges": int(G.number_of_edges()),
        "threshold": best_t,
        "time_column": time_col,
        "txid_column": txid_col,
        "splits": {k: len(v) for k,v in splits.items()},
    }

def _scale_and_predict(x_row: np.ndarray) -> (float, int):
    """x_row is unscaled features (1, D). Returns prob, pred@best_t"""
    Xs = scaler.transform(x_row)
    prob = float(model.predict(Xs, verbose=0).ravel()[0])
    pred = 1 if prob >= best_t else 0
    return prob, pred

def _grad_input_importance(x_scaled: np.ndarray) -> np.ndarray:
    """Gradient x input attribution on scaled features."""
    x = tf.convert_to_tensor(x_scaled.astype("float32"))
    with tf.GradientTape() as tape:
        tape.watch(x)
        p = model(x, training=False)
        y = p[:, 0]  # prob of illicit (sigmoid unit)
    grads = tape.gradient(y, x).numpy()[0]
    # grad * input magnitude
    contrib = np.abs(grads * x.numpy()[0])
    return contrib

def _get_tx_features(txid: str) -> Dict[str, Any]:
    if txid not in df_indexed.index:
        raise KeyError(f"txId '{txid}' not found.")
    row = df_indexed.loc[txid]
    feat_vals = row[feature_cols].values.reshape(1, -1)
    time_step = int(row[time_col])
    label = None
    if "label" in df_indexed.columns and not pd.isna(row.get("label", pd.NA)):
        label = int(row["label"])
    return {"x": feat_vals, "time": time_step, "label": label}

@app.post("/score")
async def score(payload: Dict[str, Any]):
    """
    Two modes:
      - {"mode":"txid", "txId":"..."}
      - {"mode":"payload", "features": {<feature>:value,...}, "in_degree":..., "out_degree":..., "timeStep": <int>}
    Returns: prob_illicit, pred_label, (optional) true_label, top_features
    """
    mode = payload.get("mode", "txid")
    try:
        if mode == "txid":
            txid = str(payload.get("txId", "")).strip()
            if not txid:
                raise ValueError("txId required for mode='txid'")
            info = _get_tx_features(txid)
            prob, pred = _scale_and_predict(info["x"])
            # feature attribution
            x_scaled = scaler.transform(info["x"])
            contrib = _grad_input_importance(x_scaled)
            top_idx = np.argsort(contrib)[::-1][:10]
            top_feats = [{"feature": feature_cols[i], "score": float(contrib[i])} for i in top_idx]
            return {
                "txId": txid,
                "timeStep": info["time"],
                "prob_illicit": prob,
                "pred_label": pred,
                "true_label": info["label"],
                "threshold": best_t,
                "top_features": top_feats
            }

        elif mode == "payload":
            fdict = payload.get("features", {})
            missing = [c for c in feature_cols if c not in fdict]
            if missing:
                raise ValueError(f"Missing features: {missing[:10]}...")
            x = np.array([[float(fdict[c]) for c in feature_cols]], dtype="float32")
            prob, pred = _scale_and_predict(x)
            x_scaled = scaler.transform(x)
            contrib = _grad_input_importance(x_scaled)
            top_idx = np.argsort(contrib)[::-1][:10]
            top_feats = [{"feature": feature_cols[i], "score": float(contrib[i])} for i in top_idx]
            return {
                "prob_illicit": prob,
                "pred_label": pred,
                "threshold": best_t,
                "top_features": top_feats
            }
        else:
            raise ValueError("mode must be 'txid' or 'payload'")
    except KeyError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

def _ego_subgraph_png(txid: str, k: int = 2, max_nodes: int = 150) -> (Dict[str, Any], str):
    """
    Returns (subgraph_json, base64_png) for k-hop ego network around txid.
    Limits nodes for readability. Colors center red, neighbors blue.
    """
    if txid not in G:
        raise KeyError(f"txId '{txid}' has no edges in graph.")
    nodes = list(nx.ego_graph(G, txid, radius=k).nodes())
    if len(nodes) > max_nodes:
        nodes = nodes[:max_nodes]
    SG = G.subgraph(nodes).copy()

    # JSON
    nodes_json = [{"id": n} for n in SG.nodes()]
    edges_json = [{"source": u, "target": v} for u, v in SG.edges()]

    # PNG rendering
    plt.figure(figsize=(7, 6))
    pos = nx.spring_layout(SG, seed=42, k=1/np.sqrt(max(len(SG),1)))
    node_colors = ["red" if n == txid else "steelblue" for n in SG.nodes()]
    nx.draw_networkx_nodes(SG, pos, node_color=node_colors, node_size=80, alpha=0.9, linewidths=0.3, edgecolors="white")
    nx.draw_networkx_edges(SG, pos, alpha=0.4, width=0.8)
    # label a few around center
    nx.draw_networkx_labels(SG, pos, labels={txid: txid}, font_size=8)
    b64 = fig_to_base64_png()
    return {"nodes": nodes_json, "edges": edges_json}, b64

@app.post("/explain")
async def explain(payload: Dict[str, Any]):
    """
    Input: {"txId": "...", "k": 2, "max_nodes": 150}
    Output:
      - prob_illicit, pred_label, (optional) true_label
      - top_features (grad*input)
      - subgraph {nodes, edges}
      - subgraph_png_base64
    """
    txid = str(payload.get("txId", "")).strip()
    k = int(payload.get("k", 2))
    max_nodes = int(payload.get("max_nodes", 150))
    if not txid:
        raise HTTPException(status_code=400, detail="txId required")

    try:
        info = _get_tx_features(txid)
        prob, pred = _scale_and_predict(info["x"])
        x_scaled = scaler.transform(info["x"])
        contrib = _grad_input_importance(x_scaled)
        top_idx = np.argsort(contrib)[::-1][:10]
        top_feats = [{"feature": feature_cols[i], "score": float(contrib[i])} for i in top_idx]

        sgj, b64 = _ego_subgraph_png(txid, k=k, max_nodes=max_nodes)
        return {
            "txId": txid,
            "timeStep": info["time"],
            "prob_illicit": prob,
            "pred_label": pred,
            "true_label": info["label"],
            "threshold": best_t,
            "top_features": top_feats,
            "subgraph": sgj,
            "subgraph_png_base64": b64
        }
    except KeyError as e:
        raise HTTPException(status_code=404, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))
'''

INDEX_HTML = r'''<!doctype html>
<html>
<head>
  <meta charset="utf-8"/>
  <title>GraphGuard — Fraud Scoring & Explainability</title>
  <meta name="viewport" content="width=device-width, initial-scale=1"/>
  <style>
    body { font-family: ui-sans-serif, system-ui, Segoe UI, Roboto, Arial; max-width: 1000px; margin: auto; padding: 24px; }
    h1 { margin: 0 0 8px; }
    .card { border: 1px solid #e5e7eb; border-radius: 12px; padding: 16px; margin: 10px 0; box-shadow: 0 2px 12px rgba(0,0,0,0.04); }
    .row { display: flex; gap: 16px; flex-wrap: wrap; }
    .col { flex: 1 1 360px; }
    input, button { padding: 10px 12px; border-radius: 8px; border: 1px solid #d1d5db; }
    button { cursor: pointer; }
    button:hover { background: #f3f4f6; }
    pre { background: #0b1021; color: #8df; padding: 12px; border-radius: 8px; overflow: auto; }
    img { max-width: 100%; border-radius: 8px; }
    table { width: 100%; border-collapse: collapse; }
    th, td { text-align: left; padding: 6px 8px; border-bottom: 1px solid #eee; }
    .pill { display:inline-block; padding: 2px 8px; border-radius: 999px; background:#eef2ff; color:#1f2937; font-size: 12px; }
  </style>
</head>
<body>
  <h1>GraphGuard — Fraud Scoring & Explainability</h1>
  <p class="pill">Enter a <b>txId</b> from Elliptic to score and visualize its ego subgraph.</p>

  <div class="card">
    <div class="row">
      <div class="col">
        <label>txId</label><br/>
        <input id="txid" placeholder="e.g. 230425980" style="width:100%"/>
        <div style="margin-top:8px;">
          <button id="btnScore">Score</button>
          <button id="btnExplain">Explain (k=2)</button>
        </div>
      </div>
      <div class="col">
        <div><b>Top Features</b></div>
        <table id="featTable"><thead><tr><th>Feature</th><th>Score</th></tr></thead><tbody></tbody></table>
      </div>
    </div>
  </div>

  <div class="row">
    <div class="col card">
      <h3>Result</h3>
      <div id="result"></div>
    </div>
    <div class="col card">
      <h3>Ego Subgraph</h3>
      <div id="graph"></div>
    </div>
  </div>

  <script>
    async function postJSON(url, data){
      const res = await fetch(url, {method:'POST', headers:{'Content-Type':'application/json'}, body: JSON.stringify(data)});
      const js = await res.json();
      if(!res.ok) throw new Error(js.detail || 'Error');
      return js;
    }
    function setResult(js){
      document.getElementById('result').innerHTML =
        '<p><b>txId:</b> '+(js.txId??'(payload)')+'</p>'+
        '<p><b>timeStep:</b> '+(js.timeStep??'')+'</p>'+
        '<p><b>prob_illicit:</b> '+js.prob_illicit.toFixed(4)+'</p>'+
        '<p><b>pred_label:</b> '+js.pred_label+' ('+(js.pred_label===1?'illicit':'licit')+')</p>'+
        (js.true_label!==undefined? '<p><b>true_label:</b> '+js.true_label+'</p>':'')+
        '<p><b>threshold:</b> '+js.threshold.toFixed(3)+'</p>';

      // features
      const tb = document.querySelector('#featTable tbody');
      tb.innerHTML = '';
      (js.top_features||[]).forEach(r=>{
        const tr = document.createElement('tr');
        tr.innerHTML = '<td>'+r.feature+'</td><td>'+r.score.toFixed(6)+'</td>';
        tb.appendChild(tr);
      });
    }
    async function score(){
      const txid = document.getElementById('txid').value.trim();
      if(!txid){ alert('Enter txId'); return; }
      try{
        const js = await postJSON('/score', {mode:'txid', txId: txid});
        setResult(js);
        document.getElementById('graph').innerHTML = '';
      }catch(e){
        document.getElementById('result').innerHTML = '<pre>'+e.toString()+'</pre>';
      }
    }
    async function explain(){
      const txid = document.getElementById('txid').value.trim();
      if(!txid){ alert('Enter txId'); return; }
      try{
        const js = await postJSON('/explain', {txId: txid, k:2, max_nodes:150});
        setResult(js);
        if(js.subgraph_png_base64){
          document.getElementById('graph').innerHTML = '<img src="data:image/png;base64,'+js.subgraph_png_base64+'"/>';
        }else{
          document.getElementById('graph').innerHTML = '<pre>'+JSON.stringify(js.subgraph, null, 2)+'</pre>';
        }
      }catch(e){
        document.getElementById('graph').innerHTML = '<pre>'+e.toString()+'</pre>';
      }
    }
    document.getElementById('btnScore').addEventListener('click', score);
    document.getElementById('btnExplain').addEventListener('click', explain);
  </script>
</body>
</html>
'''

REQS = "\n".join([
    "fastapi==0.111.0",
    "uvicorn[standard]==0.30.1",
    "pydantic==2.8.2",
    "numpy>=1.24",
    "pandas>=2.0.0",
    "networkx>=3.2",
    "matplotlib>=3.7.0",
    "keras>=3.3.0",
    "tensorflow==2.15.0.post1",
    "python-multipart>=0.0.9",
]) + "\n"

RUN_BAT = r'''@echo off
cd /d "C:\Users\sagni\Downloads\GraphGuard"
python -m pip install --upgrade pip
pip install -r requirements_api.txt
uvicorn app:app --host 0.0.0.0 --port 8000
'''

# write files
write(os.path.join(BASE, "app.py"), APP_PY)
write(os.path.join(BASE, "index.html"), INDEX_HTML)
write(os.path.join(BASE, "requirements_api.txt"), REQS)
write(os.path.join(BASE, "run_api.bat"), RUN_BAT)

print("[OK] Wrote:")
print(" -", os.path.join(BASE, "app.py"))
print(" -", os.path.join(BASE, "index.html"))
print(" -", os.path.join(BASE, "requirements_api.txt"))
print(" -", os.path.join(BASE, "run_api.bat"))

print("\nStart the API:")
print(r'  cd "C:\Users\sagni\Downloads\GraphGuard"')
print(r'  pip install -r requirements_api.txt')
print(r'  uvicorn app:app --host 0.0.0.0 --port 8000')
print("Open: http://localhost:8000 (redirects to /static/index.html)")


[OK] Wrote:
 - C:\Users\sagni\Downloads\GraphGuard\app.py
 - C:\Users\sagni\Downloads\GraphGuard\index.html
 - C:\Users\sagni\Downloads\GraphGuard\requirements_api.txt
 - C:\Users\sagni\Downloads\GraphGuard\run_api.bat

Start the API:
  cd "C:\Users\sagni\Downloads\GraphGuard"
  pip install -r requirements_api.txt
  uvicorn app:app --host 0.0.0.0 --port 8000
Open: http://localhost:8000 (redirects to /static/index.html)
