In [6]:
import numpy as np
import torch

############################################################################
# 1) Triangular MDP Setup
############################################################################
n_states = 3
n_actions = 3
gamma = 0.9
eps = 0.05

def build_transition_matrix(eps=0.05):
    """
    P[s, a, s'] = probability of going to s' from s when taking action a.
    Actions: 0=left, 1=right, 2=stay.
    """
    P = np.zeros((n_states, n_actions, n_states))

    def dist(target):
        d = np.ones(n_states) * (eps / n_states)
        d[target] += (1.0 - eps)
        return d
    
    for s in range(n_states):
        if s == 0:
            P[s, 0] = dist(2)  # left from 0 -> 2
            P[s, 1] = dist(1)  # right from 0 -> 1
            P[s, 2] = dist(0)  # stay in 0
        elif s == 1:
            P[s, 0] = dist(0)  # left from 1 -> 0
            P[s, 1] = dist(2)  # right from 1 -> 2
            P[s, 2] = dist(1)  # stay in 1
        else:  # s == 2
            P[s, 0] = dist(1)  # left from 2 -> 1
            P[s, 1] = dist(0)  # right from 2 -> 0
            P[s, 2] = dist(2)  # stay in 2

    return P

P_np = build_transition_matrix(eps)
P = torch.tensor(P_np, dtype=torch.float)


############################################################################
# 2) True Rewards & Generate Expert Data
############################################################################
R_true_full = torch.tensor([0.5, 0.25, 0.1])  # one reward per state

def compute_softmax_policy(Q):
    """
    pi[s,a] = softmax(Q[s,a]) along actions
    """
    shifted = Q - Q.max(dim=1, keepdim=True).values  # stable
    expQ = shifted.exp()
    denom = expQ.sum(dim=1, keepdim=True)
    return expQ / denom

def generate_expert_trajectories(Q_true, n_trajectories=100, traj_len=5):
    """
    Sample states/actions from the policy induced by Q_true.
    Return list of (s,a) pairs for all trajectories.
    """
    policy = compute_softmax_policy(Q_true)
    transitions = []
    for _ in range(n_trajectories):
        s = np.random.choice(n_states)  # random start
        for _step in range(traj_len):
            a = np.random.choice(n_actions, p=policy[s].detach().numpy())
            transitions.append((s, a))
            # next state
            s_next = np.random.choice(n_states, p=P_np[s,a])
            s = s_next
    return transitions

# We'll do a quick approximate "soft" Q-iteration to get Q_true for demonstration:
Q_true_full = torch.zeros(n_states, n_actions)
for _ in range(200):
    Qnew = Q_true_full.clone()
    for s in range(n_states):
        for a in range(n_actions):
            sum_sprime = 0.0
            for s_next in range(n_states):
                # V(s_next) = logsumexp(Q[s_next,:])
                V_snext = torch.logsumexp(Q_true_full[s_next], dim=0)
                sum_sprime += P[s,a,s_next] * V_snext
            Qnew[s,a] = R_true_full[s] + gamma * sum_sprime
    if torch.max(torch.abs(Qnew - Q_true_full)) < 1e-6:
        break
    Q_true_full = Qnew

expert_data = generate_expert_trajectories(Q_true_full, n_trajectories=500, traj_len=10)
policy_true = compute_softmax_policy(Q_true_full)
V_true = torch.logsumexp(Q_true_full, dim=1)  # V(s)=logsumexp(Q(s,a))

############################################################################
# 3) Single-Loop IRL: Fix R(0)=0.5, L2 Regularization
############################################################################

# -------------------
# (a) Setup Learnable Parameters
# -------------------
# We fix the first state's reward = 0.5 (the true value),
# and only learn the latter two states' rewards.
R_param = torch.zeros(2, requires_grad=True)  # for states 1,2
Q = torch.zeros(n_states, n_actions, requires_grad=True)  

optimizer = torch.optim.Adam([R_param, Q], lr=0.05)
alpha_bellman = 10.0   # weight on Bellman error
alpha_l2      = 0.001  # L2 regularization strength
max_iters     = 1001
print_every   = 100

def full_reward_vector(R_param):
    """
    Reconstruct R(s) where R(0)=0.5, R(1)=R_param[0], R(2)=R_param[1].
    """
    # First state is pinned at 0.5
    return torch.cat([torch.tensor([0.5]), R_param], dim=0)

