# So what are we doing here?

We're training our reinforcement learning agent with an implicit world model that it can access indpendent of the model itself. 
![The-Dyna-architecture-the-agent-is-learning-either-on-the-real-world-or-the-world-model.png](attachment:b142ea77-7917-4f91-ac74-3d536b9f7dc6.png)

The effective strategy is as follows:
1. Rollout agent in environment
2. Train agent + world model on environment rollout
3. Rollout agent in world model
4. Train agent on world model rollout
5. Play with length of world model rollouts

There's a lot to play around with here. We could train the world model on a prior first which may be helpful. Lets suppose some small prior so our agent isn't awful to beign with.

What are we trying to capture?
So first, lets make clear that the experiment will run for 2000 rollouts on **the real environment** to make it consistent with the PPO training. 
Metrics:
1. Time it takes to complete training (how much more time, on average, are we subjecting the training to?) Let's note that sampling the fake environment can be rapidly parallelizable.
2. How fast does it train relative to rollouts on the real environment.

Next, let's define a couple parameters:
We'll define: 

$\tau \in \mathbb{N}$ as the world model look ahead number when generating samples for our agent. This has a max, naturally, of 200. 

$M \in \mathbb{N}$ as the number of "rollouts" we're taking from the world model in order to train it. Another way of thinking about this: suppose we wanted to train our consistently on 800 samples. Then we would run M rollouts of $\tau$ lookahead distance to gain enough samples to train the model. 

$\rho \in \mathbb{N}$ as the number of rollouts we're giving our world model to train via prior. 


# World Model Prior with Random Samples

In [1]:
import torch 
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import ncps 
from ncps.torch import LTC
from ncps.torch import CfC
from ncps.wirings import AutoNCP
import pytorch_lightning as pl
import torch.utils.data as data
from state_indep_ppo import PPO

In [2]:
import gym
rho = 100 # 100 * 200 = 20,000 training samples. should be fine..? 
env = gym.make("Pendulum-v1")

In [3]:
old_obs = []
actions = []
new_obs = []
for _ in range(num_rollouts):
    
    observation, info = env.reset()
    
    terminated = False
    truncated = False
    
    while not (terminated or truncated):
        # Get the action from the trained PPO agent
        #vect_obs = t.tensor(observation, dtype=t.float32, device='cpu')
        obs = observation

        action = env.action_space.sample()
    
        # Take a step in the environment
        observation, reward, terminated, truncated, info = env.step(action)
        
        old_obs.append(obs)
        actions.append(action)
        new_obs.append(observation)
   
    # Once the episode is done, close the environment
    env.close()

old_obs = np.vstack(old_obs)
actions = np.vstack(actions)
new_obs = np.vstack(new_obs)

  if not isinstance(terminated, (bool, np.bool8)):


In [4]:
inputs = np.concatenate((old_obs, actions), axis=1)
inputs = torch.tensor(inputs)
outputs = torch.tensor(new_obs)
dataloader = data.DataLoader(
    data.TensorDataset(inputs, outputs), batch_size=64, shuffle=True, num_workers=4
)

In [6]:
# LightningModule for training a RNNSequence module
class SequenceLearner(pl.LightningModule):
    def __init__(self, model, lr=0.005):
        super().__init__()
        self.model = model
        self.lr = lr

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat, _ = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat, _ = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [9]:
out_features = 3
in_features = 4

wiring = AutoNCP(64, out_features)  

world_model = LTC(in_features, wiring, batch_first=True) # change to cfc, what could go wrong?
learn = SequenceLearner(world_model, lr=0.001)
trainer = pl.Trainer(
    logger=pl.loggers.CSVLogger("log"),
    max_epochs=20,
    gradient_clip_val=1,  # Clip gradient to stabilize training
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
trainer.fit(learn, dataloader)

/home/tristongrayston/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.

  | Name  | Type | Params | Mode 
---------------------------------------
0 | model | LTC  | 22.0 K | train
---------------------------------------
17.6 K    Trainable params
4.4 K     Non-trainable params
22.0 K    Total params
0.088     Total estimated model params size (MB)
4         Modules in train mode
0         Modules in eval mode


Training: |                                                                                       | 0/? [00:00…

`Trainer.fit` stopped: `max_epochs=20` reached.


# Helper Functions for Training Loop

In [12]:
# first thing we have to do is define a reward function. I'll try to just straight up copy it
def angle_normalize(x):
    '''From the pendulum code '''
    return ((x + np.pi) % (2 * np.pi)) - np.pi

def compute_reward(state, action, max_torque=2.0):
    """
    Compute the reward for the Pendulum environment based on the state and action.
    
    The Pendulum environment calculates cost as:
        cost = angle_normalize(theta)**2 + 0.1 * theta_dot**2 + 0.001 * (torque**2)
    and returns the reward as:
        reward = -cost
    
    Parameters:
      state : array-like of shape (2,)
          The current state of the pendulum [theta, theta_dot].
      action : array-like or scalar
          The action to be applied (the torque), which will be clipped.
      max_torque : float
          The maximum allowed magnitude for the torque.
    
    Returns:
      reward : float
          The computed reward.
    """
    # Unpack the state.
    theta, theta_dot = state

    # Clip the action to the allowed range.
    # Note: In the gym code, the action is expected to be an array, hence the [0] indexing.
    clipped_action = np.clip(action, -max_torque, max_torque)
    if isinstance(clipped_action, np.ndarray):
        # If action is provided as an array, extract the first element.
        clipped_action = clipped_action[0]
    
    # Calculate the cost as defined in the environment.
    cost = angle_normalize(theta)**2 + 0.1 * theta_dot**2 + 0.001 * (clipped_action**2)
    reward = -cost
    return reward

def bootstrap_gaes(model, last_state, rewards, values):
        """
        Return the General Advantage Estimates from the given rewards and values.
        Paper: https://arxiv.org/pdf/1506.02438.pdf
        Credit: Eden Meyer
        """

        last_state_est = model.get_vf(last_state)

        next_values = np.concatenate([values[1:], [[last_state_est]]])
        deltas = [rew + self.gamma * next_val - val for rew, val, next_val in zip(rewards, values, next_values)]
        #deltas = rewards + self.gamma*next_values - values

        gaes = [deltas[-1]]
        for i in reversed(range(len(deltas)-1)):
            gaes.append(deltas[i] + self.lam * self.gamma * gaes[-1])

        return np.array(gaes[::-1])

def next_step_prediction(agent, wm, tau, start_states):
    b_obs = []
    actions = []
    advantages = []
    returns = []
    act_log_probs = []
    
    new_state = torch.tensor(start_states, dtype=torch.float32)
    states.append(new_state)
    for i in range(tau):
        
        action, log_prob = agent.get_action(new_state)
        vals = agent.get_vf(vect_obs)
        

In [16]:
# I would imagine state-independent variance would converge at higher rates, so let's start with that.

tau = 10

actionspace = env.action_space.shape[0]
obsspace = env.observation_space.shape[0]
print(f"actionspace: {actionspace}, obs space {obsspace}")
agent = PPO(
    ob_space=obsspace,
    actions=actionspace,
    n_batches=10,
    lam= 0.95,
    kl_coeff= 0.2,
    clip_rewards= False,
    clip_param= 0.2,
    vf_clip_param= 10.0,
    entropy_coeff= 0,
    a_lr= 5e-4,
    c_lr= 5e-4,
    device= 'cpu',
    max_ts= 100,
    rollouts_per_batch= 5,
    max_timesteps_per_episode= 200,
    n_updates_per_iteration= 3
)

actionspace: 1, obs space 3


## Test helper functions