In [35]:
import numpy as np

np.set_printoptions(precision=4, suppress=True)

# ---------------------------
# 1) Triangular MDP Setup
# ---------------------------
# States: 0,1,2 (think of them arranged in a triangle)
# Actions: 0=left, 1=right, 2=stay
# We'll define transitions with a small noise probability eps.
eps = 0.05
n_states = 3
n_actions = 3
gamma = 0.9

def build_transition_matrix(eps=0.05):
    """
    Triangular 3-state MDP with 3 actions.
    P[(s, a)] = probability distribution over next states [p0, p1, p2].
    
    Action 0 = 'left' 
    Action 1 = 'right'
    Action 2 = 'stay'
    """
    # Initialize dictionary
    P = {}
    # Helper to set distribution with main target and small noise to others
    def dist(target):
        # With prob (1 - eps) go to the target, with prob eps the rest is evenly distributed
        d = np.ones(n_states) * (eps / (n_states))
        d[target] += (1 - eps)
        return d
    
    # For each state s, define transitions for each action
    for s in range(n_states):
        # Action 'left'
        if s == 0:
            P[(0, 0)] = dist(target=2)  # from 0, left goes to 2 ideally
        elif s == 1:
            P[(1, 0)] = dist(target=0)  # from 1, left goes to 0
        else: # s==2
            P[(2, 0)] = dist(target=1)  # from 2, left goes to 1
        
        # Action 'right'
        if s == 0:
            P[(0, 1)] = dist(target=1)  # from 0, right -> 1
        elif s == 1:
            P[(1, 1)] = dist(target=2)  # from 1, right -> 2
        else: # s==2
            P[(2, 1)] = dist(target=0)  # from 2, right -> 0
        
        # Action 'stay'
        P[(s, 2)] = dist(target=s)     # from s, stay in s
    return P

# Construct the transitions
P = build_transition_matrix(eps)

# ---------------------------
# 2) Features & True Rewards
# ---------------------------
# We'll use one-hot features for each state: phi(s) = e_s
features = np.eye(n_states)

# Set a "true" reward for demonstration. 
# The user can choose any arbitrary vector.
true_rewards = np.array([0.5, 0.25, 0.10])

# ---------------------------
# 3) Soft Value Iteration
# ---------------------------
def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    """
    V(s) <- log sum_a exp( R(s) + gamma * sum_{s'} P(s'|s,a)*V(s') )
    """
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = []
            for a in range(n_actions):
                Q_sa.append(reward[s] + gamma * np.dot(P[(s,a)], V_prev))
            # log-sum-exp
            V[s] = np.log(np.sum(np.exp(Q_sa)))
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

# ---------------------------
# 4) Compute Policy
# ---------------------------
def compute_policy(V, reward, P):
    """
    pi(a|s) = exp(Q(s,a)) / sum_{a'} exp(Q(s,a'))
    Q(s,a) = R(s) + gamma * sum_{s'} P(s'|s,a]*V(s')
    """
    policy = np.zeros((n_states, n_actions))
    Q_values = np.zeros((n_states, n_actions))  # for debug printing
    for s in range(n_states):
        Q_sa = []
        for a in range(n_actions):
            q = reward[s] + gamma * np.dot(P[(s,a)], V)
            Q_sa.append(q)
        Q_sa = np.array(Q_sa)
        # For debugging
        Q_values[s] = Q_sa
        # Stable softmax
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift) / np.sum(np.exp(shift))
    return policy, Q_values

# ---------------------------
# 5) Compute State Visitation Frequencies
# ---------------------------
def compute_svf(policy, P, start_state=0, trajectory_length=5):
    """
    Accumulate visitation frequencies over 'trajectory_length' steps,
    starting with a single initial state (prob=1).
    """
    d_t = np.zeros(n_states)
    d_t[start_state] = 1.0

    svf = np.zeros(n_states)
    for _ in range(trajectory_length):
        svf += d_t
        next_d_t = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                next_d_t += d_t[s] * policy[s, a] * P[(s,a)]
        d_t = next_d_t
    return svf

