In [None]:
#!/usr/bin/env python3
"""
A single-file Python script demonstrating a gridworld-based IRL experiment,
with all outputs saved to a 'results/' folder.

Sections (aligned with paper flow):
1) Environment & Expert Data Generation
2) IRL Algorithms (NFXP, MaxEnt)
3) Route Recommendation via Standard (Hard) Value Iteration
4) Metrics & Evaluation
5) Main Execution

Usage:
  python main.py
"""

import os
import numpy as np
import random
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import time
from collections import defaultdict
from scipy.optimize import minimize
from scipy.special import logsumexp
from scipy.stats import spearmanr
from typing import List, Tuple

# Attempt to import tabulate for nicer table printing in final output
try:
    from tabulate import tabulate
    USE_TABULATE = True
except ImportError:
    USE_TABULATE = False

###############################################################################
# Make sure a 'results/' folder exists to store outputs
###############################################################################
os.makedirs("results", exist_ok=True)
print("[main] Ensuring 'results/' directory exists.")

###############################################################################
# (1) ENVIRONMENT & EXPERT DATA GENERATION
###############################################################################
class GridMDP:
    """
    A simple Grid MDP:
      - Rewards in [0,1] for each cell, except start=0.05 and goal=1.0 (or higher).
      - We'll gather 'expert' data from a soft-optimal policy on the TRUE rewards.
    """
    def __init__(self, size=5, discount_factor=0.95, seed=42):
        print(f"[GridMDP] Initializing with size={size}, discount_factor={discount_factor}, seed={seed}")
        np.random.seed(seed)
        random.seed(seed)
        self.size = size
        self.discount_factor = discount_factor
        self.true_rewards = np.zeros((size, size), dtype=float)
        self.actions = ['up', 'down', 'left', 'right']
        self._init_rewards()

    def _init_rewards(self):
        """
        Initialize cell-wise rewards in [0,1], with special start & goal cells.
        """
        print("[GridMDP] Randomizing initial rewards in [0,1].")
        for r in range(self.size):
            for c in range(self.size):
                self.true_rewards[r, c] = np.random.rand()
        print("[GridMDP] Setting start=(0,0) reward to 0.05 and goal=(size-1,size-1) reward to 1.0.")
        self.true_rewards[0, 0] = 0.05
        self.true_rewards[self.size - 1, self.size - 1] = 1.0

    def step(self, state: Tuple[int, int], action: str) -> Tuple[int, int]:
        (r, c) = state
        if action == 'up' and r > 0:
            r -= 1
        elif action == 'down' and r < self.size - 1:
            r += 1
        elif action == 'left' and c > 0:
            c -= 1
        elif action == 'right' and c < self.size - 1:
            c += 1
        return (r, c)

    def get_reward(self, state: Tuple[int, int]) -> float:
        return self.true_rewards[state]

    def random_start_state(self) -> Tuple[int, int]:
        r = random.randint(0, self.size - 1)
        c = random.randint(0, self.size - 1)
        return (r, c)

    def shape(self) -> Tuple[int, int]:
        return (self.size, self.size)

    def valid_states(self) -> List[Tuple[int,int]]:
        return [(r, c) for r in range(self.size) for c in range(self.size)]


def soft_value_iteration(mdp: GridMDP,
                        rewards_2d: np.ndarray,
                        discount_factor: float,
                        max_iter=200,
                        tol=1e-7) -> np.ndarray:
    """
    Soft Value Iteration to generate 'soft-optimal' data:
      V(s) = log( sum_a exp( R(s') + gamma V(s') ) ).
    """
    states = mdp.valid_states()
    V_dict = {s: 0.0 for s in states}
    print(f"[soft_value_iteration] Starting with max_iter={max_iter}, tol={tol}.")

    for it in range(max_iter):
        V_old = dict(V_dict)
        delta = 0.0
        for s in states:
            vs = []
            for a in mdp.actions:
                s_next = mdp.step(s, a)
                r_next = rewards_2d[s_next]
                vs.append(r_next + discount_factor * V_old[s_next])
            new_val = logsumexp(vs)
            delta = max(delta, abs(new_val - V_dict[s]))
            V_dict[s] = new_val
        if delta < tol:
            print(f"[soft_value_iteration] Converged after {it+1} iterations (delta={delta:.3g}).")
            break

    V_arr = np.zeros((mdp.size, mdp.size), dtype=float)
    for s in states:
        V_arr[s] = V_dict[s]
    return V_arr


