### Function Approximation Big Idea

So far, the states represented are finite and as such, it is possible for us to keep a Q-table (or dictionary) that stores the Q-value of each state. However, this becomes practically impossible if the number of states is very large or we havef a continuous state. One example would be the position and velocity of a car moving up a mountain - the state, represented by the position (x,y coodinate) and velocity of the car are continuous and the number of states is basically infinite. Hence, in such problem, we cannot use a Q-table to store the Q-value.

Therefore, we use a function, parameterized by $\mathbb{w}$ to approximate the Q-value of all the states. This function can be a linear function (taking the form of $Q(s,a; \mathbb{w}) = \mathbb{w}^T \phi(s,a)$) or a (deep) neural network. Here, we represent our states using a feature vector $\phi(s,a)$.

The key idea is that instead of memorizing the Q-value for every possible state-action pair, we learn a mapping that generalizes across similar states. For example, if two states share similar features (say the car is moving slowly uphill in both cases), the function approximation allows us to assign similar Q-values without having to visit and store both states explicitly. This generalization is what makes function approximation powerful in large or continuous environments (but can also be hard to find a good function that can approximate q-values correctly for all states)

### Update rule

Recall that in standard Q-learning with discrete states, the update rule is:

$$
Q(s,a) \leftarrow Q(s,a)+α[r+\gamma \max_{a'} Q(s′,a′)−Q(s,a)]
$$

When we use function approximation, we replace the Q-table with a parameterized function $Q(s, a; \mathbf{w})$, where $\mathbf{w}$ are the parameters (weights) we need to learn.Therefore, the goal is to find parameters $\mathbf{w}$ that minimize the difference between the current Q-value estimate and the target Q-value.

The temporal difference (TD) error for a transition $(s, a, r, s')$ is:

$$
\delta =r+\gamma \max_{a'} Q(s′,a′;w)−Q(s,a;w))
$$

Instead of updating a table entry, we now update the parameters $\mathbf{w}$ to reduce this error. Specifically, we want to optimize $\mathbf{w}$ to make $Q(s, a; \mathbf{w})$ closer to the target $r + \gamma \max_{a'} Q(s', a'; \mathbf{w})$.

One way to do that is to use Gradient Descent (or variants of it). The loss function for a single sample is:

$$
L(w)=\frac{1}{2}[r+\gamma \max_{a'} Q(s′,a′;w)−Q(s,a;w)]^2
$$

The gradient of this loss with respect to $\mathbf{w}$ is:

$$
\nabla_w L(w)= −[r+\gamma \max_{a'} Q(s′,a′;w)−Q(s,a;w)]\nabla_wQ(s,a;w)
$$

Therefore, the update rule for the parameters is:
$$
w\leftarrow w−\alpha\nabla_wL(w)
$$

where $\alpha$ is the learning rate.

In [None]:
#pip install gymnax

In [3]:
import jax
import jax.numpy as jnp
import optax
import gymnax

In [4]:
env, env_params = gymnax.make("CartPole-v1")
rng = jax.random.PRNGKey(0)
obs, state = env.reset(rng, env_params)

In [5]:
def rbf_features(x, centers, sigma=0.5):
    # x: (d,), centers: (n_centers, d)
    diffs = x - centers  # (n_centers, d)
    sq_dist = jnp.sum(diffs**2, axis=-1)
    return jnp.exp(-sq_dist / (2 * sigma**2))  # (n_centers,)

def init_params(rng, n_features, n_actions):
    W = jax.random.normal(rng, (n_features, n_actions)) * 0.1
    return W

def q_values(W, obs, centers, sigma=0.5):
    phi = rbf_features(obs, centers, sigma)  # (n_features,)
    return phi @ W  # (n_actions,)

# epsilon greedy
def select_action(W, obs, rng, centers, sigma=0.5, epsilon=0.1):
    q = q_values(W, obs, centers, sigma)
    greedy = jnp.argmax(q)
    explore = jax.random.bernoulli(rng, epsilon)
    return jnp.where(explore, jax.random.randint(rng, (), 0, q.shape[0]), greedy)

# TD loss (Q-learning)
def td_loss(W, obs, action, reward, next_obs, done, gamma, centers, sigma):
    q = q_values(W, obs, centers, sigma)[action]
    next_q = jnp.max(q_values(W, next_obs, centers, sigma))
    target = reward + gamma * (1 - done) * next_q
    return 0.5 * (q - target) ** 2


In [6]:
def train(num_episodes=500, lr=1e-2, gamma=0.99, n_centers=50, sigma=0.5):
    obs_dim = env.observation_space(env_params).shape[0]
    n_actions = env.action_space(env_params).n

    rng = jax.random.PRNGKey(0)
    rng, centers_rng, init_rng = jax.random.split(rng, 3)

    # Random RBF centers sampled from [-1,1]
    centers = jax.random.uniform(centers_rng, (n_centers, obs_dim), minval=-1, maxval=1)

    W = init_params(init_rng, n_centers, n_actions)
    opt = optax.sgd(lr)
    opt_state = opt.init(W)

    @jax.jit
    def update(W, opt_state, obs, action, reward, next_obs, done, rng):
        loss_fn = lambda p: td_loss(p, obs, action, reward, next_obs, done, gamma, centers, sigma)
        grads = jax.grad(loss_fn)(W)
        updates, opt_state = opt.update(grads, opt_state, W)
        W = optax.apply_updates(W, updates)
        return W, opt_state

    for ep in range(num_episodes):
        rng, ep_rng = jax.random.split(rng)
        obs, state = env.reset(ep_rng, env_params)
        done = False
        ep_reward = 0.0

        while not done:
            rng, act_rng = jax.random.split(rng)
            action = select_action(W, obs, act_rng, centers, sigma)
            next_obs, state, reward, done, _ = env.step(act_rng, state, action, env_params)

            W, opt_state = update(W, opt_state, obs, action, reward, next_obs, done, rng)
            obs = next_obs
            ep_reward += reward

        print(f"Episode {ep}, reward: {ep_reward:.1f}")

    return W, centers


In [7]:
train()

Episode 0, reward: 17.0
Episode 1, reward: 19.0
Episode 2, reward: 14.0
Episode 3, reward: 22.0
Episode 4, reward: 20.0
Episode 5, reward: 21.0
Episode 6, reward: 16.0
Episode 7, reward: 21.0
Episode 8, reward: 27.0
Episode 9, reward: 23.0
Episode 10, reward: 21.0
Episode 11, reward: 26.0
Episode 12, reward: 23.0
Episode 13, reward: 23.0
Episode 14, reward: 29.0
Episode 15, reward: 38.0
Episode 16, reward: 10.0
Episode 17, reward: 11.0
Episode 18, reward: 10.0
Episode 19, reward: 21.0
Episode 20, reward: 23.0
Episode 21, reward: 23.0
Episode 22, reward: 28.0
Episode 23, reward: 27.0
Episode 24, reward: 26.0
Episode 25, reward: 19.0
Episode 26, reward: 20.0
Episode 27, reward: 19.0
Episode 28, reward: 30.0
Episode 29, reward: 25.0
Episode 30, reward: 29.0
Episode 31, reward: 31.0
Episode 32, reward: 42.0
Episode 33, reward: 79.0
Episode 34, reward: 41.0
Episode 35, reward: 49.0
Episode 36, reward: 25.0
Episode 37, reward: 54.0
Episode 38, reward: 24.0
Episode 39, reward: 27.0
Episode 40

(Array([[ 0.23053566, -0.07308456],
        [ 0.20252419,  0.00577167],
        [-0.04803572,  0.01652985],
        [ 0.5268843 , -0.04535921],
        [ 0.15910895, -0.04975554],
        [-0.06341079,  0.24048017],
        [ 0.08211395,  0.24280427],
        [ 0.70962745,  0.49221927],
        [ 0.50684416,  0.24370238],
        [ 0.1339713 , -0.02061459],
        [-0.03336898, -0.1947515 ],
        [-0.03958366,  0.18594886],
        [ 0.76393896,  0.9209299 ],
        [-0.00718571,  0.21688199],
        [ 0.13983637, -0.23056711],
        [-0.2383598 ,  0.09403404],
        [-0.00434367,  0.02627548],
        [ 0.1793421 ,  0.03034162],
        [ 0.28782344,  0.57183343],
        [ 0.3927906 ,  0.19313508],
        [ 0.06950882, -0.03123685],
        [ 0.1913739 ,  0.06651889],
        [ 0.47488075,  0.3558313 ],
        [ 0.40317017,  0.18110329],
        [ 0.03161961,  0.01166398],
        [ 0.25133145,  0.10042507],
        [-0.03300701,  0.25252035],
        [ 0.02113672,  0.134