# ---------------------------
# 6) Generate "Expert" Trajectories
# ---------------------------
def generate_soft_optimal_trajectories(policy, P, n_trajectories=100, trajectory_length=5):
    """
    Sample states by following 'policy'. 
    Each trajectory has length 'trajectory_length'.
    """
    trajectories = []
    for _ in range(n_trajectories):
        traj = []
        # Start from a random state
        state = np.random.choice(n_states)
        for _ in range(trajectory_length):
            action = np.random.choice(n_actions, p=policy[state])
            next_state = np.random.choice(n_states, p=P[(state, action)])
            traj.append(state)
            state = next_state
        trajectories.append(traj)
    return trajectories

# ---------------------------
# 7) MaxEnt IRL
# ---------------------------
def maxent_irl(features, expert_trajectories,
               P,  # transition model
               true_rewards=None, true_policy=None, true_value=None,
               gamma=0.9, lr=0.01, n_iters=10000, print_every=1000):
    """
    Gradient-based MaxEnt IRL: R(s) = theta^T phi(s).
    Additional debugging prints added.
    """
    n_states, d_features = features.shape
    # Initialize random weights
    reward_weights = np.random.uniform(size=d_features)

    # Compute expert state visitation frequency (normalized)
    expert_svf = np.zeros(n_states)
    total_steps = 0
    for traj in expert_trajectories:
        for s in traj:
            expert_svf[s] += 1
        total_steps += len(traj)
    expert_svf /= total_steps

    for it in range(n_iters):
        # 1) Current reward
        reward_est = features @ reward_weights
        
        # 2) Soft Value Iteration
        V_est = soft_value_iteration(reward_est, P)
        
        # 3) Compute policy & Q-values
        policy_est, Q_values = compute_policy(V_est, reward_est, P)
        
        # 4) Predicted SVF
        svf_est = compute_svf(policy_est, P, start_state=0, trajectory_length=5)
        svf_est /= np.sum(svf_est)

        # 5) Gradient step
        grad = expert_svf - svf_est
        reward_weights += lr * features.T @ grad

        # 6) Debug prints
        if (it+1) % print_every == 0:
            loss_svf = np.linalg.norm(expert_svf - svf_est)  # L2
            grad_norm = np.linalg.norm(grad)
            msg = f"Iter {it+1:05d} | Loss(SVF): {loss_svf:.4f} | GradNorm: {grad_norm:.4f}"
            
            if true_policy is not None:
                # L1 difference across all states, all actions
                pol_diff = np.sum(np.abs(policy_est - true_policy))
                msg += f" | PolDiff: {pol_diff:.4f}"
            if true_value is not None:
                # L1 difference in value
                val_diff = np.sum(np.abs(V_est - true_value))
                msg += f" | ValDiff: {val_diff:.4f}"
            if true_rewards is not None:
                # L1 difference in reward vector
                rew_diff = np.sum(np.abs(reward_est - true_rewards))
                msg += f" | RewDiff: {rew_diff:.4f}"

            print(msg)
            # Print SVF side by side
            print(f"  Expert SVF: {expert_svf}")
            print(f"  Pred   SVF: {svf_est}")
            # Print Q-values
            print("  Q-values (s x a):")
            for s in range(n_states):
                print(f"    s={s}, Q={Q_values[s]}")
            # Print policy
            print("  Policy (s x a):")
            for s in range(n_states):
                print(f"    s={s}, π={policy_est[s]}")
            # Print reward
            print(f"  Reward: {reward_est}")
            print("")

    return features @ reward_weights

# ---------------------------
# 8) Main Demo
# ---------------------------
if __name__ == "__main__":
    np.random.seed(0)  # for reproducibility, if desired

    # 1) Compute the "true" V, policy from the known reward
    V_true = soft_value_iteration(true_rewards, P)
    policy_true, _ = compute_policy(V_true, true_rewards, P)

    # 2) Generate "expert" data from the "true" policy
    n_sample_trajectories = 500
    expert_trajectories = generate_soft_optimal_trajectories(
        policy_true, P, n_trajectories=n_sample_trajectories, trajectory_length=100
    )

    # 3) Run MaxEnt IRL with extra prints
    estimated_rewards = maxent_irl(
        features,
        expert_trajectories,
        P,
        true_rewards=true_rewards,
        true_policy=policy_true,
        true_value=V_true,
        lr=0.01,
        n_iters=10000,
        print_every=1000
    )

    # Final results
    print("\nFinal Results:")
    print("True Rewards:      ", true_rewards)
    print("Estimated Rewards: ", estimated_rewards)

