In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import matplotlib.patches as mpatches

# --- 1. GMM Parameters and Function ---
w = [0.5, 0.2, 0.3]
mu = [np.array([0.35, 0.38]), np.array([0.68, 0.25]), np.array([0.56, 0.64])]
Sigma = [
    np.array([[0.01, 0.004], [0.004, 0.01]]),
    np.array([[0.005, -0.003], [-0.003, 0.005]]),
    np.array([[0.008, 0.0], [0.0, 0.004]])
]
gaussians = [multivariate_normal(mean=m, cov=s) for m, s in zip(mu, Sigma)]
def gmm_pdf(x):
    pdf_val = 0
    for weight, gauss in zip(w, gaussians):
        pdf_val += weight * gauss.pdf(x)
    return pdf_val

# --- 2. Fourier Basis and Target Coefficients ---
L = 1.0
N_basis = 7
def get_k_vectors(N_basis, dim=2):
    k_vectors = []
    for i in range(N_basis):
        for j in range(N_basis):
            k_vectors.append(np.array([i, j]))
    return np.array(k_vectors)
k_vectors = get_k_vectors(N_basis)
num_k = len(k_vectors)

def F_k(x, k_vec):
    x = np.clip(x, 0, L)
    return np.prod(np.cos(np.pi * k_vec * x / L))

def grad_F_k(x, k_vec):
    x = np.clip(x, 0, L)
    grad = np.zeros_like(x, dtype=float)
    k1, k2 = k_vec; x1, x2 = x; pi_L = np.pi / L
    if k1 != 0: grad[0] = -k1 * pi_L * np.sin(k1 * pi_L * x1) * np.cos(k2 * pi_L * x2)
    else: grad[0] = 0
    if k2 != 0: grad[1] = -k2 * pi_L * np.cos(k1 * pi_L * x1) * np.sin(k2 * pi_L * x2)
    else: grad[1] = 0
    return grad

h_k = np.ones(num_k)
grid_size = 50
x_vals = np.linspace(0, L, grid_size); y_vals = np.linspace(0, L, grid_size)
xx, yy = np.meshgrid(x_vals, y_vals)
grid_points = np.vstack([xx.ravel(), yy.ravel()]).T
pdf_values = gmm_pdf(grid_points)
dA = (L / (grid_size - 1))**2
phi_k = np.zeros(num_k)
for i, k in enumerate(k_vectors):
    F_k_vals = F_k(grid_points, k)
    h_k_grid = np.sqrt(np.sum(F_k_vals**2) * dA)
    h_k[i] = L if h_k_grid < 1e-6 else h_k_grid
    phi_k[i] = np.sum(pdf_values * (F_k_vals / h_k[i])) * dA
s_lambda = 1.5
lambda_k = (1.0 + np.linalg.norm(k_vectors, axis=1)**2)**(-s_lambda)
lambda_k[0] = 0

# --- 3. iLQR Parameters & System Dynamics ---
dt = 0.1
T_horizon = 10.0
tlist = np.arange(0, T_horizon, dt)
tsteps = tlist.shape[0]
x0 = np.array([0.3, 0.3])
init_u_traj = np.random.randn(tsteps, 2) * 0.5 # Slightly larger random init

# Cost matrices & Scaling
R_u = np.diag([0.001, 0.001])
Ergodic_Scale = 500.0 # Start with a moderate scale
n_state = 2
n_ctrl = 2

# Dynamics
A = np.eye(n_state) # x_t+1 = x_t + u_t * dt => A = I
B = np.eye(n_ctrl) * dt # B = I * dt

def dyn(xt, ut): return ut
def step(xt, ut): return np.clip(xt + dt * dyn(xt, ut), 0, L)

def traj_sim(x0, ulist):
    tsteps_sim = ulist.shape[0]
    x_traj = np.zeros((tsteps_sim + 1, n_state))
    x_traj[0] = x0.copy()
    for t in range(tsteps_sim):
        x_traj[t+1] = step(x_traj[t], ulist[t])
    return x_traj

# --- 4. Ergodic Cost Function & Gradients ---
def calculate_ck(x_traj):
    ck = np.zeros(num_k)
    N = x_traj.shape[0]
    for xt in x_traj:
        F_k_vals = np.array([F_k(xt, k) / h for k, h in zip(k_vectors, h_k)])
        ck += F_k_vals
    return ck / N

