In [None]:
from idqn.environments.lqr import LinearQuadraticEnv
import numpy as np

env = LinearQuadraticEnv(A=0.8, B=-0.9, Q=-0.5, R=-0.5, S=0.2)
print("P:", env.P)
print("Optimal M:", env.optimal_weights[2], "Optimal G:", env.optimal_weights[0])
max_g = np.abs(env.optimal_weights[0] / 2)
min_m = 2 * env.optimal_weights[2]  # always negative

In [None]:
def optimal_bellman_iteration(point):
    return np.array([env.R + env.B ** 2 * point[1], env.Q + env.A ** 2 * point[1]])

def projection(point):
    return np.array([min(point[0], 0), np.clip(point[1], -max_g, max_g)])

def bellman_line(m):
    return env.Q - env.R * env.A ** 2 / env.B ** 2 + env.A ** 2 / env.B ** 2 * m

In [None]:
n_bellman_iteration = 2
initial_point = np.array([0.5 * env.optimal_weights[2], -max_g / 2])

list_iterated_points = [initial_point]
list_projected_points = [initial_point]
current_point = initial_point

for i in range(n_bellman_iteration):
    iterated_point = optimal_bellman_iteration(current_point)
    current_point = projection(iterated_point)

    list_iterated_points.append(iterated_point)
    list_projected_points.append(current_point)

list_iterated_points = np.array(list_iterated_points)
list_projected_points = np.array(list_projected_points)

In [None]:
import jax
import jax.numpy as jnp

def Q(params, s, a):
    return params["G"] * s ** 2 + params["M"] * a ** 2

def loss_dqn(params, target_params, batch_s, batch_a):
    prediction = Q(params, batch_s, batch_a)

    batch_s_prime = env.A * batch_s + env.B * batch_a
    batch_reward = env.Q * batch_s ** 2 + 2 * env.S * batch_s * batch_a + env.R * batch_a ** 2
    target = batch_reward + target_params["G"] * batch_s_prime ** 2


    return jnp.square(prediction - target).mean()

def loss_idqn(params, target_params, batch_s, batch_a):
    return loss_dqn(params["params_1"], target_params["params_0"], batch_s, batch_a) + loss_dqn(params["params_2"], target_params["params_1"], batch_s, batch_a)


gradient_steps = 2
lr = 0.002
batch_s = jnp.linspace(-5, 5, 10)
batch_a = jnp.linspace(-5, 5, 10)
dqn_params = {"M": initial_point[0], "G": initial_point[1]}
target_dqn_params = {"M": initial_point[0], "G": initial_point[1]}
idqn_params = {"params_0": dqn_params, "params_1": dqn_params, "params_2": dqn_params}
target_idqn_params = {"params_0": dqn_params, "params_1": dqn_params, "params_2": dqn_params}

for g in range(gradient_steps):
    grad_dqn = jax.grad(loss_dqn)(dqn_params, target_dqn_params, batch_s, batch_a)
    dqn_params = jax.tree_map(lambda x, y: x + lr * y, dqn_params, grad_dqn)

    grad_idqn = jax.grad(loss_idqn)(idqn_params, target_idqn_params, batch_s, batch_a)
    idqn_params = jax.tree_map(lambda x, y: x + lr * y, idqn_params, grad_idqn)
    target_idqn_params = idqn_params  # update the target parameters

extract_point = lambda params: np.array([params["M"], params["G"]])
dqn_point = extract_point(dqn_params)
idqn_point_1 = extract_point(idqn_params["params_1"])
idqn_point_2 = extract_point(idqn_params["params_2"])

In [None]:
import matplotlib.pyplot as plt


n_points = 10
ms = np.linspace(min_m, 0, n_points)

# Optimal point
plt.scatter(env.optimal_weights[2], env.optimal_weights[0], color="black", s=50)
plt.text(-0.05 + env.optimal_weights[2], env.optimal_weights[0], "$Q^*$", fontsize=20)

# Allowed weight space
plt.plot(ms, np.ones(n_points) * max_g, color="black")
plt.plot(ms, -np.ones(n_points) * max_g, color="black")
plt.fill_between(ms, -np.ones(n_points) * max_g, np.ones(n_points) * max_g, color="green", alpha=0.3)
plt.vlines(0, ymin=-max_g, ymax=max_g, color="black")

# Bellman line
# plt.plot(ms, bellman_line(ms))

# Bellman iterations
plt.scatter(initial_point[0], initial_point[1], color="black", s=50)
plt.text(-0.09 + initial_point[0], initial_point[1], "$Q_0$", fontsize=20)
for i in range(1, n_bellman_iteration + 1):
    x, y = list_projected_points[i - 1, 0], list_projected_points[i - 1, 1]
    dx, dy = list_iterated_points[i, 0] - x, list_iterated_points[i, 1] - y
    plt.arrow(x, y, dx / 1.2, dy / 1.2, head_width = 0.02, color="black")

    plt.scatter(list_iterated_points[i, 0], list_iterated_points[i, 1], color="black", s=50)
    plt.text(list_iterated_points[i, 0], -0.05 + list_iterated_points[i, 0], f"$\Gamma^*Q_{i - 1}$", fontsize=20)

    x, y = list_projected_points[i, 0], list_projected_points[i, 1]
    dx, dy = list_iterated_points[i, 0] - x, list_iterated_points[i, 1] - y
    plt.arrow(x, y, dx, dy, color="black", linestyle="dashed")

plt.scatter(dqn_point[0], dqn_point[1], color="red", s=50)
plt.scatter(idqn_point_1[0], idqn_point_1[1], color="green", s=10)
plt.scatter(idqn_point_2[0], idqn_point_2[1], color="green", s=10)

plt.xlim(-1, -0.25)
plt.ylim(-1, 0)
