In [35]:
import numpy as np

np.set_printoptions(precision=4, suppress=True)

# ---------------------------
# 1) Triangular MDP Setup
# ---------------------------
# States: 0,1,2 (think of them arranged in a triangle)
# Actions: 0=left, 1=right, 2=stay
# We'll define transitions with a small noise probability eps.
eps = 0.05
n_states = 3
n_actions = 3
gamma = 0.9

def build_transition_matrix(eps=0.05):
    """
    Triangular 3-state MDP with 3 actions.
    P[(s, a)] = probability distribution over next states [p0, p1, p2].
    
    Action 0 = 'left' 
    Action 1 = 'right'
    Action 2 = 'stay'
    """
    # Initialize dictionary
    P = {}
    # Helper to set distribution with main target and small noise to others
    def dist(target):
        # With prob (1 - eps) go to the target, with prob eps the rest is evenly distributed
        d = np.ones(n_states) * (eps / (n_states))
        d[target] += (1 - eps)
        return d
    
    # For each state s, define transitions for each action
    for s in range(n_states):
        # Action 'left'
        if s == 0:
            P[(0, 0)] = dist(target=2)  # from 0, left goes to 2 ideally
        elif s == 1:
            P[(1, 0)] = dist(target=0)  # from 1, left goes to 0
        else: # s==2
            P[(2, 0)] = dist(target=1)  # from 2, left goes to 1
        
        # Action 'right'
        if s == 0:
            P[(0, 1)] = dist(target=1)  # from 0, right -> 1
        elif s == 1:
            P[(1, 1)] = dist(target=2)  # from 1, right -> 2
        else: # s==2
            P[(2, 1)] = dist(target=0)  # from 2, right -> 0
        
        # Action 'stay'
        P[(s, 2)] = dist(target=s)     # from s, stay in s
    return P

# Construct the transitions
P = build_transition_matrix(eps)

# ---------------------------
# 2) Features & True Rewards
# ---------------------------
# We'll use one-hot features for each state: phi(s) = e_s
features = np.eye(n_states)

# Set a "true" reward for demonstration. 
# The user can choose any arbitrary vector.
true_rewards = np.array([0.5, 0.25, 0.10])

# ---------------------------
# 3) Soft Value Iteration
# ---------------------------
def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    """
    V(s) <- log sum_a exp( R(s) + gamma * sum_{s'} P(s'|s,a)*V(s') )
    """
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = []
            for a in range(n_actions):
                Q_sa.append(reward[s] + gamma * np.dot(P[(s,a)], V_prev))
            # log-sum-exp
            V[s] = np.log(np.sum(np.exp(Q_sa)))
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

# ---------------------------
# 4) Compute Policy
# ---------------------------
def compute_policy(V, reward, P):
    """
    pi(a|s) = exp(Q(s,a)) / sum_{a'} exp(Q(s,a'))
    Q(s,a) = R(s) + gamma * sum_{s'} P(s'|s,a]*V(s')
    """
    policy = np.zeros((n_states, n_actions))
    Q_values = np.zeros((n_states, n_actions))  # for debug printing
    for s in range(n_states):
        Q_sa = []
        for a in range(n_actions):
            q = reward[s] + gamma * np.dot(P[(s,a)], V)
            Q_sa.append(q)
        Q_sa = np.array(Q_sa)
        # For debugging
        Q_values[s] = Q_sa
        # Stable softmax
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift) / np.sum(np.exp(shift))
    return policy, Q_values

# ---------------------------
# 5) Compute State Visitation Frequencies
# ---------------------------
def compute_svf(policy, P, start_state=0, trajectory_length=5):
    """
    Accumulate visitation frequencies over 'trajectory_length' steps,
    starting with a single initial state (prob=1).
    """
    d_t = np.zeros(n_states)
    d_t[start_state] = 1.0

    svf = np.zeros(n_states)
    for _ in range(trajectory_length):
        svf += d_t
        next_d_t = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                next_d_t += d_t[s] * policy[s, a] * P[(s,a)]
        d_t = next_d_t
    return svf

# ---------------------------
# 6) Generate "Expert" Trajectories
# ---------------------------
def generate_soft_optimal_trajectories(policy, P, n_trajectories=100, trajectory_length=5):
    """
    Sample states by following 'policy'. 
    Each trajectory has length 'trajectory_length'.
    """
    trajectories = []
    for _ in range(n_trajectories):
        traj = []
        # Start from a random state
        state = np.random.choice(n_states)
        for _ in range(trajectory_length):
            action = np.random.choice(n_actions, p=policy[state])
            next_state = np.random.choice(n_states, p=P[(state, action)])
            traj.append(state)
            state = next_state
        trajectories.append(traj)
    return trajectories