def J_ergodic(x_traj, u_traj):
    ck = calculate_ck(x_traj[:-1]) # Use x_traj matching u_traj length
    erg_cost = np.sum(lambda_k * (ck - phi_k)**2) * Ergodic_Scale
    ctrl_cost = np.sum([ut.T @ R_u @ ut for ut in u_traj]) * dt
    return erg_cost + ctrl_cost

def get_cost_grads(x_traj, u_traj):
    """Calculates l_x, l_u, l_xx, l_uu, l_ux for all time steps."""
    N = u_traj.shape[0]
    l_x = np.zeros((N, n_state))
    l_u = np.zeros((N, n_ctrl))
    l_xx = np.zeros((N, n_state, n_state))
    l_uu = np.zeros((N, n_ctrl, n_ctrl))
    l_ux = np.zeros((N, n_ctrl, n_state))

    ck = calculate_ck(x_traj[:-1])
    Ek = ck - phi_k

    for t in range(N):
        xt = x_traj[t]
        ut = u_traj[t]
        grad_Fk_vals = np.array([grad_F_k(xt, k) / h for k, h in zip(k_vectors, h_k)])
        l_x[t] = np.sum(2.0 * lambda_k[:, np.newaxis] * Ek[:, np.newaxis] * (1.0 / N) * grad_Fk_vals, axis=0) * Ergodic_Scale
        l_u[t] = 2 * R_u @ ut
        l_xx[t] = np.zeros((n_state, n_state)) # We don't have l_xx
        l_uu[t] = 2 * R_u
        l_ux[t] = np.zeros((n_ctrl, n_state))

    return l_x, l_u, l_xx, l_uu, l_ux

# --- 5. Riccati iLQR Implementation ---

def backward_pass(l_x, l_u, l_xx, l_uu, l_ux, mu_reg):
    """Performs the Riccati backward pass."""
    N = l_x.shape[0]
    k_list = np.zeros((N, n_ctrl))
    K_list = np.zeros((N, n_ctrl, n_state))

    V_x = np.zeros(n_state) # No terminal cost gradient
    V_xx = np.zeros((n_state, n_state)) # No terminal cost hessian

    for t in range(N - 1, -1, -1):
        Q_x = l_x[t] + A.T @ V_x
        Q_u = l_u[t] + B.T @ V_x
        Q_xx = l_xx[t] + A.T @ V_xx @ A
        Q_uu = l_uu[t] + B.T @ V_xx @ B
        Q_ux = l_ux[t] + B.T @ V_xx @ A

        # Add regularization
        Q_uu_reg = Q_uu + np.eye(n_ctrl) * mu_reg

        # Check for positive definiteness (optional but good)
        try:
            np.linalg.cholesky(Q_uu_reg)
        except np.linalg.LinAlgError:
            print(f"Warning: Q_uu not positive definite at t={t}. Increasing regularization.")
            mu_reg *= 2.0 # Increase regularization
            Q_uu_reg = Q_uu + np.eye(n_ctrl) * mu_reg
            
        Q_uu_inv = np.linalg.inv(Q_uu_reg)

        k_t = -Q_uu_inv @ Q_u
        K_t = -Q_uu_inv @ Q_ux

        k_list[t] = k_t
        K_list[t] = K_t

        # Update Value function gradients
        V_x = Q_x + K_t.T @ Q_uu @ k_t + K_t.T @ Q_u + Q_ux.T @ k_t
        V_xx = Q_xx + K_t.T @ Q_uu @ K_t + K_t.T @ Q_ux + Q_ux.T @ K_t

    return k_list, K_list, mu_reg


def forward_pass(x0, u_traj, k_list, K_list, alpha):
    """Performs forward simulation with line search."""
    x_old = traj_sim(x0, u_traj)
    cost_old = J_ergodic(x_old, u_traj)

    for gamma_pow in range(10): # Try up to 10 powers of alpha
        gamma = alpha ** gamma_pow
        u_new = np.zeros_like(u_traj)
        x_new = np.zeros_like(x_old)
        x_new[0] = x0

        for t in range(u_traj.shape[0]):
            delta_x = x_new[t] - x_old[t]
            delta_u = gamma * k_list[t] + K_list[t] @ delta_x
            u_new[t] = u_traj[t] + delta_u
            x_new[t+1] = step(x_new[t], u_new[t])

        cost_new = J_ergodic(x_new, u_new)

        if cost_new < cost_old:
            print(f"  Line search: gamma = {gamma:.4f}, New Cost = {cost_new:.4f}")
            return x_new, u_new, True # Found a better trajectory

    print("  Line search failed.")
    return x_old, u_traj, False # Failed to find a better one