def soft_policy(mdp: GridMDP, V_2d: np.ndarray, rewards_2d: np.ndarray) -> dict:
    """
    Pi(a|s) ~ exp( R(s') + gamma * V(s') )
    """
    states = mdp.valid_states()
    pol = {}
    print("[soft_policy] Building soft policy from value function.")
    for s in states:
        vs = []
        for a in mdp.actions:
            s_next = mdp.step(s, a)
            r_next = rewards_2d[s_next]
            vs.append(r_next + mdp.discount_factor * V_2d[s_next])
        log_probs = vs - logsumexp(vs)
        pol[s] = dict(zip(mdp.actions, np.exp(log_probs)))
    return pol


def generate_trajectories(mdp: GridMDP,
                          N=20,
                          max_steps=15,
                          policy_type="soft",
                          seed=42) -> List[List[Tuple[Tuple[int,int], str, float]]]:
    """
    Creates 'expert' trajectories from either a soft or a deterministic optimal policy.
    """
    print(f"[generate_trajectories] N={N}, max_steps={max_steps}, policy_type={policy_type}, seed={seed}.")
    np.random.seed(seed)
    random.seed(seed)

    # Soft VI on the true reward
    V_true = soft_value_iteration(mdp, mdp.true_rewards, mdp.discount_factor)
    pi_soft = soft_policy(mdp, V_true, mdp.true_rewards)

    # Possibly convert to a deterministic, greedy policy
    if policy_type == "optimal":
        pi_dict = {}
        states = mdp.valid_states()
        for s in states:
            vs = []
            for a in mdp.actions:
                s_next = mdp.step(s, a)
                vs.append(mdp.true_rewards[s_next] + mdp.discount_factor * V_true[s_next])
            best_idx = np.argmax(vs)
            best_action = mdp.actions[best_idx]
            pi_dict[s] = {a: (1.0 if a == best_action else 0.0) for a in mdp.actions}
    else:
        pi_dict = pi_soft

    trajectories = []
    for i in range(N):
        traj = []
        s_current = mdp.random_start_state()
        for t in range(max_steps):
            acts = list(pi_dict[s_current].keys())
            probs = list(pi_dict[s_current].values())
            chosen_action = np.random.choice(acts, p=probs)
            s_next = mdp.step(s_current, chosen_action)
            r_next = mdp.get_reward(s_next)
            traj.append((s_current, chosen_action, r_next))
            s_current = s_next
        trajectories.append(traj)

    print(f"[generate_trajectories] Generated {N} trajectories (policy_type={policy_type}).")
    return trajectories


def save_trajectories_csv(trajectories, filename="results/expert_data.csv"):
    """
    Save trajectory data to CSV in 'results/' folder.
    """
    rows = []
    for i, traj in enumerate(trajectories):
        for t, (s, a, r) in enumerate(traj):
            rows.append([i, t, s[0], s[1], a, r])
    df = pd.DataFrame(rows, columns=["trajectory_id","step","row","col","action","reward"])
    df.to_csv(filename, index=False)
    print(f"[save_trajectories_csv] Saved {len(trajectories)} trajectories to {filename}.")