# ---------------------------
# 7) MaxEnt IRL
# ---------------------------
def maxent_irl(features, expert_trajectories,
               P,  # transition model
               true_rewards=None, true_policy=None, true_value=None,
               gamma=0.9, lr=0.01, n_iters=10000, print_every=1000):
    """
    Gradient-based MaxEnt IRL: R(s) = theta^T phi(s).
    Additional debugging prints added.
    """
    n_states, d_features = features.shape
    # Initialize random weights
    reward_weights = np.random.uniform(size=d_features)

    # Compute expert state visitation frequency (normalized)
    expert_svf = np.zeros(n_states)
    total_steps = 0
    for traj in expert_trajectories:
        for s in traj:
            expert_svf[s] += 1
        total_steps += len(traj)
    expert_svf /= total_steps

    for it in range(n_iters):
        # 1) Current reward
        reward_est = features @ reward_weights
        
        # 2) Soft Value Iteration
        V_est = soft_value_iteration(reward_est, P)
        
        # 3) Compute policy & Q-values
        policy_est, Q_values = compute_policy(V_est, reward_est, P)
        
        # 4) Predicted SVF
        svf_est = compute_svf(policy_est, P, start_state=0, trajectory_length=5)
        svf_est /= np.sum(svf_est)

        # 5) Gradient step
        grad = expert_svf - svf_est
        reward_weights += lr * features.T @ grad

        # 6) Debug prints
        if (it+1) % print_every == 0:
            loss_svf = np.linalg.norm(expert_svf - svf_est)  # L2
            grad_norm = np.linalg.norm(grad)
            msg = f"Iter {it+1:05d} | Loss(SVF): {loss_svf:.4f} | GradNorm: {grad_norm:.4f}"
            
            if true_policy is not None:
                # L1 difference across all states, all actions
                pol_diff = np.sum(np.abs(policy_est - true_policy))
                msg += f" | PolDiff: {pol_diff:.4f}"
            if true_value is not None:
                # L1 difference in value
                val_diff = np.sum(np.abs(V_est - true_value))
                msg += f" | ValDiff: {val_diff:.4f}"
            if true_rewards is not None:
                # L1 difference in reward vector
                rew_diff = np.sum(np.abs(reward_est - true_rewards))
                msg += f" | RewDiff: {rew_diff:.4f}"

            print(msg)
            # Print SVF side by side
            print(f"  Expert SVF: {expert_svf}")
            print(f"  Pred   SVF: {svf_est}")
            # Print Q-values
            print("  Q-values (s x a):")
            for s in range(n_states):
                print(f"    s={s}, Q={Q_values[s]}")
            # Print policy
            print("  Policy (s x a):")
            for s in range(n_states):
                print(f"    s={s}, π={policy_est[s]}")
            # Print reward
            print(f"  Reward: {reward_est}")
            print("")

    return features @ reward_weights

# ---------------------------
# 8) Main Demo
# ---------------------------
if __name__ == "__main__":
    np.random.seed(0)  # for reproducibility, if desired

    # 1) Compute the "true" V, policy from the known reward
    V_true = soft_value_iteration(true_rewards, P)
    policy_true, _ = compute_policy(V_true, true_rewards, P)

    # 2) Generate "expert" data from the "true" policy
    n_sample_trajectories = 500
    expert_trajectories = generate_soft_optimal_trajectories(
        policy_true, P, n_trajectories=n_sample_trajectories, trajectory_length=100
    )

    # 3) Run MaxEnt IRL with extra prints
    estimated_rewards = maxent_irl(
        features,
        expert_trajectories,
        P,
        true_rewards=true_rewards,
        true_policy=policy_true,
        true_value=V_true,
        lr=0.01,
        n_iters=10000,
        print_every=1000
    )

    # Final results
    print("\nFinal Results:")
    print("True Rewards:      ", true_rewards)
    print("Estimated Rewards: ", estimated_rewards)

Iter 01000 | Loss(SVF): 0.0076 | GradNorm: 0.0076 | PolDiff: 0.9582 | ValDiff: 11.8977 | RewDiff: 1.5524
  Expert SVF: [0.3905 0.3233 0.2862]
  Pred   SVF: [0.3938 0.3171 0.2891]
  Q-values (s x a):
    s=0, Q=[16.5041 16.6004 16.0799]
    s=1, Q=[16.6887 17.1129 17.2093]
    s=2, Q=[17.0966 16.576  17.0002]
  Policy (s x a):
    s=0, π=[0.3629 0.3996 0.2375]
    s=1, π=[0.2375 0.3629 0.3996]
    s=2, π=[0.3996 0.2375 0.3629]
  Reward: [0.2974 0.9063 0.7935]