Iter 01000 | Loss(SVF): 0.0076 | GradNorm: 0.0076 | PolDiff: 0.9582 | ValDiff: 11.8977 | RewDiff: 1.5524
  Expert SVF: [0.3905 0.3233 0.2862]
  Pred   SVF: [0.3938 0.3171 0.2891]
  Q-values (s x a):
    s=0, Q=[16.5041 16.6004 16.0799]
    s=1, Q=[16.6887 17.1129 17.2093]
    s=2, Q=[17.0966 16.576  17.0002]
  Policy (s x a):
    s=0, π=[0.3629 0.3996 0.2375]
    s=1, π=[0.2375 0.3629 0.3996]
    s=2, π=[0.3996 0.2375 0.3629]
  Reward: [0.2974 0.9063 0.7935]

Iter 02000 | Loss(SVF): 0.0008 | GradNorm: 0.0008 | PolDiff: 0.9798 | ValDiff: 11.9721 | RewDiff: 1.5823
  Expert SVF: [0.3905 0.3233 0.2862]
  Pred   SVF: [0.3911 0.3227 0.2863]
  Q-values (s x a):
    s=0, Q=[16.5035 16.6286 16.0745]
    s=1, Q=[16.7225 17.1515 17.2766]
    s=2, Q=[17.1303 16.5762 17.0052]
  Policy (s x a):
    s=0, π=[0.3591 0.407  0.2339]
    s=1, π=[0.2339 0.3591 0.407 ]
    s=2, π=[0.407  0.2339 0.3591]
  Reward: [0.2825 0.9305 0.7842]

Iter 03000 | Loss(SVF): 0.0001 | GradNorm: 0.0001 | PolDiff: 0.9832 | Va

# MAXENT + NXFP

In [18]:
import numpy as np
import time
from tabulate import tabulate
from scipy.special import logsumexp
from scipy.optimize import minimize

np.set_printoptions(precision=4, suppress=True)

# -------------------------------------------------
# Shared Setup
# -------------------------------------------------
eps = 0.05
n_states = 3
n_actions = 3
gamma = 0.9

def build_transition_matrix(eps=0.05):
    P = {}
    def dist(target):
        d = np.ones(n_states) * (eps / n_states)
        d[target] += (1 - eps)
        return d
    for s in range(n_states):
        if s == 0:
            P[(0,0)] = dist(target=2)
            P[(0,1)] = dist(target=1)
            P[(0,2)] = dist(target=0)
        elif s == 1:
            P[(1,0)] = dist(target=0)
            P[(1,1)] = dist(target=2)
            P[(1,2)] = dist(target=1)
        else:  # s=2
            P[(2,0)] = dist(target=1)
            P[(2,1)] = dist(target=0)
            P[(2,2)] = dist(target=2)
    return P

P = build_transition_matrix(eps)
features = np.eye(n_states)
true_rewards = np.array([0.0, 0.25, 0.9])

def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = [reward[s] + gamma*np.dot(P[(s,a)], V_prev) for a in range(n_actions)]
            V[s] = logsumexp(Q_sa)
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

def compute_policy(V, reward, P):
    policy = np.zeros((n_states, n_actions))
    Q_values = np.zeros((n_states, n_actions))
    for s in range(n_states):
        Q_sa = [reward[s] + gamma*np.dot(P[(s,a)], V) for a in range(n_actions)]
        Q_values[s] = Q_sa
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift) / np.sum(np.exp(shift))
    return policy, Q_values

def generate_soft_optimal_trajectories(policy, P, n_trajectories=100, trajectory_length=5):
    trajectories = []
    for _ in range(n_trajectories):
        traj, state = [], np.random.choice(n_states)
        for __ in range(trajectory_length):
            a = np.random.choice(n_actions, p=policy[state])
            s_next = np.random.choice(n_states, p=P[(state,a)])
            traj.append((state, a, s_next))
            state = s_next
        trajectories.append(traj)
    return trajectories

def compute_svf(policy, P, start_state=0, trajectory_length=5):
    d_t = np.zeros(n_states)
    d_t[start_state] = 1.0
    svf = np.zeros(n_states)
    for _ in range(trajectory_length):
        svf += d_t
        nxt = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                nxt += d_t[s] * policy[s,a] * P[(s,a)]
        d_t = nxt
    return svf