###############################################################################
# (2) IRL ALGORITHMS (NFXP, MAXENT)
###############################################################################
def reshape_1d_to_2d(vec: np.ndarray, size: int) -> np.ndarray:
    return vec.reshape((size, size))

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# NFXP
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def nfxp_log_likelihood(param_vec: np.ndarray,
                        mdp: GridMDP,
                        trajectories) -> float:
    """
    Negative log-likelihood for NFXP:
      1) Soft VI on param_2d
      2) Induce policy
      3) Compare log-likelihood to expert data
    """
    size = mdp.size
    param_2d = reshape_1d_to_2d(param_vec, size)

    V_2d = soft_value_iteration(mdp, param_2d, mdp.discount_factor)
    pi_dict = {}
    states = mdp.valid_states()
    for s in states:
        vs = []
        for a in mdp.actions:
            s_next = mdp.step(s, a)
            r_next = param_2d[s_next]
            vs.append(r_next + mdp.discount_factor * V_2d[s_next])
        log_probs = vs - logsumexp(vs)
        pi_dict[s] = np.exp(log_probs)

    eps = 1e-12
    total_ll = 0
    for traj in trajectories:
        for (s, a, _) in traj:
            a_idx = mdp.actions.index(a)
            total_ll += np.log(pi_dict[s][a_idx] + eps)

    return -total_ll


def estimate_nfxp(mdp: GridMDP,
                  trajectories,
                  max_iter=300) -> np.ndarray:
    """
    Minimizes negative log-likelihood by L-BFGS-B with param in [0,1].
    """
    print("[estimate_nfxp] Estimating reward via nested fixed-point approach.")
    size = mdp.size
    param_init = np.random.uniform(0.0, 1.0, size*size)
    bounds = [(0,1)] * (size*size)

    t0 = time.time()
    result = minimize(
        nfxp_log_likelihood,
        param_init,
        args=(mdp, trajectories),
        method='L-BFGS-B',
        bounds=bounds,
        options={'maxiter': max_iter, 'disp': False}
    )
    elapsed = time.time() - t0
    print(f"[estimate_nfxp] L-BFGS took {elapsed:.2f}s for NxFP estimation.")

    param_est = reshape_1d_to_2d(result.x, size)
    return param_est


def save_rewards_csv(reward_2d: np.ndarray, filename="results/nfxp_est_rewards.csv"):
    df = pd.DataFrame(reward_2d)
    df.to_csv(filename, index=False)
    print(f"[save_rewards_csv] NFXP estimated rewards saved to {filename}.")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# MAXENT
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def collect_state_visitation(mdp: GridMDP,
                             pi_dict: dict,
                             discount_factor: float,
                             T=15) -> dict:
    """
    Approx. state visitation frequencies by random rollouts from pi_dict.
    """
    freq = defaultdict(float)
    n_roll = 1000
    for _ in range(n_roll):
        s = mdp.random_start_state()
        for t in range(T):
            freq[s] += (discount_factor ** t)
            pvals = pi_dict[s]
            a_idx = np.random.choice(len(pvals), p=pvals)
            a = mdp.actions[a_idx]
            s = mdp.step(s, a)
    for s in freq:
        freq[s] /= n_roll
    return freq


def feature_vector(mdp: GridMDP, s: Tuple[int,int]) -> np.ndarray:
    """
    1-hot representation for cell s -> dimension = size*size.
    """
    size = mdp.size
    vec = np.zeros(size*size)
    idx = s[0]*size + s[1]
    vec[idx] = 1.0
    return vec


def compute_expert_feature_expectation(mdp: GridMDP,
                                       trajectories,
                                       discount_factor) -> np.ndarray:
    """
    mu_D = average_{trajectories} sum_{t} gamma^t * one_hot(s_t)
    """
    print("[compute_expert_feature_expectation] Building expert feature expectations.")
    size = mdp.size
    mu_D = np.zeros(size*size)
    N = len(trajectories)
    for traj in trajectories:
        for t, (s, a, r) in enumerate(traj):
            mu_D += (discount_factor ** t) * feature_vector(mdp, s)
    mu_D /= float(N)
    return mu_D


