# REINFORCE

In [1]:
import gymnasium as gym
import math
import random
from itertools import count
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.normal import Normal

from tqdm import tqdm 

from IPython import display
import wandb

import os
import sys
import time

# trick to import from relative path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from nb_utils.widgets import ArrayRenderWidget


# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
seed = 420
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [3]:
%matplotlib inline

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmax-schik[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
config = {
    "num_episodes": 10000,
    "gamma": 0.99,
    "lr": 1e-4,
    "env": "Pendulum-v1",
}

In [6]:
env = gym.make(config["env"])

In [7]:
class PolicyNetwork(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(PolicyNetwork, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        
        self.mean = nn.Linear(128, n_actions)
        self.stddev = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.tanh(self.layer1(x))
        x = F.tanh(self.layer2(x))
        mean = self.mean(x)
        stddev = torch.log(1 + torch.exp(self.stddev(x)))
        return mean, stddev

In [8]:
net = PolicyNetwork(6, 6)
x = torch.zeros((128, 6, ))
y = net(x)

dist = Normal(*y)

dist.sample().shape

torch.Size([128, 6])

In [9]:
EPS = 1e-6

class Agent:
    def __init__(self, env, policy_net, update_interval=1):
        self.env = env
        self.policy_net = policy_net
        self.update_interval = update_interval
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=config["lr"], amsgrad=True)
        
        self.t_step = 0

    def act(self, state, deterministic=False):
        mean, stddev = self.policy_net(state)
        dist = Normal(mean + EPS, stddev + EPS)
        if deterministic:
            action = mean
        else:
            action = dist.sample()
        log_prob = dist.log_prob(action)
        return action.detach().cpu().numpy(), log_prob
        
    def learn(self, log_probs, rewards):
        running_g = 0
        gs = []

        for R in rewards[::-1]:
            running_g = R + config["gamma"] * running_g
            gs.insert(0, running_g)

        deltas = torch.tensor(gs)

        loss = 0
        for log_prob, delta in zip(log_probs, deltas):
            loss += log_prob.mean() * delta * -1

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss

In [10]:
n_actions = env.action_space.shape[0]

state, info = env.reset()
n_observations = len(state)

policy_net = PolicyNetwork(n_observations, n_actions).to(device)

agent = Agent(env, policy_net)

In [11]:
env.step(agent.act(torch.tensor(env.observation_space.sample(), device=device))[0])

(array([ 0.84871566, -0.5288495 ,  0.47445923], dtype=float32),
 -0.40965866722013544,
 False,
 False,
 {})

In [12]:
wandb.init(
    project=f"learn-rl.REINFORCE.{config['env']}",
    config=config
)

step = 0

for i_episode in tqdm(range(config["num_episodes"])):
    rewards = []
    log_probs = []
    
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device)
    
    for t in count():
        action, log_prob = agent.act(state)
        observation, reward, terminated, truncated, _ = env.step(action)
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        rewards.append(reward)
        log_probs.append(log_prob)
        
        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device)

        # Move to the next state
        state = next_state

        if done:
            break

    loss = agent.learn(log_probs, rewards).detach().cpu().numpy()
        
    wandb.log({"reward": sum(rewards), "loss": loss, "episode_length": t+1})


wandb.finish()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [39:06<00:00,  4.26it/s]


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
episode_length,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,▆▅▅▄▄▅▆▆▃▃▅▄▃▄▃▅▂▇▆▃▃▄▄▅▄▅▇█▁█▁█▆▆▆▂▆▂▆▂
reward,█▆▇▄▅▂▃▂▁▂▁▂▁▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▂▂▁▂▂▁▂▂▂▂▂

0,1
episode_length,200.0
loss,-184772.35083
reward,-1525.43204


In [13]:
torch.save(policy_net.state_dict(), f"../weights/REINFORCE_{config['env']}.pt")

In [14]:
image_widget = ArrayRenderWidget(
    format='png',
    width=600,
    height=400,
)

In [None]:
env = gym.make(config["env"], render_mode="rgb_array")

observation, info = env.reset()

display.display(image_widget)
image_widget.render(env.render())
for _ in range(1000):
    image_widget.render(env.render())


    # action = env.action_space.sample()
    action = agent.act(torch.tensor(observation, device=device), True)[0]
    
    observation, reward, terminated, truncated, info = env.step(action)
    time.sleep(1/30)

    if terminated or truncated:
        observation, info = env.reset()
        break

env.close()

ArrayRenderWidget(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xf4\x00\x00\x01\xf4\x08\x02\x00\x00\…