def compute_expert_log_likelihood(reward, expert_trajectories, P, use_transitions=True):
    V_est = soft_value_iteration(reward, P)
    policy_est, _ = compute_policy(V_est, reward, P)
    total_loglik = 0.0
    eps_small = 1e-12
    for traj in expert_trajectories:
        for (s, a, s_next) in traj:
            p_a = policy_est[s,a]
            total_loglik += np.log(p_a + eps_small)
            if use_transitions:
                total_loglik += np.log(P[(s,a)][s_next] + eps_small)
    return total_loglik

# -------------------------------------------------
# MaxEnt IRL
# -------------------------------------------------
def maxent_irl(features, expert_trajectories, P, anchor_mode=0, anchor_values=None,
               reg_lambda=0.0, gamma=0.9, lr=0.01, n_iters=10000, print_every=1000,
               verbose=True, true_rewards=None):
    n_states, d_features = features.shape
    w = np.random.uniform(-1e-3, 1e-3, d_features)
    if anchor_mode >= 1:
        w[0] = anchor_values[0]
    if anchor_mode == 2:
        w[1] = anchor_values[1]

    # Expert SVF
    expert_svf = np.zeros(n_states)
    tot_steps = 0
    for traj in expert_trajectories:
        for (s,a,s_next) in traj:
            expert_svf[s] += 1
        tot_steps += len(traj)
    expert_svf /= tot_steps

    for it in range(n_iters):
        reward_est = features @ w
        V_est = soft_value_iteration(reward_est, P)
        policy_est, _ = compute_policy(V_est, reward_est, P)
        svf_est = compute_svf(policy_est, P, 0, 5)
        svf_est /= np.sum(svf_est)

        grad_main = expert_svf - svf_est
        grad_w = features.T @ grad_main
        if reg_lambda > 0:
            grad_w -= reg_lambda * w
        w += lr*grad_w

        if anchor_mode >= 1:
            w[0] = anchor_values[0]
        if anchor_mode == 2:
            w[1] = anchor_values[1]

        if verbose and (it+1) % print_every == 0:
            pass  # optional prints

    return w

# -------------------------------------------------
# NxFP IRL
# -------------------------------------------------
def nfxp_objective(params, anchor_mode, anchor_values, reg_lambda, expert_trajectories):
    if anchor_mode == 0:
        r0, r1, r2 = params
    elif anchor_mode == 1:
        r0 = anchor_values[0]
        r1, r2 = params
    else:
        r0, r1 = anchor_values
        (r2,) = params
    reward = np.array([r0, r1, r2])

    V_est = soft_value_iteration(reward, P)
    policy_est, _ = compute_policy(V_est, reward, P)
    nll = 0.0
    eps = 1e-12
    for traj in expert_trajectories:
        for (s,a,_) in traj:
            p_a = policy_est[s,a]
            nll += -np.log(p_a + eps)
    penalty = reg_lambda*np.sum(params**2)
    return nll + penalty

def estimate_nfxp(expert_trajectories, anchor_mode, anchor_values, reg_lambda):
    if anchor_mode == 0:
        x0 = np.array([0.1, 0.1, 0.1])
        bnds = [(-2,2)]*3
    elif anchor_mode == 1:
        x0 = np.array([0.1, 0.1])
        bnds = [(-2,2)]*2
    else:
        x0 = np.array([0.1])
        bnds = [(-2,2)]

    def wrapper(x):
        return nfxp_objective(x, anchor_mode, anchor_values, reg_lambda, expert_trajectories)

    res = minimize(wrapper, x0, method='L-BFGS-B', bounds=bnds,
                   options={'maxiter':300, 'disp':False})

    if anchor_mode == 0:
        r0, r1, r2 = res.x
    elif anchor_mode == 1:
        r0 = anchor_values[0]
        r1, r2 = res.x
    else:
        r0, r1 = anchor_values
        (r2,) = res.x
    return np.array([r0, r1, r2]), res.fun

