<a href="https://colab.research.google.com/github/ravishelnaicker/Fresh-5/blob/main/Fresh_5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title NASim + (Optional) LLM Planner with Strict Validation — Colab Ready
# @markdown **Usage**
# @markdown 1) Choose a scenario (e.g., "tiny", "small", "medium-multi-site").
# @markdown 2) (Optional) Set USE_LLM=True and put your OpenAI API key in the environment.
# @markdown 3) Run. A step-by-step table will be printed and also saved to CSV.
SCENARIO_NAME = "tiny"            # @param ["tiny", "small", "medium", "large", "medium-multi-site", "large-multi-site"] {allow-input: true}
SEED = 123                        # @param {type:"integer"}
MAX_STEPS = 200                   # @param {type:"integer"}
USE_LLM = False                   # @param {type:"boolean"}
OPENAI_MODEL = "gpt-4o-mini"      # @param {type:"string"}
CSV_OUT = f"nasim_run_{SCENARIO_NAME}.csv"  # @param {type:"string"}

# --- Install deps
!pip -q install nasim pandas pyyaml openai > /dev/null

import os, re, json, math, textwrap, random
import pandas as pd
import numpy as np
import nasim

# OpenAI client (loaded lazily if USE_LLM)
try:
    from openai import OpenAI
except Exception:
    OpenAI = None

# ------------------------------
# Helper: load env
# ------------------------------
def load_env(name: str, seed=None, render=False, fully_obs=False):
    # flat_actions=True, flat_obs=True are the NASim defaults we want.
    # render_mode only matters if you want GUI-like rendering.
    env = nasim.make_benchmark(
        name,
        seed=seed,
        fully_obs=fully_obs,
        flat_actions=True,
        flat_obs=True,
        render_mode="human" if render else None
    )
    return env

# ------------------------------
# Read "optimal path" SOFT hints
# ------------------------------
def get_optimal_path_hints(scenario_name: str):
    """
    Reads the scenario YAML and tries to extract the 'Optimal path' comment.
    Returns {'services': set(...), 'processes': set(...)} used for soft re-ranking only.
    """
    hints = {"services": set(), "processes": set()}
    try:
        import os, nasim
        base = os.path.dirname(nasim.__file__)
        path = os.path.join(base, "scenarios", "benchmark", f"{scenario_name}.yaml")
        if os.path.exists(path):
            with open(path, "r", encoding="utf-8") as f:
                txt = f.read()
            # The tiny.yaml has a line like:
            # "# Optimal path: (e_ssh, (6,1)) -> subnet_scan -> ... -> (pe_tomcat, (2,1))"
            line = None
            for l in txt.splitlines():
                if "Optimal path" in l:
                    line = l
                    break
            if line:
                # Extract tokens like e_ssh / pe_tomcat
                expl = re.findall(r"\b(e_[A-Za-z0-9_]+)\b", line)
                pe   = re.findall(r"\b(pe_[A-Za-z0-9_]+)\b", line)
                # Map e_ssh -> service "ssh"; pe_tomcat -> process "tomcat"
                for e in expl:
                    svc = e.split("e_", 1)[1]
                    hints["services"].add(svc)
                for p in pe:
                    proc = p.split("pe_", 1)[1]
                    hints["processes"].add(proc)
    except Exception:
        pass
    return hints

# ------------------------------
# Scenario + action/state schema
# ------------------------------
def scenario_summary(env):
    sc = env.scenario
    # These properties are exposed by Scenario
    # (names mirror YAML keys: os, services, processes, exploits, privescs, hosts, firewall, etc.)
    return {
        "name": sc.name,
        "num_hosts": len(sc.hosts),
        "subnets": sc.subnets,
        "os_list": list(sc.os),
        "services_list": list(sc.services),
        "processes_list": list(sc.processes),
        "exploits": sc.exploits,         # dict: e_name -> {service, os, prob, cost, access}
        "privescs": sc.privescs,         # dict: pe_name -> {process, os, prob, cost, access}
        "sensitive_hosts": sc.sensitive_hosts,
        "step_limit": sc.step_limit
    }