Iter 02000 | Loss(SVF): 0.0008 | GradNorm: 0.0008 | PolDiff: 0.9798 | ValDiff: 11.9721 | RewDiff: 1.5823
  Expert SVF: [0.3905 0.3233 0.2862]
  Pred   SVF: [0.3911 0.3227 0.2863]
  Q-values (s x a):
    s=0, Q=[16.5035 16.6286 16.0745]
    s=1, Q=[16.7225 17.1515 17.2766]
    s=2, Q=[17.1303 16.5762 17.0052]
  Policy (s x a):
    s=0, π=[0.3591 0.407  0.2339]
    s=1, π=[0.2339 0.3591 0.407 ]
    s=2, π=[0.407  0.2339 0.3591]
  Reward: [0.2825 0.9305 0.7842]

Iter 03000 | Loss(SVF): 0.0001 | GradNorm: 0.0001 | PolDiff: 0.9832 | Va

In [6]:
import numpy as np

np.set_printoptions(precision=4, suppress=True)

# -------------------------------------------------
# 1) Triangular MDP Setup
# -------------------------------------------------
eps = 0.05
n_states = 3
n_actions = 3
gamma = 0.9

def build_transition_matrix(eps=0.05):
    """
    Triangular 3-state MDP with 3 actions.
    P[(s, a)] = probability distribution over next states [p0, p1, p2].
    
    Action 0=left, 1=right, 2=stay
    """
    P = {}
    def dist(target):
        d = np.ones(n_states) * (eps / n_states)
        d[target] += (1 - eps)
        return d
    
    for s in range(n_states):
        if s == 0:
            P[(0,0)] = dist(target=2)  # left
            P[(0,1)] = dist(target=1)  # right
            P[(0,2)] = dist(target=0)  # stay
        elif s == 1:
            P[(1,0)] = dist(target=0)
            P[(1,1)] = dist(target=2)
            P[(1,2)] = dist(target=1)
        else:  # s=2
            P[(2,0)] = dist(target=1)
            P[(2,1)] = dist(target=0)
            P[(2,2)] = dist(target=2)
    return P

P = build_transition_matrix(eps)

# -------------------------------------------------
# 2) Features & True Rewards
# -------------------------------------------------
features = np.eye(n_states)
true_rewards = np.array([0.0, 0.25, 0.9])

# -------------------------------------------------
# 3) Soft Value Iteration
# -------------------------------------------------
def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    """
    Returns the value function V(s) under the "soft" (MaxEnt) Bellman update:
    V(s) = log sum_a exp( R(s) + gamma * sum_{s'} P(s'|s,a)*V(s') ).
    """
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = [
                reward[s] + gamma * np.dot(P[(s,a)], V_prev)
                for a in range(n_actions)
            ]
            # log-sum-exp
            V[s] = np.log(np.sum(np.exp(Q_sa)))
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

# -------------------------------------------------
# 4) Compute Policy
# -------------------------------------------------
def compute_policy(V, reward, P):
    """
    Derives a softmax policy from value function V(s) and reward R(s).
    Returns (policy, Q_values).
    """
    policy = np.zeros((n_states, n_actions))
    Q_values = np.zeros((n_states, n_actions))
    for s in range(n_states):
        Q_sa = np.array([
            reward[s] + gamma * np.dot(P[(s,a)], V)
            for a in range(n_actions)
        ])
        Q_values[s] = Q_sa
        # stable softmax
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift) / np.sum(np.exp(shift))
    return policy, Q_values

# -------------------------------------------------
# 5) Compute State Visitation Frequencies
# -------------------------------------------------
def compute_svf(policy, P, start_state=0, trajectory_length=5):
    """
    Returns the unnormalized visitation distribution from a single start state,
    across 'trajectory_length' steps. We then sum over time and optionally normalize.
    """
    d_t = np.zeros(n_states)
    d_t[start_state] = 1.0
    svf = np.zeros(n_states)
    for _ in range(trajectory_length):
        svf += d_t
        next_d_t = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                next_d_t += d_t[s] * policy[s,a] * P[(s,a)]
        d_t = next_d_t
    return svf

# -------------------------------------------------
# 6) Generate "Expert" Trajectories
# -------------------------------------------------
def generate_soft_optimal_trajectories(policy, P, n_trajectories=100, trajectory_length=5):
    """
    Samples states by following 'policy'.
    """
    trajectories = []
    for _ in range(n_trajectories):
        traj = []
        state = np.random.choice(n_states)
        for __ in range(trajectory_length):
            action = np.random.choice(n_actions, p=policy[state])
            next_state = np.random.choice(n_states, p=P[(state, action)])
            traj.append((state, action, next_state))  # store (s,a,s')
            state = next_state
        trajectories.append(traj)
    return trajectories