# -------------------------------------------------
# Unified Demo
# -------------------------------------------------
if __name__ == "__main__":
    np.random.seed(0)
    print("=== Base Experiment Info ===")
    print(f"States={n_states}, Actions={n_actions}, eps={eps}, gamma={gamma}")
    print("True Rewards:", true_rewards)
    n_sample_trajectories = 500
    trajectory_length = 50
    print(f"NumTraj={n_sample_trajectories}, TrajLen={trajectory_length}\n")

    # Generate expert data
    V_true = soft_value_iteration(true_rewards, P)
    policy_true, _ = compute_policy(V_true, true_rewards, P)
    expert_trajectories = generate_soft_optimal_trajectories(
        policy_true, P, n_trajectories=n_sample_trajectories, trajectory_length=trajectory_length
    )

    # Precompute an expert_svf_raw for scoring
    expert_svf_raw = np.zeros(n_states)
    tot_steps = 0
    for traj in expert_trajectories:
        for (s, a, _) in traj:
            expert_svf_raw[s] += 1
        tot_steps += len(traj)
    expert_svf_raw /= tot_steps

    anchor_modes = [0, 1, 2]
    reg_lambdas = [0.0, 0.05, 0.1, 0.2, 0.7]

    # Run both MaxEnt IRL and NxFP IRL
    results_maxent = {}
    results_nfxp = {}

    # --- MaxEnt IRL ---
    for am in anchor_modes:
        if am == 0:
            anc_vals = None
            am_str = "FixNone"
        elif am == 1:
            anc_vals = [0.0]
            am_str = "FixFirst"
        else:
            anc_vals = [0.0, 0.25]
            am_str = "FixFirstSecond"

        for rl in reg_lambdas:
            start_t = time.time()
            w = maxent_irl(features, expert_trajectories, P, anchor_mode=am,
                           anchor_values=anc_vals, reg_lambda=rl, verbose=False)
            end_t = time.time()

            est_rew = features @ w
            V_est = soft_value_iteration(est_rew, P)
            policy_est, _ = compute_policy(V_est, est_rew, P)

            pol_diff = np.sum(np.abs(policy_est - policy_true))
            rew_diff = np.sum(np.abs(est_rew - true_rewards))
            svf_est = compute_svf(policy_est, P, 0, 5)
            svf_est /= np.sum(svf_est)
            svf_diff = np.sum(np.abs(svf_est - expert_svf_raw))
            grad_norm = np.linalg.norm(expert_svf_raw - svf_est)
            val_diff = np.sum(np.abs(V_est - V_true))
            loglik = compute_expert_log_likelihood(est_rew, expert_trajectories, P, use_transitions=False)
            tsec = end_t - start_t

            results_maxent[(am_str, rl)] = {
                'R': est_rew,
                'RewDiff': rew_diff,
                'PolDiff': pol_diff,
                'GradNorm': grad_norm,
                'ValDiff': val_diff,
                'SvfDiff': svf_diff,
                'LogLik': loglik,
                'TimeSec': tsec
            }

    # --- NxFP IRL ---
    for am in anchor_modes:
        if am == 0:
            anc_vals = None
            am_str = "FixNone"
        elif am == 1:
            anc_vals = [0.0]
            am_str = "FixFirst"
        else:
            anc_vals = [0.0, 0.25]
            am_str = "FixFirstSecond"

        for rl in reg_lambdas:
            start_t = time.time()
            est_rew, _ = estimate_nfxp(expert_trajectories, am, anc_vals, rl)
            end_t = time.time()

            V_est = soft_value_iteration(est_rew, P)
            policy_est, _ = compute_policy(V_est, est_rew, P)

            pol_diff = np.sum(np.abs(policy_est - policy_true))
            rew_diff = np.sum(np.abs(est_rew - true_rewards))
            svf_est = compute_svf(policy_est, P, 0, 5)
            svf_est /= np.sum(svf_est)
            svf_diff = np.sum(np.abs(svf_est - expert_svf_raw))
            grad_norm = np.linalg.norm(expert_svf_raw - svf_est)
            val_diff = np.sum(np.abs(V_est - V_true))
            loglik = compute_expert_log_likelihood(est_rew, expert_trajectories, P, use_transitions=False)
            tsec = end_t - start_t

            results_nfxp[(am_str, rl)] = {
                'R': est_rew,
                'RewDiff': rew_diff,
                'PolDiff': pol_diff,
                'GradNorm': grad_norm,
                'ValDiff': val_diff,
                'SvfDiff': svf_diff,
                'LogLik': loglik,
                'TimeSec': tsec
            }

    # Print tables
    hdrs = ["AnchorMode","reg_lambda","R(0)","R(1)","R(2)","RewDiff","PolDiff","GradNorm","ValDiff","SvfDiff","LogLik","TimeSec"]

    def build_rows(res):
        rows = []
        for am in anchor_modes:
            if am == 0:
                am_str = "FixNone"
            elif am == 1:
                am_str = "FixFirst"
            else:
                am_str = "FixFirstSecond"
            for rl in reg_lambdas:
                vals = res[(am_str, rl)]
                r0, r1, r2 = vals['R']
                row = [
                    am_str, f"{rl:.2f}", f"{r0:.4f}", f"{r1:.4f}", f"{r2:.4f}",
                    f"{vals['RewDiff']:.4f}", f"{vals['PolDiff']:.4f}",
                    f"{vals['GradNorm']:.4f}", f"{vals['ValDiff']:.4f}",
                    f"{vals['SvfDiff']:.4f}", f"{vals['LogLik']:.4f}",
                    f"{vals['TimeSec']:.4f}"
                ]
                rows.append(row)
        return rows

    print("\n=== MaxEnt IRL Results ===")
    rows_maxent = build_rows(results_maxent)
    print(tabulate(rows_maxent, headers=hdrs, tablefmt="plain"))

    print("\n=== NxFP IRL Results ===")
    rows_nfxp = build_rows(results_nfxp)
    print(tabulate(rows_nfxp, headers=hdrs, tablefmt="plain"))
    print("===================================================")


