<a href="https://colab.research.google.com/github/dsaldana/reinforcement-learning-course/blob/main/lab6_nonlinear.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 6: Non-linear function approximation

## Exercise 1: Q-Learning with a Neural Network (PyTorch) on MountainCar

**Objective:**
Implement **Q-learning** with a **PyTorch neural network** to solve `MountainCar-v0`. You will approximate Q(s, a) with a small MLP, train it from batches of transitions sampled from a replay buffer, and evaluate the learned policy.

---

## Environment
- **Gym** environment: `MountainCar-v0`
- **State**: continuous (position, velocity) → shape `(2,)`
- **Actions**: {0: left, 1: no push, 2: right}
- **Reward**: -1 per step until the goal (`position >= 0.5`)
- **Episode limit**: 500 steps
- **Goal**: reduce steps-to-goal and improve return over training

---

## What You Must Implement

### 1) Q-Network (PyTorch)
Create a small MLP `QNetwork` that maps `state -> Q-values for 3 actions`.
- Inputs: `(batch_size, 2)` float32
- Outputs: `(batch_size, 3)` Q-values
- Suggested architecture: `2 → 64 → 3` with ReLU
- Initialize weights reasonably (PyTorch defaults are fine)

### 2) Replay Buffer
A cyclic buffer to store transitions `(s, a, r, s_next, done)`:
- `append(s, a, r, s_next, done)`
- `sample(batch_size)` → tensors ready for PyTorch (float32 for states, int64 for actions, float32 for rewards/done)

### 3) ε-Greedy Policy
- With probability `epsilon`: pick a random action
- Otherwise: `argmax_a Q(s, a)` from the current network
- Use **decaying ε** (e.g., from 1.0 down to 0.05 over ~20–50k steps)

### 4) Q-Learning Target and Loss
For a sampled batch:
- Compute `q_pred = Q(s).gather(1, a)`  (shape `(batch, 1)`)
- Compute target:
  - If `done`: `target = r`
  - Else: `target = r + gamma * max_a' Q(s_next, a').detach()`
- Loss: Mean Squared Error (MSE) between `q_pred` and `target`

> **Stabilization (recommended)**: Use a **target network** `Q_target` (periodically copy weights from `Q_online`) to compute the max over next-state actions. Update every `target_update_freq` steps.

### 5) Deep Q-learning method
- For each environment step:
  1. Select action with ε-greedy
  2. Step the env, store transition in buffer
  3. If `len(buffer) >= batch_size`:
     - Sample a batch
     - Compute `q_pred`, `target`
     - Backprop: `optimizer.zero_grad(); loss.backward(); optimizer.step()`
     - (Optional) gradient clipping (e.g., `clip_grad_norm_` at 10)
  4. Periodically update `Q_target ← Q_online` (if using target net)
- Track episode returns (sum of rewards) and steps-to-goal

---

## Evaluation
- Run **evaluation episodes** with `epsilon = 0.0` (greedy) every N training episodes
- Report:
  - Average steps-to-goal (lower is better; random policy is ~200)
  - Average return (less negative is better)
- Plot:
  - Training episode return

---

## Deliverables
1. **Code**: In a notebook.
2. **Plots**:
   - Episode  vs return
   - Final value function (State (postition and velocity) Vs Max(Q(state)))

3. **Short write-up** (also in the notebook):
   - **Performance of your DQN agent**: How quickly does it learn? Does it reach the goal consistently?
   - **Comparison with tile coding**:
     - Which representation learns faster?
     - Which one is more stable?
     - How do the function approximation choices (linear with tiles vs. neural network) affect generalization?
     - Did the NN require more tuning (learning rate, ε schedule) compared to tile coding?
   - **Insights**: What are the trade-offs between hand-crafted features (tiles) and learned features (neural networks)?



In [37]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

# Set up environment
env = gym.make("MountainCar-v0")
n_actions = env.action_space.n
state_dim = env.observation_space.shape[0]

# Hyperparameters
gamma = 0.99
alpha = 0.001
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
num_episodes = 5000
batch_size = 64
replay_buffer_size = 50000

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [38]:
print(env.observation_space)


Box([-1.2  -0.07], [0.6  0.07], (2,), float32)


In [39]:
# Define Q-Network
class QNetwork(nn.Module):
    def __init__(self, state_dim=2, n_actions=3):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, n_actions)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [40]:
# Initialize Q-network and optimizer
q_net = QNetwork(state_dim, n_actions).to(device)
optimizer = optim.Adam(q_net.parameters(), lr=alpha)
loss_fn = nn.MSELoss()
replay_buffer = deque(maxlen=replay_buffer_size)