# -------------------------------------------------
# 7) Log-Likelihood of Expert Trajectories
# -------------------------------------------------
def compute_expert_log_likelihood(reward, expert_trajectories, P):
    """
    Compute sum of log p_\theta(\tau_i) for the expert's trajectories,
    using the learned policy derived from 'reward'.
    p_\theta(s0, a0, s1, a1, ...) = product_t [ pi(a_t | s_t) * P(s_{t+1} | s_t, a_t) ].
    We'll ignore any initial state distribution factor for simplicity,
    or we can assume it's uniform.
    """
    # 1) Value iteration + policy
    V_est = soft_value_iteration(reward, P)
    policy_est, _ = compute_policy(V_est, reward, P)

    # 2) For each trajectory, compute log-prob
    total_loglik = 0.0
    for traj in expert_trajectories:
        # traj is a list of (s, a, s') steps
        logp_tau = 0.0
        for (s, a, s_next) in traj:
            # log p_\theta(s,a,s_next) = log pi(a|s) + log P(s_next|s,a)
            # ignoring any discount for demonstration
            # ignoring initial state distribution
            if policy_est[s,a] > 0:
                logp_tau += np.log(policy_est[s,a])
            else:
                logp_tau += -1e9  # or some large negative if policy_est[s,a] = 0
            # log P(s_next|s,a)
            if P[(s,a)][s_next] > 0:
                logp_tau += np.log(P[(s,a)][s_next])
            else:
                logp_tau += -1e9
        total_loglik += logp_tau
    return total_loglik

# -------------------------------------------------
# 8) MaxEnt IRL
# -------------------------------------------------
def maxent_irl(
    features,
    expert_trajectories,
    P,
    anchor_mode=0,           # 0=fix none, 1=fix first, 2=fix first&second
    anchor_values=None,      # e.g. [0.0], or [0.0, 0.25], or None
    reg_lambda=0.0,          # L2 penalty
    gamma=0.9,
    lr=0.01,
    n_iters=10000,
    print_every=1000,
    verbose=True,
    true_rewards=None,
    true_policy=None,
    true_value=None
):
    """
    anchor_mode:
      0 -> no anchoring
      1 -> fix R(0)=anchor_values[0]
      2 -> fix R(0)=anchor_values[0], R(1)=anchor_values[1]

    We do gradient-based MaxEnt IRL, with an L2 penalty.
    """
    n_states, d_features = features.shape
    reward_weights = np.random.uniform(low=-1e-3, high=1e-3, size=d_features)

    # If anchoring, set those dimensions
    if anchor_mode >= 1:
        reward_weights[0] = anchor_values[0]
    if anchor_mode == 2:
        reward_weights[1] = anchor_values[1]

    # Expert SVF
    expert_svf = np.zeros(n_states)
    total_steps = 0
    for traj in expert_trajectories:
        for (s,a,s_next) in traj:
            expert_svf[s] += 1
        total_steps += len(traj)
    expert_svf /= total_steps

    for it in range(n_iters):
        reward_est = features @ reward_weights
        V_est = soft_value_iteration(reward_est, P)
        policy_est, _ = compute_policy(V_est, reward_est, P)

        # Compute learned SVF
        svf_est = compute_svf(policy_est, P, start_state=0, trajectory_length=5)
        svf_est /= np.sum(svf_est)

        # IRL gradient
        grad_main = expert_svf - svf_est
        grad_for_weights = features.T @ grad_main

        # L2 penalty
        if reg_lambda > 0:
            grad_for_weights -= reg_lambda * reward_weights

        # Update
        reward_weights += lr * grad_for_weights

        # Re-pin anchors if needed
        if anchor_mode >= 1:
            reward_weights[0] = anchor_values[0]
        if anchor_mode == 2:
            reward_weights[1] = anchor_values[1]

        # Optional prints
        if verbose and (it+1) % print_every == 0:
            loss_svf = np.linalg.norm(grad_main)
            msg = f"Iter {it+1:05d} | Loss(SVF): {loss_svf:.4f}"
            if true_rewards is not None:
                rew_diff = np.sum(np.abs(reward_est - true_rewards))
                msg += f" | RewDiff: {rew_diff:.4f}"
            print(msg)

    return reward_weights  # We'll convert to R(s) outside