=== Base Experiment Info ===
States=3, Actions=3, eps=0.05, gamma=0.9
True Rewards: [0.   0.25 0.9 ]
NumTraj=500, TrajLen=50


=== MaxEnt IRL Results ===
AnchorMode        reg_lambda     R(0)    R(1)    R(2)    RewDiff    PolDiff    GradNorm    ValDiff    SvfDiff    LogLik    TimeSec
FixNone                 0     -2.1748  0.7848  1.3909     3.2005     1.1914      0.0021     5.0764     0.0035  -33382.4    46.1096
FixNone                 0.05  -1.1384  0.2913  0.8471     1.2327     0.7558      0.0724     6.6523     0.1138  -27773.2    45.9838
FixNone                 0.1   -0.8596  0.1807  0.6789     1.15       0.5266      0.111      9.0881     0.1719  -26822      47.1638
FixNone                 0.2   -0.6029  0.0991  0.5038     1.15       0.3358      0.1584    10.9757     0.2412  -26329.4    46.0111
FixNone                 0.7   -0.2574  0.0291  0.2284     1.15       0.549       0.2417    12.7077     0.3604  -26512      45.9233
FixFirst                0      0       2.3865  3.0045     4.

# GRIDWORLD

In [9]:
import numpy as np
import time
from tabulate import tabulate
from scipy.special import logsumexp
from scipy.optimize import minimize
from scipy.stats import spearmanr, pearsonr

np.set_printoptions(precision=4, suppress=True)
np.random.seed(0)

# ---------------------------
# Environment Setup (5x5)
# ---------------------------
N = 5
n_states = N*N
actions = {"LEFT":0, "RIGHT":1, "UP":2, "DOWN":3, "STAY":4}
n_actions = len(actions)
gamma = 0.9

def to_rc(s):
    return divmod(s, N)  # row,col

def to_s(r,c):
    return r*N + c

def build_transition_matrix():
    P = {}
    for s in range(n_states):
        r,c = to_rc(s)
        # Possible next states
        left_s  = to_s(r, max(c-1, 0))
        right_s = to_s(r, min(c+1, N-1))
        up_s    = to_s(max(r-1, 0), c)
        down_s  = to_s(min(r+1, N-1), c)
        stay_s  = s
        next_map = [left_s, right_s, up_s, down_s, stay_s]
        for a in range(n_actions):
            dist = np.zeros(n_states)
            dist[next_map[a]] = 1.0  # deterministic transitions
            P[(s,a)] = dist
    return P

P = build_transition_matrix()

# One-hot features for each of 25 states
features = np.eye(n_states)

# ---------------------------
# True Rewards (fixed random)
# ---------------------------
true_rewards = np.random.uniform(-1,1,n_states)

# ---------------------------
# Helpers
# ---------------------------
def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = []
            for a in range(n_actions):
                Q_sa.append(reward[s] + gamma * np.dot(P[(s,a)], V_prev))
            V[s] = logsumexp(Q_sa)
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

def compute_policy(V, reward, P):
    policy = np.zeros((n_states, n_actions))
    for s in range(n_states):
        Q_sa = []
        for a in range(n_actions):
            Q_sa.append(reward[s] + gamma * np.dot(P[(s,a)], V))
        Q_sa = np.array(Q_sa)
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift)/np.sum(np.exp(shift))
    return policy