# --- 6. iLQR Iteration Loop ---
print("Starting Riccati iLQR iterations...")
u_traj = init_u_traj.copy()
n_iters = 30 # Number of iLQR iterations
mu_reg = 1e-6 # Initial regularization
alpha = 0.5 # Line search base

cost_history = []

for iter_num in range(n_iters):
    x_traj = traj_sim(x0, u_traj)
    current_cost = J_ergodic(x_traj, u_traj)
    cost_history.append(current_cost)
    print(f"Iter: {iter_num:02d}, Cost: {current_cost:.6f}, Mu: {mu_reg:.2e}")

    l_x, l_u, l_xx, l_uu, l_ux = get_cost_grads(x_traj, u_traj)

    # Try backward pass, increasing mu_reg if Q_uu isn't PD
    success_bp = False
    for _ in range(5): # Try increasing mu_reg a few times
        k_list, K_list, mu_reg_new = backward_pass(l_x, l_u, l_xx, l_uu, l_ux, mu_reg)
        if np.all(np.isfinite(k_list)) and np.all(np.isfinite(K_list)):
             mu_reg = mu_reg_new
             success_bp = True
             break
        else:
             print("Backward pass failed, increasing regularization.")
             mu_reg *= 10.0
    
    if not success_bp:
        print("Backward pass failed repeatedly. Stopping.")
        break

    x_traj_new, u_traj_new, success_fp = forward_pass(x0, u_traj, k_list, K_list, alpha)

    if success_fp:
        # Check convergence
        if iter_num > 0 and abs(cost_history[-2] - cost_history[-1]) < 1e-3:
             print("Converged.")
             u_traj = u_traj_new
             break
        u_traj = u_traj_new
        mu_reg *= 0.7 # Decrease regularization if successful
        mu_reg = max(1e-8, mu_reg) # Keep a minimum
    else:
        mu_reg *= 2.0 # Increase regularization if forward pass fails
        if mu_reg > 1e6:
            print("Regularization too high. Stopping.")
            break

print("iLQR complete.")

# --- 7. Plot Final Results ---
x_final_traj = traj_sim(x0, u_traj)

fig = plt.figure(figsize=(18, 5))
ax1 = fig.add_subplot(131)
ax1.contourf(xx, yy, pdf_values.reshape(xx.shape), levels=20, cmap='Reds', alpha=0.8)
ax1.plot(x_final_traj[:, 0], x_final_traj[:, 1], 'k-', lw=1.5, label='Trajectory')
ax1.plot(x0[0], x0[1], 'go', markersize=10, label='Start')
ax1.plot(x_final_traj[-1, 0], x_final_traj[-1, 1], 'b*', markersize=12, label='End')
for i, m in enumerate(mu):
    ax1.plot(m[0], m[1], 'bo', markersize=8)
    circle = mpatches.Ellipse(m, width=3*np.sqrt(Sigma[i][0,0]), height=3*np.sqrt(Sigma[i][1,1]),
                             fill=False, color='blue', ls='--', alpha=0.6)
    ax1.add_patch(circle)
ax1.set_title('Ergodic Trajectory (Riccati iLQR)')
ax1.set_xlabel('X'); ax1.set_ylabel('Y'); ax1.set_xlim(0, L); ax1.set_ylim(0, L)
ax1.set_aspect('equal', adjustable='box'); ax1.legend(); ax1.grid(True, linestyle='--', alpha=0.6)

ax2 = fig.add_subplot(132)
ax2.plot(tlist, u_traj[:, 0], label='$u_1(t)$'); ax2.plot(tlist, u_traj[:, 1], label='$u_2(t)$')
ax2.set_title('Optimal Control Signals (Riccati iLQR)'); ax2.set_xlabel('Time (s)'); ax2.set_ylabel('Control')
ax2.legend(); ax2.grid(True, linestyle='--', alpha=0.6)

ax3 = fig.add_subplot(133)
ax3.plot(range(len(cost_history)), cost_history)
ax3.set_title('Objective Value (iLQR Cost)'); ax3.set_xlabel('iLQR Iteration'); ax3.set_ylabel('Objective')
ax3.grid(True, linestyle='--', alpha=0.6)

plt.tight_layout()
plt.show()

Starting iLQR iterations...
Iter: 00, Cost: 1539.333665
Solving BVP...


  r_middle = 1.5 * col_res / h
  slope = (y[:, 1:] - y[:, :-1]) / h


KeyboardInterrupt: 