# ------------------------------
# Action decoding / valid mask
# ------------------------------
def action_to_dict(env, idx: int):
    """Turn a NASim Action into a readable dict."""
    a = env.action_space.actions[idx]
    # classify by type using Action helpers available in NASim
    def kind(a):
        if hasattr(a, "is_exploit") and a.is_exploit(): return "exploit"
        if hasattr(a, "is_privesc") and a.is_privesc(): return "privesc"
        if hasattr(a, "is_service_scan") and a.is_service_scan(): return "service_scan"
        if hasattr(a, "is_os_scan") and a.is_os_scan(): return "os_scan"
        if hasattr(a, "is_subnet_scan") and a.is_subnet_scan(): return "subnet_scan"
        if hasattr(a, "is_process_scan") and a.is_process_scan(): return "process_scan"
        return "unknown"
    t = kind(a)
    # pull best-effort parameters
    def get_attr(obj, name):
        return getattr(obj, name, None)
    return {
        "index": idx,
        "type": t,
        "target": get_attr(a, "target"),
        "service": get_attr(a, "service"),
        "process": get_attr(a, "process"),
        "os": get_attr(a, "os"),
        "cost": get_attr(a, "cost"),
        "prob": get_attr(a, "prob"),
        "req_access": str(getattr(a, "req_access", ""))  # usually AccessLevel enum
    }

def valid_actions(env):
    """Return list[dict] for all valid actions under current state (action mask = 1)."""
    mask = env.get_action_mask()  # 1 = valid, 0 = invalid
    return [action_to_dict(env, i) for i, m in enumerate(mask) if int(m) == 1]

# ------------------------------
# Belief state + logging helpers
# ------------------------------
def access_to_str(x):
    # NASim uses AccessLevel enum; convert to human-ish labels
    try:
        name = str(x).lower()
        if "root" in name: return "root"
        if "user" in name: return "user"
        return "none"
    except Exception:
        return "none"

def update_belief_with_info(belief, info):
    """
    belief: dict[address] -> {access, os_known, services_known:set, processes_known:set, discovered:bool}
    info: action result info dict from NASim
    """
    # newly discovered hosts
    for addr in (info.get("newly_discovered") or []):
        belief.setdefault(tuple(addr), {"access":"none","os_known":None,"services_known":set(),"processes_known":set(),"discovered":True})
        belief[tuple(addr)]["discovered"] = True

    # os/services/processes discovered: dict[address]->dict[name->bool]
    for k, slot in [("os", "os_known"), ("services","services_known"), ("processes","processes_known")]:
        m = info.get(k) or {}
        for addr, vals in m.items():
            addr = tuple(addr)
            b = belief.setdefault(addr, {"access":"none","os_known":None,"services_known":set(),"processes_known":set(),"discovered":True})
            if k == "os":
                for os_name, present in vals.items():
                    if present:
                        b["os_known"] = os_name
            else:
                for name, present in vals.items():
                    if present:
                        b[slot].add(name)

    # access changes
    acc = info.get("access") or {}
    for addr, level in acc.items():
        addr = tuple(addr)
        b = belief.setdefault(addr, {"access":"none","os_known":None,"services_known":set(),"processes_known":set(),"discovered":True})
        b["access"] = access_to_str(level)

def compromised_hosts_from_belief(belief):
    return sorted([addr for addr, st in belief.items() if st.get("access") in ("user", "root")])