In [41]:
def epsilon_greedy(state_loc, epsilon_loc, step, decay_steps=50_000):
  ############ TODO ###########
  epsilon_loc = epsilon_min + (epsilon_loc - epsilon_min) * np.exp(-1.0 * step / decay_steps)
  if random.random() < epsilon_loc:
        return random.randint(0, n_actions - 1)
  else:
        # print(f"{state_loc=}")
        state = torch.FloatTensor(state_loc).unsqueeze(0).to(device)
        with torch.no_grad():
          q_values = q_net(state)
        q_values = q_values.cpu().numpy()
        max_q = np.max(q_values)
        max_actions = np.where(q_values == max_q)[0]
        action = np.random.choice(max_actions)
        return action  # random tie-breaker
        

In [42]:
def train_dqn():
    """Train the DQN using experience replay."""
    if len(replay_buffer) < batch_size:
        return
    batch = random.sample(replay_buffer, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones).to(device)

    q_values = q_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q_values = q_net(next_states).max(1)[0].detach()
    targets = rewards + gamma * next_q_values * (1 - dones)

    loss = loss_fn(q_values, targets)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [None]:
from tqdm import tqdm

## MAIN Loop ###
rewards_dqn = []

for episode in tqdm(range(num_episodes), leave=True):
  state = env.reset()[0]
  total_reward = 0
  done = False
  counter = 0
  while not done:
    action = epsilon_greedy(state, epsilon, episode)
    next_state, reward, done, _, _ = env.step(action)
    ############ TODO ###########
    replay_buffer.append((state, action, reward, next_state, done))
    train_dqn()
    counter += 1
    # if counter % 10000 == 0:
    #     print("counter:{} action: {}, state: {}".format(counter, action, state))
    total_reward += reward
    state = next_state
  rewards_dqn.append(total_reward)
print(rewards_dqn)
    

  0%|          | 0/5000 [00:00<?, ?it/s]

counter:10000 action: 0, state: [-0.85207963  0.00427269]


  0%|          | 2/5000 [00:19<12:31:11,  9.02s/it]

counter:10000 action: 0, state: [-0.7839436   0.00100684]


  0%|          | 3/5000 [00:30<14:18:40, 10.31s/it]

counter:10000 action: 0, state: [-0.33850166  0.02042399]


  0%|          | 4/5000 [00:39<13:10:09,  9.49s/it]

counter:10000 action: 1, state: [-0.53426516  0.02073868]
counter:20000 action: 0, state: [-0.50851977 -0.05618177]
counter:30000 action: 2, state: [-0.7929173  -0.02354778]


  0%|          | 6/5000 [01:11<16:50:43, 12.14s/it]

counter:10000 action: 0, state: [-0.25937817 -0.00099938]
counter:20000 action: 1, state: [-0.8235739   0.03069822]


  0%|          | 7/5000 [01:28<18:57:00, 13.66s/it]

counter:10000 action: 2, state: [-0.44198063  0.03674978]
counter:20000 action: 1, state: [-1.1113517 -0.0137649]
counter:30000 action: 2, state: [-0.865238   -0.03617394]
counter:40000 action: 1, state: [-0.57382125 -0.03504675]
counter:50000 action: 0, state: [ 0.10958759 -0.01550075]
counter:60000 action: 0, state: [-0.11973941 -0.03275781]
counter:70000 action: 2, state: [-0.5268002   0.04954583]
counter:80000 action: 1, state: [-0.4606324   0.00826167]
counter:90000 action: 2, state: [-0.75415784  0.01915178]
counter:100000 action: 2, state: [-1.0120693   0.02323061]
counter:110000 action: 0, state: [-0.6045859  0.0038088]
counter:120000 action: 0, state: [-0.34988105 -0.01012337]
counter:130000 action: 1, state: [-0.9951061  -0.00345748]
counter:140000 action: 0, state: [-0.34575647  0.00614384]
counter:150000 action: 0, state: [-0.47322103 -0.00869492]
counter:160000 action: 2, state: [-0.8335608   0.02953197]


  0%|          | 9/5000 [03:40<49:52:30, 35.97s/it]

counter:10000 action: 0, state: [-0.86270916 -0.02217313]


  0%|          | 10/5000 [03:55<40:39:43, 29.34s/it]

counter:10000 action: 1, state: [-0.5549237 -0.0356425]
counter:20000 action: 0, state: [-0.13474472 -0.00277726]
counter:30000 action: 2, state: [-0.46978927 -0.04533459]


  0%|          | 11/5000 [04:18<37:55:25, 27.37s/it]

counter:10000 action: 1, state: [-0.30836314  0.01470117]


  0%|          | 12/5000 [04:32<32:13:23, 23.26s/it]

counter:10000 action: 2, state: [-0.4407754   0.02048686]
counter:20000 action: 2, state: [-0.77185607 -0.01573887]
counter:30000 action: 1, state: [-0.09955028  0.02788705]
counter:40000 action: 2, state: [-0.6153518   0.01117559]
counter:50000 action: 0, state: [-0.26262817 -0.04097274]


  0%|          | 13/5000 [05:11<39:10:13, 28.28s/it]