# -------------------------------------------------
# 9) Main Demo
# -------------------------------------------------
if __name__ == "__main__":
    np.random.seed(0)

    # -----------------------------------------
    # Print Base Experiment Info
    # -----------------------------------------
    print("=== Base Experiment Info ===")
    print(f"States:        {n_states}")
    print(f"Actions:       {n_actions}")
    print(f"eps (noise):   {eps}")
    print(f"gamma:         {gamma}")
    print(f"True Rewards:  {true_rewards}")
    n_sample_trajectories = 500
    trajectory_length = 50
    print(f"NumTraj:       {n_sample_trajectories}")
    print(f"TrajLen:       {trajectory_length}")
    print("============================\n")

    # Build transitions, get "true" V, policy
    V_true = soft_value_iteration(true_rewards, P)
    policy_true, _ = compute_policy(V_true, true_rewards, P)

    # Generate expert data
    expert_trajectories = generate_soft_optimal_trajectories(
        policy_true, P,
        n_trajectories=n_sample_trajectories,
        trajectory_length=trajectory_length
    )

    # We'll evaluate anchor_mode in [0,1,2] and reg_lambda in [0,0.05,0.1,0.2,0.7].
    anchor_modes = [0, 1, 2]
    reg_lambdas = [0.0, 0.05, 0.1, 0.2, 0.7]

    results = {}

    # Precompute the expert state distribution to measure "svf_diff"
    # We'll also keep the total steps for reference
    expert_svf_raw = np.zeros(n_states)
    tot_steps = 0
    for traj in expert_trajectories:
        for (s,a,s_next) in traj:
            expert_svf_raw[s] += 1
        tot_steps += len(traj)
    expert_svf_raw /= tot_steps

    for am in anchor_modes:
        if am == 0:
            anc_vals = None
            am_str = "FixNone"
        elif am == 1:
            anc_vals = [0.0]
            am_str = "FixFirst"
        else:
            anc_vals = [0.0, 0.25]
            am_str = "FixFirstSecond"

        for rl in reg_lambdas:
            # (1) Run IRL -> returns the final reward weights
            final_weights = maxent_irl(
                features,
                expert_trajectories,
                P,
                anchor_mode=am,
                anchor_values=anc_vals,
                reg_lambda=rl,
                gamma=gamma,
                lr=0.01,
                n_iters=10000,
                print_every=2000,
                verbose=False,
                true_rewards=true_rewards,
                true_policy=policy_true,
                true_value=V_true
            )
            # Convert final weights -> final reward
            est_rew = features @ final_weights

            # (2) Recompute final policy etc.
            V_est = soft_value_iteration(est_rew, P)
            policy_est, _ = compute_policy(V_est, est_rew, P)

            # 2a) PolDiff
            pol_diff = np.sum(np.abs(policy_est - policy_true))

            # 2b) RewDiff
            rew_diff = np.sum(np.abs(est_rew - true_rewards))

            # 2c) SVF difference
            svf_est = compute_svf(policy_est, P, 0, trajectory_length=5)
            svf_est /= np.sum(svf_est)
            svf_diff = np.sum(np.abs(svf_est - expert_svf_raw))

            # 2d) Gradient norm
            grad_main = expert_svf_raw - svf_est
            grad_norm = np.linalg.norm(grad_main)

            # 2e) Value difference
            val_diff = np.sum(np.abs(V_est - V_true))

            # 2f) Log-likelihood of expert demos
            loglik = compute_expert_log_likelihood(est_rew, expert_trajectories, P)

            # Store
            results[(am_str, rl)] = {
                'R': est_rew,
                'PolDiff': pol_diff,
                'RewDiff': rew_diff,
                'GradNorm': grad_norm,
                'ValDiff': val_diff,
                'SvfDiff': svf_diff,
                'LogLik': loglik
            }

    # Print final table
    print("\n=== Final Results Table ===")
    print("AnchorMode       | reg_lambda |   R(0)    R(1)    R(2)   | RewDiff | PolDiff | GradNorm | ValDiff | SvfDiff | LogLik")
    print("------------------------------------------------------------------------------------------------------------------")

    for am in anchor_modes:
        if am == 0:
            am_str = "FixNone"
        elif am == 1:
            am_str = "FixFirst"
        else:
            am_str = "FixFirstSecond"

        for rl in reg_lambdas:
            rvals = results[(am_str, rl)]
            R_0, R_1, R_2 = rvals['R']
            print(f"{am_str:<16} | {rl:<9.2f} | {R_0:7.4f} {R_1:7.4f} {R_2:7.4f} "
                  f"| {rvals['RewDiff']:.4f}  "
                  f"| {rvals['PolDiff']:.4f}  "
                  f"| {rvals['GradNorm']:.4f}  "
                  f"| {rvals['ValDiff']:.4f}  "
                  f"| {rvals['SvfDiff']:.4f}  "
                  f"| {rvals['LogLik']:.4f}")
    print("==================================================================================================================\n")


=== Base Experiment Info ===
States:        3
Actions:       3
eps (noise):   0.05
gamma:         0.9
True Rewards:  [0.   0.25 0.9 ]
NumTraj:       500
TrajLen:       50


=== Final Results Table ===
AnchorMode       | reg_lambda |   R(0)    R(1)    R(2)   | RewDiff | PolDiff | GradNorm | ValDiff | SvfDiff | LogLik
------------------------------------------------------------------------------------------------------------------
FixNone          | 0.00      | -2.1748  0.7848  1.3909 | 3.2005  | 1.1914  | 0.0021  | 5.0764  | 0.0035  | -37579.8135
FixNone          | 0.05      | -1.1384  0.2913  0.8471 | 1.2327  | 0.7558  | 0.0724  | 6.6523  | 0.1138  | -31970.5889
FixNone          | 0.10      | -0.8596  0.1807  0.6789 | 1.1500  | 0.5266  | 0.1110  | 9.0881  | 0.1719  | -31019.4487
FixNone          | 0.20      | -0.6029  0.0991  0.5038 | 1.1500  | 0.3358  | 0.1584  | 10.9757  | 0.2412  | -30526.8096
FixNone          | 0.70      | -0.2574  0.0291  0.2284 | 1.1500  | 0.5490  | 0.2417  | 12.