# ------------------------------
# Heuristic policy (deterministic, no-LLM)
# ------------------------------
def heuristic_choose(valids, belief, hints=None):
    """
    Rank valid actions with a simple priority:
    1) service_scan / os_scan on discovered-but-unknown hosts
    2) subnet_scan on hosts we already have access to (to expand reach)
    3) exploit actions where OS/services align with belief (and prob/cost good)
    4) process_scan on compromised hosts to unlock privesc
    5) privesc on compromised hosts
    6) fallback: any valid action
    """
    hints = hints or {"services": set(), "processes": set()}

    def score(a):
        t = a["type"]
        tgt = tuple(a["target"]) if a.get("target") else None
        st = belief.get(tgt, {})
        os_known = st.get("os_known")
        services_known = st.get("services_known", set())
        acc = st.get("access","none")

        s = 0.0
        # prefer not to repeat scans if we already know OS/services
        if t == "service_scan":
            s += 10.0
            # extra boost if we don't know services yet
            if not services_known: s += 5.0
        elif t == "os_scan":
            s += 9.0
            if not os_known: s += 4.0
        elif t == "subnet_scan":
            # subnet_scan requires on-host access; great for expansion
            if acc in ("user","root"):
                s += 8.5
        elif t == "exploit":
            # prefer aligned OS/services and high success/low cost
            if a.get("service") and (a["service"] in services_known):
                s += 8.0
            if a.get("os") and os_known and (a["os"] == os_known):
                s += 8.0
            # soft bias from 'optimal path' hints
            if a.get("service") in hints.get("services", set()):
                s += 2.0
            prob = a.get("prob") or 0.5
            cost = a.get("cost") or 1.0
            s += float(prob) * 2.0 - float(cost)*0.2
        elif t == "process_scan":
            if acc in ("user","root"):
                s += 6.0
        elif t == "privesc":
            if acc in ("user","root"):
                # soft bias for hinted processes
                if a.get("process") in hints.get("processes", set()):
                    s += 2.0
                s += 5.5
        return s

    ranked = sorted(valids, key=score, reverse=True)
    return ranked[0]["index"] if ranked else random.choice([v["index"] for v in valids])

# ------------------------------
# LLM planner (strict, masked)
# ------------------------------
def llm_choose(valids, belief, hints=None, tried_signature_block=None):
    """
    Present only VALID actions (post-mask), force JSON output {action_index:int, reason:str}.
    If output is invalid or repeats, fallback to heuristic.
    """
    if not USE_LLM:
        return heuristic_choose(valids, belief, hints)

    if OpenAI is None:
        return heuristic_choose(valids, belief, hints)

    client = OpenAI()  # expects OPENAI_API_KEY in env

    # Remove any actions we want to block due to recent repeats
    valids_filtered = []
    for a in valids:
        sig = (a["type"], tuple(a["target"]) if a.get("target") else None)
        if tried_signature_block and sig in tried_signature_block:
            continue
        valids_filtered.append(a)
    if not valids_filtered:
        valids_filtered = valids[:]  # if everything filtered, allow all

    # Cap list length for prompt economy
    MAX_LIST = 120
    sample = valids_filtered[:MAX_LIST]

    compromised = compromised_hosts_from_belief(belief)
    brief_state = {
        "compromised_hosts": compromised,
        "known": {
            str(addr): {
                "access": belief[addr].get("access"),
                "os": belief[addr].get("os_known"),
                "services": sorted(list(belief[addr].get("services_known", set()))),
                "processes": sorted(list(belief[addr].get("processes_known", set())))
            } for addr in belief
        }
    }

    sys = (
        "You are a cautious NASim pentest planner. "
        "Choose exactly ONE next action from the allowed list. "
        "Respect constraints: no repeat on same host unless a new compromise happened; "
        "prefer OS-appropriate exploits; process scans/priv-esc only after compromise; "
        "avoid obviously pointless actions. Output strict JSON: {\"action_index\": int, \"reason\": \"...\"}."
    )
    user = {
        "goal": "Maximise points in minimum steps by compromising sensitive hosts.",
        "priority_hints": hints or {},
        "current_state": brief_state,
        "allowed_actions": sample,   # pre-masked, all executable
    }

    try:
        resp = client.chat.completions.create(
            model=OPENAI_MODEL,
            messages=[
                {"role": "system", "content": sys},
                {"role": "user", "content": json.dumps(user)}
            ],
            response_format={"type":"json_object"}
        )
        payload = resp.choices[0].message.content.strip()
        data = json.loads(payload)
        idx = int(data.get("action_index"))
        # ensure chosen index is in the allowed sample (safety)
        allowed_idx = {a["index"] for a in sample}
        if idx in allowed_idx:
            return idx
    except Exception:
        # swallow and fallback
        pass

    return heuristic_choose(valids, belief, hints)