counter:10000 action: 2, state: [-0.21186958 -0.0021939 ]
counter:20000 action: 0, state: [0.09498738 0.00688585]


  0%|          | 14/5000 [05:32<35:47:17, 25.84s/it]

counter:10000 action: 1, state: [-0.83964837 -0.03150315]


  0%|          | 15/5000 [05:41<28:42:14, 20.73s/it]

counter:10000 action: 0, state: [-0.73760325 -0.01284212]
counter:20000 action: 2, state: [-0.4430753   0.04977961]
counter:30000 action: 2, state: [-0.43518728  0.01788615]
counter:40000 action: 2, state: [-0.09055731  0.03439914]
counter:50000 action: 1, state: [-0.18089825 -0.04764877]


  0%|          | 16/5000 [06:21<36:46:32, 26.56s/it]

counter:10000 action: 0, state: [-0.62539434  0.04510339]
counter:20000 action: 0, state: [-1.0767242   0.00690133]
counter:30000 action: 1, state: [-1.1912637   0.00549445]
counter:40000 action: 1, state: [-0.16722193  0.00208259]


  0%|          | 17/5000 [06:55<40:03:23, 28.94s/it]

counter:10000 action: 2, state: [-0.7777497  -0.03598491]
counter:20000 action: 0, state: [-0.03365885 -0.03260805]
counter:30000 action: 0, state: [-1.1898193 -0.0112497]
counter:40000 action: 2, state: [-0.42459399  0.02359404]
counter:50000 action: 1, state: [0.05937864 0.03057848]
counter:60000 action: 0, state: [-1.1146674  0.0185374]
counter:70000 action: 2, state: [-0.04919336  0.01295729]
counter:80000 action: 2, state: [-0.7257457   0.01119223]
counter:90000 action: 0, state: [-0.02332597 -0.0265806 ]
counter:100000 action: 0, state: [-0.33659244  0.00942899]
counter:110000 action: 2, state: [-0.09545782  0.00802757]
counter:120000 action: 0, state: [ 0.16781071 -0.00118337]
counter:130000 action: 2, state: [-0.4653057  -0.04793753]
counter:140000 action: 1, state: [-0.01947564 -0.02213885]


  0%|          | 19/5000 [08:42<51:19:32, 37.10s/it]

counter:10000 action: 1, state: [-0.4717437   0.02698622]
counter:20000 action: 2, state: [-1.0707957  -0.01865973]
counter:30000 action: 2, state: [-0.7759946  -0.00900601]
counter:40000 action: 2, state: [0.0050498  0.01817155]
counter:50000 action: 2, state: [ 0.1389362  -0.00589636]
counter:60000 action: 2, state: [-0.5180985  0.0009403]
counter:70000 action: 1, state: [-0.84735155 -0.00905676]
counter:80000 action: 2, state: [-0.872727   -0.01502903]
counter:90000 action: 0, state: [-0.15696788  0.02478435]
counter:100000 action: 2, state: [ 0.1846653  -0.01047408]


  0%|          | 20/5000 [09:55<66:14:01, 47.88s/it]

counter:10000 action: 1, state: [-0.75135636  0.00368155]
counter:20000 action: 0, state: [-0.95236576  0.01051763]
counter:30000 action: 0, state: [-0.39565682 -0.01175105]
counter:40000 action: 0, state: [-0.90568286  0.03639715]
counter:50000 action: 2, state: [-0.7214575   0.00466459]
counter:60000 action: 2, state: [-0.5366504  0.018983 ]
counter:70000 action: 2, state: [-1.0575769  -0.02183812]


  0%|          | 21/5000 [10:52<69:55:26, 50.56s/it]

counter:10000 action: 1, state: [-5.7442951e-01  2.9197325e-05]
counter:20000 action: 1, state: [-0.7819707  -0.01106898]
counter:30000 action: 2, state: [-0.02939943  0.01806114]
counter:40000 action: 2, state: [-0.22346197 -0.03873277]
counter:50000 action: 1, state: [-1.0909132   0.01071665]


  0%|          | 22/5000 [11:29<64:22:18, 46.55s/it]

counter:10000 action: 0, state: [-0.37602916  0.0041364 ]
counter:20000 action: 0, state: [-0.7176912  -0.04593456]


  0%|          | 23/5000 [11:49<53:07:55, 38.43s/it]