for it in range(max_iters):
    optimizer.zero_grad()

    # 1) Reconstruct the full reward vector
    R_learn = full_reward_vector(R_param)  # shape [3]

    # 2) Compute policy from current Q
    policy = compute_softmax_policy(Q)

    # 3) Negative Log-Likelihood of expert data
    nll = 0.0
    for (s_exp, a_exp) in expert_data:
        nll = nll - torch.log(policy[s_exp, a_exp] + 1e-10)

    # 4) Soft Bellman Consistency: Q(s,a) ~ R(s) + gamma * Σ P(s,a,s') * logsumexp(Q[s',:])
    bellman_error = 0.0
    for s in range(n_states):
        for a in range(n_actions):
            sum_sprime = 0.0
            for s_next in range(n_states):
                V_snext = torch.logsumexp(Q[s_next], dim=0)
                sum_sprime += P[s,a,s_next] * V_snext
            target = R_learn[s] + gamma * sum_sprime
            bellman_error += (Q[s,a] - target)**2

    # 5) L2 regularization on Q and learnable R
    l2_reg = Q.pow(2).sum() + R_param.pow(2).sum()

    # 6) Total Loss
    loss = nll + alpha_bellman * bellman_error + alpha_l2 * l2_reg

    # 7) Backprop + Update
    loss.backward()
    optimizer.step()

    # ----------------
    # Debugging Info
    # ----------------
    if it % print_every == 0:
        with torch.no_grad():
            R_current = full_reward_vector(R_param)
            policy_diff = (policy - policy_true).abs().sum().item()
            V_current = torch.logsumexp(Q, dim=1)
            val_diff = (V_current - V_true).abs().sum().item()
            rew_diff = (R_current - R_true_full).abs().sum().item()

            print(f"Iter={it:04d}, "
                  f"Loss={loss.item():.4f}, "
                  f"NLL={nll.item():.4f}, "
                  f"Bellman={bellman_error.item():.4f}, "
                  f"L2={l2_reg.item():.4f}")
            print(f"  R={R_current.detach().numpy()}, RewDiff={rew_diff:.4f}")
            print(f"  PolDiff={policy_diff:.4f}, ValDiff={val_diff:.4f}")
            print("")

# ----------------
# Final Results
# ----------------
R_final = full_reward_vector(R_param).detach().numpy()
print("==== Final Results ====")
print(f"Learned Rewards = {R_final}")
print(f"True Rewards    = {R_true_full.numpy()}")
print("Learned Q =")
print(Q.detach().numpy())


Iter=0000, Loss=5618.3013, NLL=5493.1523, Bellman=12.5149, L2=0.0000
  R=[ 0.5  -0.05 -0.05], RewDiff=0.4500
  PolDiff=0.3829, ValDiff=38.5110

Iter=0100, Loss=5456.2183, NLL=5432.2593, Bellman=2.3951, L2=7.6653
  R=[ 0.5       -1.3857818 -1.475164 ], RewDiff=3.2109
  PolDiff=0.0947, ValDiff=38.4013

Iter=0200, Loss=5455.5454, NLL=5432.2168, Bellman=2.3321, L2=7.6924
  R=[ 0.5       -1.3609862 -1.4494882], RewDiff=3.1605
  PolDiff=0.0931, ValDiff=37.8794

Iter=0300, Loss=5454.6792, NLL=5432.1943, Bellman=2.2476, L2=8.6416
  R=[ 0.5       -1.3328481 -1.4214159], RewDiff=3.1043
  PolDiff=0.0926, ValDiff=37.2007

Iter=0400, Loss=5453.6729, NLL=5432.1694, Bellman=2.1493, L2=10.9694
  R=[ 0.5       -1.2993728 -1.3880069], RewDiff=3.0374
  PolDiff=0.0919, ValDiff=36.3942

Iter=0500, Loss=5452.5601, NLL=5432.1421, Bellman=2.0403, L2=15.1909
  R=[ 0.5       -1.2613928 -1.3501029], RewDiff=2.9615
  PolDiff=0.0912, ValDiff=35.4791

Iter=0600, Loss=5451.4160, NLL=5432.1602, Bellman=1.9234, L2=21.