def generate_expert_trajectories(policy, P, n_trajectories=200, traj_len=10):
    trajectories = []
    for _ in range(n_trajectories):
        s = np.random.choice(n_states)
        traj = []
        for __ in range(traj_len):
            a = np.random.choice(n_actions, p=policy[s])
            s_next = np.random.choice(n_states, p=P[(s,a)])
            traj.append((s,a,s_next))
            s = s_next
        trajectories.append(traj)
    return trajectories

def compute_expert_loglik(rew, trajectories, P):
    V_est = soft_value_iteration(rew, P)
    pol_est = compute_policy(V_est, rew, P)
    loglik = 0.0
    eps = 1e-12
    for traj in trajectories:
        for (s,a,_) in traj:
            loglik += np.log(pol_est[s,a]+eps)
    return loglik

def compute_svf(policy, P, start_count=10, traj_len=10):
    # Approx. occupancy by simulating from random start states
    counts = np.zeros(n_states)
    total_steps = start_count * traj_len
    for _ in range(start_count):
        s = np.random.choice(n_states)
        for __ in range(traj_len):
            counts[s] += 1
            a = np.random.choice(n_actions, p=policy[s])
            s = np.random.choice(n_states, p=P[(s,a)])
    return counts / total_steps

# ---------------------------
# Metrics
# ---------------------------
def reward_metrics(true_r, est_r):
    # rank corr, corr, RMSE
    r_rank, _ = spearmanr(true_r, est_r)
    r_corr, _ = pearsonr(true_r, est_r)
    rmse = np.sqrt(np.mean((true_r - est_r)**2))
    return r_rank, r_corr, rmse

def policy_diff(pol1, pol2):
    return np.mean(np.abs(pol1 - pol2))

def vector_diff(v1, v2):
    return np.mean(np.abs(v1 - v2))

# ---------------------------
# Expert Data
# ---------------------------
V_true = soft_value_iteration(true_rewards, P)
policy_true = compute_policy(V_true, true_rewards, P)
expert_trajectories = generate_expert_trajectories(policy_true, P)
expert_svf = compute_svf(policy_true, P)

# ---------------------------
# MaxEnt IRL
# ---------------------------
def maxent_irl(features, expert_trajectories, P, anchor_mode=0, anchor_val=0.0, reg=0.0,
               lr=0.01, n_iters=5000):
    w = np.zeros(features.shape[1])
    # If anchor_mode = 1 or 2, fix w[0]
    # anchor_mode=2 => also apply a reg in gradient
    # Expert occupancy
    counts = np.zeros(n_states)
    for traj in expert_trajectories:
        for (s,a,sn) in traj:
            counts[s] += 1
    counts /= np.sum(counts)
    
    for _ in range(n_iters):
        r_est = features @ w
        V_est = soft_value_iteration(r_est, P)
        pol_est = compute_policy(V_est, r_est, P)
        svf_est = compute_svf(pol_est, P)
        grad = features.T @ (counts - svf_est)
        if anchor_mode == 2:
            grad -= reg * w
        w += lr * grad
        if anchor_mode >= 1:
            w[0] = anchor_val
    return features @ w

# ---------------------------
# NxFP IRL
# ---------------------------
def nfxp_objective(params, anchor_mode, anchor_val, reg, expert_trajectories):
    # anchor_mode=0 => free r, anchor_mode=1 => anchor r0, anchor_mode=2 => anchor + reg
    if anchor_mode == 0:
        r = params
    else:
        # Force r[0] = anchor_val
        r = np.concatenate([[anchor_val], params])
    V_est = soft_value_iteration(r, P)
    pol_est = compute_policy(V_est, r, P)
    nll = 0.0
    eps = 1e-12
    for traj in expert_trajectories:
        for (s,a,_) in traj:
            nll -= np.log(pol_est[s,a]+eps)
    if anchor_mode == 2:
        nll += reg * np.sum(params**2)
    return nll

def estimate_nfxp(expert_trajectories, anchor_mode=0, anchor_val=0.0, reg=0.0):
    # For simplicity, dimension = 25. If anchor=1 or 2 => param dimension = 24
    if anchor_mode == 0:
        x0 = np.zeros(n_states)
        bnds = [(-2,2)]*n_states
    else:
        x0 = np.zeros(n_states-1)
        bnds = [(-2,2)]*(n_states-1)
    
    def wrapper(p):
        return nfxp_objective(p, anchor_mode, anchor_val, reg, expert_trajectories)
    
    res = minimize(wrapper, x0, method='L-BFGS-B', bounds=bnds, options={'maxiter':500, 'disp':False})
    if anchor_mode == 0:
        r = res.x
    else:
        r = np.concatenate([[anchor_val], res.x])
    return r