counter:10000 action: 0, state: [-0.09983114  0.01987481]
counter:20000 action: 2, state: [-0.39292136  0.00458989]
counter:30000 action: 2, state: [-0.77569467 -0.00466855]
counter:40000 action: 1, state: [-0.8167772  -0.00291322]
counter:50000 action: 0, state: [-0.5721887  -0.00704318]
counter:60000 action: 0, state: [-0.28906298 -0.00406782]
counter:70000 action: 0, state: [-0.8820956  0.0089087]
counter:80000 action: 1, state: [-0.46572807 -0.01721222]
counter:90000 action: 2, state: [-0.6997325   0.02370249]
counter:100000 action: 1, state: [0.00717327 0.00561611]
counter:110000 action: 2, state: [-0.5523971   0.00578409]


  0%|          | 25/5000 [13:18<52:16:58, 37.83s/it]

counter:10000 action: 1, state: [-0.86378384 -0.00609018]
counter:20000 action: 2, state: [-0.49559706 -0.04551506]
counter:30000 action: 1, state: [-0.8836343  -0.02108216]
counter:40000 action: 2, state: [-0.6244045  -0.02410686]
counter:50000 action: 1, state: [-0.83733255 -0.0377605 ]
counter:60000 action: 1, state: [-0.96939087  0.00965233]


  1%|          | 26/5000 [14:05<56:14:58, 40.71s/it]

counter:10000 action: 2, state: [-0.256402   -0.00905593]
counter:20000 action: 1, state: [-0.8264896   0.01776727]
counter:30000 action: 1, state: [-0.49634823  0.02053277]
counter:40000 action: 0, state: [-0.24511603  0.03956933]
counter:50000 action: 0, state: [-0.44171923 -0.04105083]
counter:60000 action: 2, state: [-0.952287   -0.02437901]


  1%|          | 27/5000 [14:51<58:34:30, 42.40s/it]

counter:10000 action: 2, state: [-0.48189497 -0.02584482]
counter:20000 action: 1, state: [-0.6213671  -0.00903586]
counter:30000 action: 2, state: [-0.220986    0.04594935]


# Exercise 2: Deep Q-Learning (DQN) on LunarLander-v2

## Problem Description
In this exercise, you will implement **Deep Q-Learning (DQN)** to solve the classic control problem **LunarLander-v2** in Gym.

### The Task
The agent controls a lander that starts at the top of the screen and must safely land on the landing pad between two flags.

- **State space**: Continuous vector of 8 variables, including:
  - Position (x, y)
  - Velocity (x_dot, y_dot)
  - Angle and angular velocity
  - Left/right leg contact indicators
- **Action space**: Discrete, 4 actions
  - 0: do nothing
  - 1: fire left orientation engine
  - 2: fire main engine
  - 3: fire right orientation engine
- **Rewards**:
  - +100 to +140 for successful landing
  - -100 for crashing
  - Small negative reward for firing engines (fuel cost)
  - Episode ends when lander crashes or comes to rest

The goal is to train an agent that lands successfully **most of the time**.

---

## Algorithm: Deep Q-Learning
You will implement a **DQN agent** with the following components:

1. **Q-Network**
   - Neural network that approximates Q(s, a).
   - Input: state vector (8 floats).
   - Output: Q-values for 4 actions.
   - Suggested architecture: 2 hidden layers with 128 neurons each, ReLU activation.

2. **Target Network**
   - A copy of the Q-network that is updated less frequently (e.g., every 1000 steps).
   - Used for stable target computation.

3. **Replay Buffer**
   - Stores transitions `(s, a, r, s_next, done)`.
   - Sample random mini-batches to break correlation between consecutive samples.

4. **ε-Greedy Policy**
   - With probability ε, take a random action.
   - Otherwise, take `argmax_a Q(s, a)`.
   - Decay ε over time (e.g., from 1.0 → 0.05).

5. **Q-Learning Method**
   


**Final note:**
   No code base is necessary. At this point, you must know how to implement evertything.
   For reference, but not recommended ([Here](https://colab.research.google.com/drive/1Gl0kuln79A__hgf2a-_-mwoGISXQDK_X?authuser=1#scrollTo=8Sd0q9DG8Rt8&line=56&uniqifier=1) is a solution)

---
## Deliverables
1. **Code**:
- Q-network (PyTorch).
- Training loop with ε-greedy policy, target network, and Adam optimizer.

2. **Plots**:
- Episode returns vs training episodes.
- Evaluation performance with a greedy policy (ε = 0).

3. **Short Write-up (≤1 page)**:
- Did your agent learn to land consistently?  
- How many episodes did it take before you saw improvement?  
- What effect did replay buffer size, target update frequency, and learning rate have on stability?  
- Compare results across different runs (does it sometimes fail to converge?).

Compare this task with the **MountainCar-v0** problem you solved earlier:
- What is **extra** or more challenging in LunarLander?  
- Consider state dimensionality, number of actions, reward shaping, and the difficulty of exploration.  
- Why might DQN be necessary here, whereas simpler methods (like tile coding) could work for MountainCar?
