# Rocket DQN Training Notebook

This notebook trains a Deep Q-Learning (DQN) neural network to control a rocket. 
The environment is defined in C++ (`rocket.cpp`) and exported as a shared library using **pybind11**.

## Install Dependencies and Compile Library

We install necessary tools and compile the shared library.

In [None]:
!apt-get update
!apt-get install -y build-essential
!pip install pybind11 torch matplotlib tqdm gymnasium
!cd ..; make  # Compile the shared library

## Import the Shared Library

Add the path where the shared library is located to `sys.path` to import it in Python.

In [None]:
import sys
import os

# Add the root folder or 'lib' folder where the shared library is located
sys.path.append(os.path.abspath('../lib'))  # Adjust according to your setup

import Rocket as rck  # Import the Rocket class from the C++ library
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from collections import deque
import random
import copy
from tqdm import tqdm
import math

# Define PyTorch device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Neural Network Definition

The `Red` class defines a fully connected neural network with two hidden layers and ReLU activations.

In [None]:
class Red(nn.Module):
    def __init__(self, input_size=4, hidden_size=10, output_size=2):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.network(x)

## Replay Memory

The `ReplayMemory` class stores transitions and generates random minibatches for training.

In [None]:
class ReplayMemory():
    def __init__(self, capacity=10000):
        self.deque = deque(maxlen=capacity)
        self.deque_shuffle = deque(maxlen=capacity)

    def append(self, transition):
        """Add a transition to the memory."""
        self.deque.append(transition)

    def minibatch(self, batch_size=32):
        """Yield random minibatches for training."""
        if len(self.deque_shuffle) == 0:
            self.deque_shuffle = copy.deepcopy(self.deque)
            random.shuffle(self.deque_shuffle)
        while len(self.deque_shuffle) > 0:
            batch = []
            for _ in range(min(batch_size, len(self.deque_shuffle))):
                batch.append(self.deque_shuffle.popleft())
            yield np.array(batch)

    def reset(self, capacity=10000):
        """Reset memory to empty."""
        self.deque = deque(maxlen=capacity)
        self.deque_shuffle = deque(maxlen=capacity)

## Agent Definition

The `WrapAgent` class implements an epsilon-greedy policy to choose actions.

In [None]:
class WrapAgent():
    def __init__(self, epsilon=0.7, decay=10):
        self.epsilon = epsilon
        self.decay = decay
        self.rng = np.random.default_rng(234343)

    def choose_action(self, env, state, episode, Qnet, epsilon_o, step_epsilon, cont_epsilon):
        """Choose an action using epsilon-greedy policy."""
        rand = self.rng.uniform(0, 1)
        if rand < (self.epsilon - step_epsilon * cont_epsilon) or rand < 0.01:
            return env.sample()  # Random action
        else:
            with torch.inference_mode():
                return torch.argmax(Qnet(torch.from_numpy(state).to(device).unsqueeze(0).float())).detach().cpu().numpy()

## Hyperparameters and Environment Initialization

In [None]:
# Hyperparameters
gamma = 0.9999  # Discount factor
epsilon = 1     # Initial epsilon
decay = 100     # Epsilon decay factor
step_epsilon = 0.001
cont_epsilon = 0
tau = 0.05      # Target network update factor
batch_size = 128
start_train = 100  # Start training after this many episodes
max_steps = 6000
episodes = 7000
state_length = 9  # Number of state variables from environment
rp_len = 10000

# Initialize replay memory, agent, and networks
rp = ReplayMemory(rp_len)
agent = WrapAgent(epsilon, decay)
Qnet = Red(state_length, 30, 4).to(device)
target = Red(state_length, 30, 4).to(device)
target.load_state_dict(Qnet.state_dict())
optimizer = torch.optim.SGD(Qnet.parameters(), lr=3e-4)
loss = nn.MSELoss()
rng = np.random.default_rng(33234)
env = rck.Rocket()

## Training Loop

Train the network by sampling transitions from the replay memory and updating Q-values.

In [None]:
rew = []
t_steps = []
error_t = []