In [14]:
import numpy as np
import random
from scipy.special import logsumexp
from scipy.optimize import minimize

np.set_printoptions(precision=4, suppress=True)

# -------------------------------------------------
# 1) Triangular MDP Setup
# -------------------------------------------------
eps = 0.05
n_states = 3
n_actions = 3
gamma = 0.9

def build_transition_matrix(eps=0.05):
    """
    Triangular 3-state MDP with 3 actions.
    P[(s, a)] = probability distribution over next states [p0, p1, p2].
    
    Action 0=left, 1=right, 2=stay
    """
    P = {}
    def dist(target):
        d = np.ones(n_states) * (eps / n_states)
        d[target] += (1 - eps)
        return d
    
    for s in range(n_states):
        if s == 0:
            P[(0,0)] = dist(target=2)  # left
            P[(0,1)] = dist(target=1)  # right
            P[(0,2)] = dist(target=0)  # stay
        elif s == 1:
            P[(1,0)] = dist(target=0)
            P[(1,1)] = dist(target=2)
            P[(1,2)] = dist(target=1)
        else:  # s=2
            P[(2,0)] = dist(target=1)
            P[(2,1)] = dist(target=0)
            P[(2,2)] = dist(target=2)
    return P

P = build_transition_matrix(eps)

# -------------------------------------------------
# 2) Features & True Rewards
# -------------------------------------------------
features = np.eye(n_states)
true_rewards = np.array([0.0, 0.25, 0.9])

# -------------------------------------------------
# 3) Soft Value Iteration (Inner Loop)
# -------------------------------------------------
def soft_value_iteration(reward, P, tol=1e-6, max_iter=200):
    """
    Returns the value function V(s) under the 'soft' (MaxEnt) Bellman update:
      V(s) = log sum_a exp( reward[s] + gamma * sum_{s'} P(s'|s,a)*V(s') ).
    """
    V = np.zeros(n_states)
    for _ in range(max_iter):
        V_prev = V.copy()
        for s in range(n_states):
            Q_sa = []
            for a in range(n_actions):
                Q_sa.append(reward[s] + gamma * np.dot(P[(s,a)], V_prev))
            V[s] = logsumexp(Q_sa)
        if np.max(np.abs(V - V_prev)) < tol:
            break
    return V

def compute_policy(V, reward, P):
    """
    Softmax policy pi(a|s) = exp(Q(s,a)) / sum_{a'} exp(Q(s,a')).
    Returns policy[s,a] and Q_values[s,a].
    """
    policy = np.zeros((n_states, n_actions))
    Q_values = np.zeros((n_states, n_actions))
    for s in range(n_states):
        Q_sa = [
            reward[s] + gamma * np.dot(P[(s,a)], V)
            for a in range(n_actions)
        ]
        Q_values[s] = Q_sa
        shift = Q_sa - np.max(Q_sa)
        policy[s] = np.exp(shift) / np.sum(np.exp(shift))
    return policy, Q_values

# -------------------------------------------------
# 4) Generating Expert Data
# -------------------------------------------------
def generate_soft_optimal_trajectories(policy, P, n_trajectories=100, trajectory_length=5):
    """
    Samples states by following 'policy'.
    """
    trajectories = []
    for _ in range(n_trajectories):
        traj = []
        state = np.random.choice(n_states)
        for __ in range(trajectory_length):
            action = np.random.choice(n_actions, p=policy[state])
            next_state = np.random.choice(n_states, p=P[(state, action)])
            traj.append((state, action, next_state))
            state = next_state
        trajectories.append(traj)
    return trajectories

# -------------------------------------------------
# 5) NxFP with L2 Penalty + Anchors
# -------------------------------------------------
def nfxp_objective(unconstrained_params, anchor_mode, anchor_values, reg_lambda, expert_trajectories):
    """
    This function takes a parameter vector whose dimension depends on anchor_mode:
      anchor_mode=0 => len(unconstrained_params) = 3
      anchor_mode=1 => len(unconstrained_params) = 2
      anchor_mode=2 => len(unconstrained_params) = 1

    We reconstruct the full reward vector [r0, r1, r2], compute negative log-likelihood,
    then add reg_lambda * || unanchored_params ||^2.
    """
    # 1) Reconstruct full 3D reward from anchored + unconstrained
    if anchor_mode == 0:
        # No anchors, unconstrained_params = [r0, r1, r2]
        r0, r1, r2 = unconstrained_params
    elif anchor_mode == 1:
        # anchor R(0) = anchor_values[0], optimize R(1), R(2)
        r0 = anchor_values[0]
        r1, r2 = unconstrained_params
    else:
        # anchor R(0), R(1), only optimize R(2)
        r0 = anchor_values[0]
        r1 = anchor_values[1]
        (r2,) = unconstrained_params

    reward = np.array([r0, r1, r2])

    # 2) Inner loop: compute policy
    V = soft_value_iteration(reward, P)
    policy_est, _ = compute_policy(V, reward, P)

    # 3) Compute negative log-likelihood
    eps_small = 1e-12
    nll = 0.0
    for traj in expert_trajectories:
        for (s, a, s_next) in traj:
            prob_a = policy_est[s,a]
            if prob_a < eps_small:
                nll += -np.log(eps_small)
            else:
                nll += -np.log(prob_a)
            # (Optional) If you want transitions in the likelihood:
            # nll += -np.log(P[(s,a)][s_next] + eps_small)

    # 4) L2 penalty only on the "free" parameters
    penalty = reg_lambda * np.sum(unconstrained_params**2)

    return nll + penalty