def maxent_estimation(mdp: GridMDP,
                      trajectories,
                      discount_factor=0.95,
                      learning_rate=0.01,
                      max_iter=50) -> np.ndarray:
    """
    Basic MaxEnt IRL with param in [0,1] for each cell, gradient-based matching of mu_D - mu_pi.
    """
    print("[maxent_estimation] Estimating reward via maximum entropy IRL.")
    size = mdp.size
    mu_D = compute_expert_feature_expectation(mdp, trajectories, discount_factor)
    param = np.random.uniform(0,1,size*size)

    t0 = time.time()
    for it in range(max_iter):
        R_2d = reshape_1d_to_2d(param, size)
        V_2d = soft_value_iteration(mdp, R_2d, discount_factor)

        pi_dict = {}
        for s in mdp.valid_states():
            vs = []
            for a in mdp.actions:
                s_next = mdp.step(s, a)
                vs.append(R_2d[s_next] + discount_factor * V_2d[s_next])
            log_probs = vs - logsumexp(vs)
            pi_dict[s] = np.exp(log_probs)

        freq = collect_state_visitation(mdp, pi_dict, discount_factor, T=15)
        mu_pi = np.zeros(size*size)
        for s, val in freq.items():
            idx = s[0]*size + s[1]
            mu_pi[idx]+= val

        grad = mu_D - mu_pi
        param += learning_rate * grad
        param = np.clip(param, 0, 1)

        if it % 10==0:
            loss = np.linalg.norm(grad)
            print(f"[maxent] iter={it}, loss={loss:.4f}")
            if loss < 1e-4:
                break

    elapsed = time.time() - t0
    print(f"[maxent_estimation] completed in {elapsed:.2f}s")
    return reshape_1d_to_2d(param, size)


def save_maxent_rewards_csv(reward_2d: np.ndarray, filename="results/maxent_est_rewards.csv"):
    df = pd.DataFrame(reward_2d)
    df.to_csv(filename, index=False)
    print(f"[save_maxent_rewards_csv] MaxEnt estimated rewards saved to {filename}.")


###############################################################################
# (3) ROUTE RECOMMENDATION (HARD VALUE ITERATION)
###############################################################################
def standard_value_iteration(mdp: GridMDP,
                             reward_2d: np.ndarray,
                             discount_factor: float,
                             max_iter=500,
                             tol=1e-7) -> np.ndarray:
    """
    Hard-optimal value iteration: V(s) = max_a [ R(s') + gamma V(s') ].
    """
    print("[standard_value_iteration] Running hard-optimal value iteration.")
    states = mdp.valid_states()
    V_dict = {s:0.0 for s in states}

    for it in range(max_iter):
        V_old = dict(V_dict)
        delta = 0.0
        for s in states:
            best_val=-1e9
            for a in mdp.actions:
                s_next = mdp.step(s, a)
                val = reward_2d[s_next] + discount_factor * V_old[s_next]
                if val>best_val:
                    best_val=val
            delta = max(delta, abs(best_val - V_dict[s]))
            V_dict[s] = best_val
        if delta<tol:
            print(f"[standard_value_iteration] Converged after {it+1} iterations (delta={delta:.3g}).")
            break

    V_arr = np.zeros((mdp.size, mdp.size), dtype=float)
    for s in states:
        V_arr[s] = V_dict[s]
    return V_arr