for i in range(episodes):
    with torch.inference_mode():
        reward = 0
        time_steps = 0
        state = env.reset()
        for j in range(max_steps):
            action = agent.choose_action(env, state, i, Qnet, epsilon, step_epsilon, cont_epsilon)
            new = env.step(action)
            state_n = new[:state_length]
            r = new[state_length]
            done = new[-1]
            reward += r
            time_steps += 1
            rp.append(np.hstack((state, r, np.array(action), state_n, done)))
            if done == 0:
                break
            else:
                state = state_n
        rew.append(reward)
        t_steps.append(time_steps)
        error = 0
        if i >= start_train:
            for batch in rp.minibatch(batch_size):
                Qnet.train()
                state_t = torch.from_numpy(batch[:, :state_length]).float().to(device)
                finish_states = torch.from_numpy(batch[:, -1]).float().unsqueeze(0).to(device)
                sample_space = np.arange(4)
                actions = torch.vstack(tuple(torch.from_numpy(sample_space == a) for a in batch[:, state_length+1])).to(device)
                q_values_qnet = Qnet(state_t)[actions].unsqueeze(0)
                state_n = torch.from_numpy(batch[:, state_length+2:-1]).float().to(device)
                reward_b = torch.from_numpy(batch[:, state_length+1]).float().unsqueeze(0).to(device)
                q_values_target = reward_b + gamma * torch.max(target(state_n), dim=1)[0] * finish_states
                err_train = loss(q_values_target, q_values_qnet)
                error += err_train.detach().cpu().numpy()
                optimizer.zero_grad()
                err_train.backward()
                optimizer.step()
                Qnet.eval()
            error_t.append(error)
            print(f"Process {i*100/episodes:.2f}%, Episode {i}, Steps {t_steps[-1]}, Reward {rew[-1]}, Error {error_t[-1]:.5f}, Epsilon {max(epsilon - step_epsilon*cont_epsilon,0.01):.3f}")
            # Update target network
            Qnet_dict = Qnet.state_dict()
            target_dict = target.state_dict()
            for key in target_dict.keys():
                target_dict[key] = tau * Qnet_dict[key] + (1 - tau) * target_dict[key]
            target.load_state_dict(target_dict)
            torch.save(Qnet.state_dict(), '../models/net.pth')
        cont_epsilon += 1

## Save Trained Network

In [None]:
torch.save(Qnet.state_dict(), '../models/net.pth')

## Visualize Training Metrics

In [None]:
plt.figure(figsize=(10,6))
plt.plot(np.arange(len(error_t)), error_t)
plt.title('Training Error per Episode')
plt.xlabel('Episode')
plt.ylabel('Error')
plt.show()

plt.figure(figsize=(10,6))
plt.hist(t_steps, bins=100)
plt.title('Distribution of Steps per Episode')
plt.xlabel('Steps')
plt.ylabel('Frequency')
plt.show()

## Rocket Animation

Visualize the rocket trajectory using `matplotlib.animation`.

In [None]:
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Rectangle
%matplotlib inline
plt.rcParams['figure.dpi'] = 150
plt.ioff()

fig, ax = plt.subplots(figsize=(5,5))
state = env.reset()
flag = True

def animate(t):
    global state
    global flag
    ax.clear()
    xlim = 10
    ylim = 10
    plt.tight_layout()
    
    # Select action using the trained network
    action = torch.argmax(Qnet(torch.from_numpy(state[:state_length]).unsqueeze(0).float().to(device))).detach().cpu().item()
    new = env.step(action)
    state = new[:state_length]
    phi = state[2]
    x_cm = state[0]
    y_cm = state[1]
    
    ax.set_ylim(-ylim, ylim)
    ax.set_xlim(-xlim, xlim)
    
    # Draw the rocket as a rectangle
    rect = Rectangle((x_cm-0.25, y_cm-2), width=0.5, height=4, angle=phi*180/np.pi, rotation_point='center', edgecolor='blue', facecolor='lightblue')
    ax.add_patch(rect)
    if flag:
        plt.plot(np.array([-y_cm, y_cm]), np.array([x_cm, x_cm]), color='red')  # Example flame visualization

FuncAnimation(fig, animate, frames=100, interval=40)