def estimate_nfxp(expert_trajectories, anchor_mode, anchor_values, reg_lambda):
    """
    Minimizes NxFP objective for each (anchor_mode, anchor_values, reg_lambda).
    """
    # Build initial guess depending on anchor_mode
    if anchor_mode == 0:
        # 3 free params
        x0 = np.array([0.1, 0.1, 0.1])
        bnds = [(-2,2)]*3
    elif anchor_mode == 1:
        # 2 free params for R(1), R(2), anchor R(0)
        x0 = np.array([0.1, 0.1])
        bnds = [(-2,2)]*2
    else:
        # 1 free param for R(2), anchor R(0), R(1)
        x0 = np.array([0.1])
        bnds = [(-2,2)]

    # We'll pass a wrapper for objective that closes over anchor_mode, anchor_values, etc.
    def wrapper(unconstrained_params):
        return nfxp_objective(
            unconstrained_params, anchor_mode, anchor_values, reg_lambda,
            expert_trajectories
        )

    result = minimize(
        wrapper,
        x0,
        method='L-BFGS-B',
        bounds=bnds,
        options={'maxiter':300, 'disp':False}
    )

    # Reconstruct full reward from result.x
    if anchor_mode == 0:
        r0, r1, r2 = result.x
    elif anchor_mode == 1:
        r0 = anchor_values[0]
        r1, r2 = result.x
    else:
        r0 = anchor_values[0]
        r1 = anchor_values[1]
        (r2,) = result.x

    est_rew = np.array([r0, r1, r2])
    return est_rew, result.fun

# -------------------------------------------------
# 6) Evaluate NxFP for All Table Scenarios
# -------------------------------------------------
def compute_expert_log_likelihood(reward, expert_trajectories, P):
    # Rebuild policy
    V_est = soft_value_iteration(reward, P)
    policy_est, _ = compute_policy(V_est, reward, P)
    total_loglik = 0.0
    eps_small = 1e-12
    for traj in expert_trajectories:
        for (s, a, s_next) in traj:
            p_a = policy_est[s,a]
            total_loglik += np.log(p_a + eps_small)
            # If transitions matter: total_loglik += np.log(P[(s,a)][s_next] + eps_small)
    return total_loglik

def compute_svf(policy, P, start_state=0, trajectory_length=5):
    d_t = np.zeros(n_states)
    d_t[start_state] = 1.0
    svf = np.zeros(n_states)
    for _ in range(trajectory_length):
        svf += d_t
        next_d_t = np.zeros(n_states)
        for s in range(n_states):
            for a in range(n_actions):
                next_d_t += d_t[s] * policy[s,a] * P[(s,a)]
        d_t = next_d_t
    return svf