def get_hard_greedy_path(mdp: GridMDP,
                         reward_2d: np.ndarray,
                         blocked_cells=None,
                         max_steps=30,
                         discount_factor=None) -> List[Tuple[int,int]]:
    """
    1) Standard value iteration on reward_2d, ignoring transitions to blocked cells.
    2) From (0,0), greedily pick next cell for up to max_steps or until goal.
    """
    print("[get_hard_greedy_path] Computing path with blocked_cells=", blocked_cells)
    if discount_factor is None:
        discount_factor=mdp.discount_factor

    # local function to skip blocked cells
    def step_blocked(s, a):
        (r,c)=s
        if blocked_cells and (r,c) in blocked_cells:
            return s
        if a=='up'and r>0:
            nxt=(r-1,c)
            if blocked_cells and nxt in blocked_cells:
                return s
            return nxt
        elif a=='down'and r<mdp.size-1:
            nxt=(r+1,c)
            if blocked_cells and nxt in blocked_cells:
                return s
            return nxt
        elif a=='left'and c>0:
            nxt=(r,c-1)
            if blocked_cells and nxt in blocked_cells:
                return s
            return nxt
        elif a=='right'and c<mdp.size-1:
            nxt=(r,c+1)
            if blocked_cells and nxt in blocked_cells:
                return s
            return nxt
        return s

    # Value iteration respecting blocked cells
    states=mdp.valid_states()
    V_dict={s:0.0 for s in states}
    for _ in range(200):
        V_old=dict(V_dict)
        delta=0.0
        for s in states:
            best_val=-1e9
            for a in mdp.actions:
                s_next=step_blocked(s,a)
                rew=reward_2d[s_next]
                val=rew+discount_factor*V_old[s_next]
                if val>best_val:
                    best_val=val
            delta=max(delta,abs(best_val-V_dict[s]))
            V_dict[s]=best_val
        if delta<1e-7:
            break

    # Greedily extract a path
    path=[]
    s_cur=(0,0)
    for t in range(max_steps):
        path.append(s_cur)
        if s_cur==(mdp.size-1,mdp.size-1):
            break
        best_a=None
        best_val=-1e9
        for a in mdp.actions:
            s_next=step_blocked(s_cur,a)
            val=reward_2d[s_next]+discount_factor*V_dict[s_next]
            if val>best_val:
                best_val=val
                best_a=a
        s_cur=step_blocked(s_cur,best_a)
    return path

def get_hard_policy_matrix(mdp: GridMDP, reward_2d: np.ndarray)->np.ndarray:
    """
    Hard-optimal policy matrix from standard_value_iteration for each cell.
    """
    action_map={'up':0,'down':1,'left':2,'right':3}
    V_2d = standard_value_iteration(mdp,reward_2d,mdp.discount_factor)
    mat=np.zeros((mdp.size,mdp.size),dtype=int)

    states=mdp.valid_states()
    for s in states:
        (r,c)=s
        best_val=-1e9
        best_a='up'
        for a in mdp.actions:
            s_next=mdp.step(s,a)
            val=reward_2d[s_next]+mdp.discount_factor*V_2d[s_next]
            if val>best_val:
                best_val=val
                best_a=a
        mat[r,c]=action_map[best_a]
    return mat

###############################################################################
# (4) METRICS & EVALUATION
###############################################################################
def compute_true_discounted_return(mdp: GridMDP, path)->float:
    """ sum_{t} gamma^t * R_true(path[t]) """
    ret=0.0
    gamma_t=1.0
    for s in path:
        ret+= gamma_t * mdp.true_rewards[s]
        gamma_t*= mdp.discount_factor
    return ret

def compute_est_discounted_return(reward_2d: np.ndarray, mdp: GridMDP, path)->float:
    """ sum_{t} gamma^t * R_est(path[t]) """
    ret=0.0
    gamma_t=1.0
    for s in path:
        ret+= gamma_t*reward_2d[s]
        gamma_t*= mdp.discount_factor
    return ret

def reward_rmse(true_rewards_2d, est_rewards_2d)->float:
    return float(np.sqrt(np.mean((true_rewards_2d - est_rewards_2d)**2)))

def reward_rank_corr(true_rewards_2d, est_rewards_2d)->float:
    t=true_rewards_2d.flatten()
    e=est_rewards_2d.flatten()
    corr,_= spearmanr(t,e)
    return float(corr)

def policy_disagreement(mdp:GridMDP,
                        R_true_2d: np.ndarray,
                        R_est_2d: np.ndarray) -> float:
    """
    Compare the hard-optimal policy for R_true_2d vs. R_est_2d
    by fraction of states that differ in best action.
    """
    pol_true = get_hard_policy_matrix(mdp, R_true_2d)
    pol_est  = get_hard_policy_matrix(mdp, R_est_2d)
    states=mdp.valid_states()
    diff=0
    for s in states:
        if pol_true[s[0],s[1]]!= pol_est[s[0],s[1]]:
            diff+=1
    return diff/len(states)