# ------------------------------
# Main episode runner
# ------------------------------
env = load_env(SCENARIO_NAME, seed=SEED)
obs, info = env.reset(seed=SEED)

summ = scenario_summary(env)
hints = get_optimal_path_hints(SCENARIO_NAME)

print(f"Loaded NASim benchmark: {summ['name']}  |  Hosts: {summ['num_hosts']}  |  Step limit: {summ['step_limit']}")
print(f"OS: {summ['os_list']}\nServices: {summ['services_list']}\nProcesses: {summ['processes_list']}")
print(f"Exploits: {list(summ['exploits'].keys())}\nPrivEsc: {list(summ['privescs'].keys())}")
if any(hints.values()):
    print(f"Soft hints from scenario header (not hard-coded): {hints}")

# Belief tracking + logging
belief = {}
total = 0.0
rows = []
tried_since_last_compromise = set()  # {(action_type, target)} to avoid repeats until new compromise

# initial discovery from reset info (if any)
if isinstance(info, dict):
    update_belief_with_info(belief, info)

done = False
truncated = False
for step in range(1, MAX_STEPS+1):
    # enumerate valid actions
    valids = valid_actions(env)
    if not valids:
        print("No valid actions — stopping.")
        break

    # choose via LLM (strict) or heuristic
    a_idx = llm_choose(
        valids, belief, hints=hints,
        tried_signature_block=tried_since_last_compromise
    )

    a_desc = next((v for v in valids if v["index"] == a_idx), None)
    sig = (a_desc["type"], tuple(a_desc["target"]) if a_desc.get("target") else None)

    # step
    next_obs, reward, done_flag, step_limit_reached, step_info = env.step(a_idx)
    total += float(reward)

    # extract new info and update belief
    prev_compromised = set(compromised_hosts_from_belief(belief))
    update_belief_with_info(belief, step_info)
    now_compromised = set(compromised_hosts_from_belief(belief))

    # If any NEW host got compromised, clear the anti-repeat block
    if now_compromised - prev_compromised:
        tried_since_last_compromise.clear()
    else:
        tried_since_last_compromise.add(sig)

    # compact "what's newly discovered" for the log
    newly = {
        "new_hosts": [tuple(x) for x in (step_info.get("newly_discovered") or [])],
        "os": step_info.get("os") or {},
        "services": step_info.get("services") or {},
        "processes": step_info.get("processes") or {},
        "access": {tuple(k): access_to_str(v) for k, v in (step_info.get("access") or {}).items()}
    }

    rows.append({
        "step": step,
        "action_index": a_idx,
        "action_type": a_desc["type"] if a_desc else None,
        "target": a_desc.get("target") if a_desc else None,
        "service": a_desc.get("service") if a_desc else None,
        "process": a_desc.get("process") if a_desc else None,
        "os_required": a_desc.get("os") if a_desc else None,
        "reward": reward,
        "total_points": total,
        "newly_discovered": json.dumps(newly),
        "compromised_hosts_now": json.dumps([list(x) for x in now_compromised]),
        "truncated_or_step_limit": bool(step_limit_reached),
    })

    obs = next_obs
    if done_flag or step_limit_reached:
        done = bool(done_flag)
        truncated = bool(step_limit_reached)
        break

# Results table
df = pd.DataFrame(rows)
pd.set_option("display.max_colwidth", 200)
print("\n=== Run Summary ===")
print(f"Steps taken: {len(df)}  |  Done: {done}  |  Step-limit/truncated: {truncated}  |  Total points: {df['total_points'].iloc[-1] if len(df) else 0.0}")
display(df)

# Save CSV
df.to_csv(CSV_OUT, index=False)
print(f"\nSaved step log to: {CSV_OUT}")