if __name__ == "__main__":
    # 1) Basic Info
    print("=== NxFP Demo Over Table of (anchor_mode, reg_lambda) ===")
    print(f"States={n_states}, Actions={n_actions}, eps={eps}, gamma={gamma}")
    print("True Rewards:", true_rewards)
    n_sample_trajectories = 5000
    trajectory_length = 50
    print(f"NumTraj={n_sample_trajectories}, TrajLen={trajectory_length}\n")

    # 2) Generate Expert Data
    V_true = soft_value_iteration(true_rewards, P)
    policy_true, _ = compute_policy(V_true, true_rewards, P)
    expert_trajectories = generate_soft_optimal_trajectories(
        policy_true, P,
        n_trajectories=n_sample_trajectories,
        trajectory_length=trajectory_length
    )

    # 3) Evaluate all combinations of anchor_mode and reg_lambda
    anchor_modes = [0, 1, 2]          # 0=FixNone, 1=FixFirst, 2=FixFirstSecond
    reg_lambdas = [0.0, 0.05, 0.1, 0.2, 0.7]

    # Predefine anchor_values
    #  mode=1 => anchor R(0)=0.0
    #  mode=2 => anchor R(0)=0.0, R(1)=0.25
    results = {}

    for am in anchor_modes:
        if am == 0:
            am_str = "FixNone"
            anc_vals = None
        elif am == 1:
            am_str = "FixFirst"
            anc_vals = [0.0]
        else:
            am_str = "FixFirstSecond"
            anc_vals = [0.0, 0.25]

        for rl in reg_lambdas:
            # (1) Run NxFP estimation
            est_rew, final_obj = estimate_nfxp(expert_trajectories, am, anc_vals, rl)

            # (2) Evaluate
            # 2a) Policy
            V_est = soft_value_iteration(est_rew, P)
            policy_est, _ = compute_policy(V_est, est_rew, P)
            pol_diff = np.sum(np.abs(policy_est - policy_true))

            # 2b) RewDiff
            rew_diff = np.sum(np.abs(est_rew - true_rewards))

            # 2c) SVF difference vs. the "raw" expert distribution (s-only)
            #    We do the same 5-step distribution from state=0 as in the IRL script.
            svf_est = compute_svf(policy_est, P, 0, trajectory_length=5)
            # The expert distribution was from actual rollouts of length=50 with random starts,
            # so this 'svf' won't match exactly. For consistency with IRL script, we do the same approach:
            # We'll just track it anyway:
            svf_est /= np.sum(svf_est)

            # Build "expert_svf_raw" from the data
            expert_svf_raw = np.zeros(n_states)
            tot_steps = 0
            for traj in expert_trajectories:
                for (s,a,s_next) in traj:
                    expert_svf_raw[s] += 1
                tot_steps += len(traj)
            expert_svf_raw /= tot_steps

            svf_diff = np.sum(np.abs(svf_est - expert_svf_raw))
            grad_norm = np.linalg.norm(expert_svf_raw - svf_est)

            # 2d) Value difference
            val_diff = np.sum(np.abs(V_est - V_true))

            # 2e) Log-likelihood of expert demos
            loglik = compute_expert_log_likelihood(est_rew, expert_trajectories, P)

            # Store
            results[(am_str, rl)] = {
                'R': est_rew,
                'Obj': final_obj,
                'PolDiff': pol_diff,
                'RewDiff': rew_diff,
                'GradNorm': grad_norm,
                'ValDiff': val_diff,
                'SvfDiff': svf_diff,
                'LogLik': loglik
            }

    # 4) Print table of results
    print("=== NxFP Results Table ===")
    print("AnchorMode       | reg_lambda |    R(0)    R(1)    R(2)   | RewDiff | PolDiff | GradNorm | ValDiff | SvfDiff | LogLik     | -LL(Obj)")
    print("--------------------------------------------------------------------------------------------------------------")

    for am in anchor_modes:
        if am == 0:
            am_str = "FixNone"
        elif am == 1:
            am_str = "FixFirst"
        else:
            am_str = "FixFirstSecond"

        for rl in reg_lambdas:
            rvals = results[(am_str, rl)]
            R_0, R_1, R_2 = rvals['R']
            print(f"{am_str:<16} | {rl:<9.2f} | {R_0:8.4f} {R_1:8.4f} {R_2:8.4f} "
                  f"| {rvals['RewDiff']:.4f} "
                  f"| {rvals['PolDiff']:.4f} "
                  f"| {rvals['GradNorm']:.4f} "
                  f"| {rvals['ValDiff']:.4f} "
                  f"| {rvals['SvfDiff']:.4f} "
                  f"| {rvals['LogLik']:.2f} "
                  f"| {rvals['Obj']:.2f}")
    print("====================================================================================\n")


=== NxFP Demo Over Table of (anchor_mode, reg_lambda) ===
States=3, Actions=3, eps=0.05, gamma=0.9
True Rewards: [0.   0.25 0.9 ]
NumTraj=5000, TrajLen=50

=== NxFP Results Table ===
AnchorMode       | reg_lambda |    R(0)    R(1)    R(2)   | RewDiff | PolDiff | GradNorm | ValDiff | SvfDiff | LogLik     | -LL(Obj)
--------------------------------------------------------------------------------------------------------------
FixNone          | 0.00      |  -0.4380  -0.1861   0.4639 | 1.3102 | 0.0017 | 0.1875 | 13.0971 | 0.3032 | -260695.14 | 260695.14
FixNone          | 0.05      |  -0.4394  -0.1870   0.4630 | 1.3134 | 0.0021 | 0.1874 | 13.1281 | 0.3031 | -260695.15 | 260695.17
FixNone          | 0.10      |  -0.4303  -0.1795   0.4714 | 1.2883 | 0.0016 | 0.1875 | 12.8774 | 0.3033 | -260695.15 | 260695.19
FixNone          | 0.20      |  -0.4366  -0.1852   0.4660 | 1.3058 | 0.0023 | 0.1874 | 13.0498 | 0.3031 | -260695.16 | 260695.25
FixNone          | 0.70      |  -0.4362  -0.1853   0.4659