###############################################################################
# (5) MAIN EXECUTION
###############################################################################
def main():
    print("[main] Starting main execution...")

    # Hyperparameters
    K=5
    N=10
    max_steps=15
    seed=42

    # 0) Build MDP
    mdp=GridMDP(size=K, discount_factor=0.95, seed=seed)
    mdp.true_rewards[K-1, K-1] = 20.0  # Increase goal reward for clarity
    print(f"[main] Created GridMDP shape={mdp.shape()}, discount={mdp.discount_factor:.2f}")

    # 1) Gather Expert Data
    print("[main] Generating expert data...")
    trajectories=generate_trajectories(mdp, N=N, max_steps=max_steps, policy_type="soft", seed=seed)
    save_trajectories_csv(trajectories, filename="results/expert_data.csv")

    # 2) Estimate Rewards: NFXP
    print("[main] Estimating NFXP rewards...")
    t0=time.time()
    nfxp_est=estimate_nfxp(mdp, trajectories, max_iter=500)
    time_nfxp=time.time()-t0
    save_rewards_csv(nfxp_est, filename="results/nfxp_est_rewards.csv")

    # 3) Estimate Rewards: MaxEnt
    print("[main] Estimating MaxEnt rewards...")
    t1=time.time()
    maxent_est=maxent_estimation(mdp, trajectories, mdp.discount_factor, learning_rate=0.01, max_iter=50)
    time_maxent=time.time()-t1
    save_maxent_rewards_csv(maxent_est, filename="results/maxent_est_rewards.csv")

    # 4) Evaluate reward quality (RMSE, rank correlation, policy difference)
    print("[main] Computing evaluation metrics (RMSE, rank corr, policy diff).")
    rmse_nfxp  = reward_rmse(mdp.true_rewards, nfxp_est)
    rmse_maxent= reward_rmse(mdp.true_rewards, maxent_est)
    rcorr_nfxp   = reward_rank_corr(mdp.true_rewards, nfxp_est)
    rcorr_maxent = reward_rank_corr(mdp.true_rewards, maxent_est)
    pdiff_nfxp   = policy_disagreement(mdp, mdp.true_rewards, nfxp_est)
    pdiff_maxent = policy_disagreement(mdp, mdp.true_rewards, maxent_est)

    # 5) Evaluate route recommendations in unblocked scenario
    print("[main] Building route recommendations in unblocked scenario...")
    route_gt    = get_hard_greedy_path(mdp, mdp.true_rewards, None, max_steps)
    route_nfxp  = get_hard_greedy_path(mdp, nfxp_est,       None, max_steps)
    route_maxent= get_hard_greedy_path(mdp, maxent_est,     None, max_steps)

    gt_est_unblocked    = compute_est_discounted_return(mdp.true_rewards, mdp, route_gt)
    gt_true_unblocked   = compute_true_discounted_return(mdp, route_gt)
    nfxp_est_unblocked  = compute_est_discounted_return(nfxp_est, mdp, route_nfxp)
    nfxp_true_unblocked = compute_true_discounted_return(mdp, route_nfxp)
    maxent_est_unblocked= compute_est_discounted_return(maxent_est, mdp, route_maxent)
    maxent_true_unblocked= compute_true_discounted_return(mdp, route_maxent)

    # 6) Evaluate route recommendations in blocked scenario
    print("[main] Building route recommendations in blocked scenario...")
    blocked_cells=None
    if K>=5:
        blocked_cells=[(2,2),(2,3),(3,2),(3,3)]
    route_gt_block     = get_hard_greedy_path(mdp, mdp.true_rewards, blocked_cells, max_steps)
    route_nfxp_block   = get_hard_greedy_path(mdp, nfxp_est,         blocked_cells, max_steps)
    route_maxent_block = get_hard_greedy_path(mdp, maxent_est,       blocked_cells, max_steps)

    gt_est_block       = compute_est_discounted_return(mdp.true_rewards, mdp, route_gt_block)
    gt_true_block      = compute_true_discounted_return(mdp, route_gt_block)
    nfxp_est_block     = compute_est_discounted_return(nfxp_est, mdp, route_nfxp_block)
    nfxp_true_block    = compute_true_discounted_return(mdp, route_nfxp_block)
    maxent_est_block   = compute_est_discounted_return(maxent_est, mdp, route_maxent_block)
    maxent_true_block  = compute_true_discounted_return(mdp, route_maxent_block)

    # 7) Visualization of final policies
    print("[main] Creating final policy comparison figure...")
    pol_gt     = get_hard_policy_matrix(mdp, mdp.true_rewards)
    pol_nfxp   = get_hard_policy_matrix(mdp, nfxp_est)
    pol_maxent = get_hard_policy_matrix(mdp, maxent_est)

    fig, axes = plt.subplots(1,3, figsize=(15,5))
    plot_reward_and_policy(axes[0], mdp.true_rewards, pol_gt,     title="GroundTruth (Hard Policy)")
    plot_reward_and_policy(axes[1], nfxp_est,         pol_nfxp,   title="NFXP (Hard Policy)")
    plot_reward_and_policy(axes[2], maxent_est,       pol_maxent, title="MaxEnt (Hard Policy)")
    plt.tight_layout()
    out_fig1 = "results/reward_policy_comparison.png"
    plt.savefig(out_fig1, dpi=150)
    print(f"[main] Saved figure: {out_fig1}")
    plt.close(fig)

    # Compare unblocked vs blocked routes for NFXP/MaxEnt
    print("[main] Creating recommended paths figure (unblocked vs. blocked).")
    fig2, axes2 = plt.subplots(1,2, figsize=(12,5))
    plot_two_paths(axes2[0], K, route_nfxp, route_maxent,
                   blocked_cells=None,
                   title="Unblocked: NFXP(blue) vs MaxEnt(red)")
    plot_two_paths(axes2[1], K, route_nfxp_block, route_maxent_block,
                   blocked_cells=blocked_cells,
                   title="Blocked: NFXP(blue) vs MaxEnt(red)")
    plt.tight_layout()
    out_fig2 = "results/recommended_paths.png"
    plt.savefig(out_fig2, dpi=150)
    print(f"[main] Saved figure: {out_fig2}")
    plt.close(fig2)

    # 8) Prepare final results table
    print("[main] Preparing final results table and saving to CSV.")
    time_gt=0.0
    rmse_gt=0.0
    rcorr_gt=1.0
    pdiff_gt=0.0

    data_table = [
        ["Hyperparams:", f"K={K}, N={N}, max_steps={max_steps}", "", ""],
        ["Method","NFXP","MaxEnt","GroundTruth",""],
        ["Time (sec)",f"{time_nfxp:.3f}",f"{time_maxent:.3f}",f"{time_gt:.3f}",""],
        ["RMSE(Reward)",f"{rmse_nfxp:.3f}",f"{rmse_maxent:.3f}",f"{rmse_gt:.3f}",""],
        ["RankCorr(Reward)",f"{rcorr_nfxp:.3f}",f"{rcorr_maxent:.3f}",f"{rcorr_gt:.3f}",""],
        ["PolDiff(Reward)",f"{pdiff_nfxp:.3f}",f"{pdiff_maxent:.3f}",f"{pdiff_gt:.3f}",""],
        ["--Unblocked--","NFXP","MaxEnt","GroundTruth",""],
        ["Return (Est R)",f"{nfxp_est_unblocked:.3f}",f"{maxent_est_unblocked:.3f}",f"{gt_est_unblocked:.3f}",""],
        ["Return (True R)",f"{nfxp_true_unblocked:.3f}",f"{maxent_true_unblocked:.3f}",f"{gt_true_unblocked:.3f}",""],
        ["--Blocked--","NFXP","MaxEnt","GroundTruth",""],
        ["Return (Est R)",f"{nfxp_est_block:.3f}",f"{maxent_est_block:.3f}",f"{gt_est_block:.3f}",""],
        ["Return (True R)",f"{nfxp_true_block:.3f}",f"{maxent_true_block:.3f}",f"{gt_true_block:.3f}",""]
    ]

    # Print table in console
    if USE_TABULATE:
        print("\nFINAL RESULTS\n")
        print(tabulate(data_table, tablefmt="github"))
    else:
        print("\nFINAL RESULTS (Fallback)\n")
        for row in data_table:
            print(" | ".join(map(str, row)))

    # Also save final results as CSV
    # We'll assume the first column is a label row, subsequent columns are data
    final_csv_path = "results/final_results.csv"
    all_rows = []
    for row in data_table:
        # Make sure everything is string
        all_rows.append([str(x) for x in row])

    df_table = pd.DataFrame(all_rows)
    df_table.to_csv(final_csv_path, index=False, header=False)
    print(f"[main] Saved final results to {final_csv_path}")

    print("[main] Done. See results above.")