# ---------------------------
# Run 3 Experiments
# ---------------------------
# 1) Anchor None (no reg)
# 2) Anchor One (no reg)
# 3) Anchor One + reg
anchor_settings = [
    (0, 0.0, 0.0),   # no anchor
    (1, 1.0, 0.0),   # anchor r[0]=1
    (2, 1.0, 0.1)    # anchor r[0]=1 + regularization
]

results_maxent = []
results_nfxp = []

for (mode, aval, regval) in anchor_settings:
    # MaxEnt
    start_t = time.time()
    est_r = maxent_irl(features, expert_trajectories, P, anchor_mode=mode,
                       anchor_val=aval, reg=regval)
    tsec = time.time() - start_t
    
    # Evaluate
    rrank, rcorr, rmse = reward_metrics(true_rewards, est_r)
    V_est = soft_value_iteration(est_r, P)
    pol_est = compute_policy(V_est, est_r, P)
    pol_d = policy_diff(pol_est, policy_true)
    svf_est = compute_svf(pol_est, P)
    svf_d = vector_diff(svf_est, expert_svf)
    val_d = vector_diff(V_est, V_true)
    ll = compute_expert_loglik(est_r, expert_trajectories, P)
    
    results_maxent.append([
        f"Anch={mode}", f"{rrank:.2f}", f"{rcorr:.2f}", f"{rmse:.2f}",
        f"{pol_d:.3f}", f"{svf_d:.3f}", f"{val_d:.3f}",
        f"{tsec:.3f}", f"{ll:.1f}"
    ])
    
    # NxFP
    start_t = time.time()
    est_r_nfxp = estimate_nfxp(expert_trajectories, mode, aval, regval)
    tsec = time.time() - start_t
    
    # Evaluate
    rrank, rcorr, rmse = reward_metrics(true_rewards, est_r_nfxp)
    V_est = soft_value_iteration(est_r_nfxp, P)
    pol_est = compute_policy(V_est, est_r_nfxp, P)
    pol_d = policy_diff(pol_est, policy_true)
    svf_est = compute_svf(pol_est, P)
    svf_d = vector_diff(svf_est, expert_svf)
    val_d = vector_diff(V_est, V_true)
    ll = compute_expert_loglik(est_r_nfxp, expert_trajectories, P)
    
    results_nfxp.append([
        f"Anch={mode}", f"{rrank:.2f}", f"{rcorr:.2f}", f"{rmse:.2f}",
        f"{pol_d:.3f}", f"{svf_d:.3f}", f"{val_d:.3f}",
        f"{tsec:.3f}", f"{ll:.1f}"
    ])

# ---------------------------
# Print Tables
# ---------------------------
headers = ["Anchor","RankCorr","Corr","RMSE","PolDiff","SvfDiff","ValDiff","Time","LogLik"]

print("=== MaxEnt IRL (3 Rows) ===")
print(tabulate(results_maxent, headers=headers, tablefmt="plain"))
print("\n=== NxFP IRL (3 Rows) ===")
print(tabulate(results_nfxp, headers=headers, tablefmt="plain"))


=== MaxEnt IRL (3 Rows) ===
Anchor      RankCorr    Corr    RMSE    PolDiff    SvfDiff    ValDiff     Time    LogLik
Anch=0          0.97    0.94    0.3       0.028      0.027      2.404  212.634   -2913.7
Anch=1          0.88    0.84    0.33      0.043      0.028      1.09   215.785   -2978.9
Anch=2          0.73    0.53    0.48      0.099      0.03       1.708  215.803   -3211.8

=== NxFP IRL (3 Rows) ===
Anchor      RankCorr    Corr    RMSE    PolDiff    SvfDiff    ValDiff    Time    LogLik
Anch=0          0.97    0.98    0.35      0.017      0.046      3.106  45.401   -2893.8
Anch=1          0.97    0.98    1.04      0.017      0.034     10.497  47.572   -2893.8
Anch=2          0.98    0.99    0.95      0.016      0.044      9.58   64.357   -2894.1