###############################################################################
# Additional Helpers for Visualization
###############################################################################
def plot_reward_and_policy(ax, reward_2d: np.ndarray, policy_mat: np.ndarray, title=""):
    """
    Display the reward grid plus an arrow for each cell's best action.
    action_map: 0=up,1=down,2=left,3=right
    """
    size = reward_2d.shape[0]
    im = ax.imshow(reward_2d, cmap='viridis', origin='upper')
    ax.set_title(title)
    ax.set_xticks(range(size))
    ax.set_yticks(range(size))
    action_arrow = {
        0: (0, -0.3),   # up
        1: (0,  0.3),   # down
        2: (-0.3, 0),   # left
        3: (0.3,  0)    # right
    }
    for r in range(size):
        for c in range(size):
            a = policy_mat[r,c]
            dx, dy = action_arrow[a]
            ax.arrow(c, r, dx, dy,
                     color='white', head_width=0.1,
                     head_length=0.1, length_includes_head=True)

def plot_two_paths(ax, size,
                   path_nfxp, path_maxent,
                   blocked_cells=None,
                   title=""):
    """
    Show NFXP path (blue) vs. MaxEnt path (red) in a single plot.
    If blocked_cells is given, shade them in gray.
    """
    ax.set_xlim(-0.5, size-0.5)
    ax.set_ylim(-0.5, size-0.5)
    ax.set_xticks(range(size))
    ax.set_yticks(range(size))
    ax.invert_yaxis()
    ax.set_aspect('equal','box')
    ax.set_title(title)

    # Draw grid lines
    for i in range(size+1):
        ax.axhline(i-0.5,color='black')
        ax.axvline(i-0.5,color='black')

    if blocked_cells:
        for (r,c) in blocked_cells:
            ax.fill_between([c-0.5,c+0.5], r-0.5, r+0.5, color='gray', alpha=0.5)

    # NFXP path
    xs_n = [s[1] for s in path_nfxp]
    ys_n = [s[0] for s in path_nfxp]
    ax.plot(xs_n, ys_n, 'o-', color='blue', label='NFXP path')

    # MaxEnt path
    xs_m = [s[1] for s in path_maxent]
    ys_m = [s[0] for s in path_maxent]
    ax.plot(xs_m, ys_m, 's-', color='red', label='MaxEnt path')

    ax.legend()

###############################################################################
# Execute main()
###############################################################################
if __name__ == "__main__":
    main()


[main] Ensuring 'results/' directory exists.
[main] Starting main execution...
[GridMDP] Initializing with size=5, discount_factor=0.95, seed=42
[GridMDP] Randomizing initial rewards in [0,1].
[GridMDP] Setting start=(0,0) reward to 0.05 and goal=(size-1,size-1) reward to 1.0.
[main] Created GridMDP shape=(5, 5), discount=0.95
[main] Generating expert data...
[generate_trajectories] N=10, max_steps=15, policy_type=soft, seed=42.
[soft_value_iteration] Starting with max_iter=200, tol=1e-07.
[soft_policy] Building soft policy from value function.
[generate_trajectories] Generated 10 trajectories (policy_type=soft).
[save_trajectories_csv] Saved 10 trajectories to results/expert_data.csv.
[main] Estimating NFXP rewards...
[estimate_nfxp] Estimating reward via nested fixed-point approach.
[soft_value_iteration] Starting with max_iter=200, tol=1e-07.
[soft_value_iteration] Starting with max_iter=200, tol=1e-07.
[soft_value_iteration] Starting with max_iter=200, tol=1e-07.
[soft_value_iterat