In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Any, Dict, Optional, Callable
import torch.distributions as td
import numpy as np
import gym

In [2]:
from ezrl.optimizer import RLOptimizer
from ezrl.policy import GymPolicy
from ezrl.algorithms.reinforce import ReinforceOptimizer

In [3]:
class LunarLanderPolicy(GymPolicy):
    def __init__(self):
        super().__init__()
        self.input_dims = 8
        self.output_dims = 4

        self.net = nn.Sequential(
            nn.Linear(8, 32),
            nn.Tanh(),
            nn.Linear(32, 32),
            nn.Tanh(),
            nn.Linear(32, 4, bias=False)
        )

    def forward(self, obs: Any) -> Dict[str, Any]:
        logits = self.net(obs)
        dist = td.Categorical(logits=logits)
        action = dist.sample()
        return {"action":action, "dist":dist}

    def act(self, obs: Any):
        out = self.forward(obs)
        return out["action"].item(), out

In [4]:
from torch.utils.tensorboard import SummaryWriter
import os
from datetime import datetime

def get_tensorboard_logger(experiment_name: str, base_log_path: str = "tensorboard_logs"):
    log_path = "{}/{}_{}".format(
            base_log_path, experiment_name, datetime.now()
        )
    train_writer = SummaryWriter(log_path, flush_secs=10)
    full_log_path = os.path.join(os.getcwd(), log_path)
    print(
        "Follow tensorboard logs with: tensorboard --logdir '{}'".format(
            full_log_path
        )
    )
    return train_writer

In [5]:
from tqdm import tqdm


policy = LunarLanderPolicy()
device = torch.device('cuda')
policy = policy.to(device)

In [6]:
policy.device

device(type='cuda', index=0)

In [7]:
def reinforce_rollout(
    policy: GymPolicy, env_name: str = None, env=None, env_creation_fn=None
):
    if env_name is None and env is None:
        raise ValueError("env_name or env must be provided!")
    if env is None:
        if env_creation_fn is None:
            env_creation_fn = gym.make
        env = env_creation_fn(env_name)
    done = False
    observations, actions, rewards = (
        [],
        [],
        [],
    )
    observation = env.reset()
    with torch.no_grad():
        while not done:
            action, out = policy.act(
                torch.from_numpy(observation).unsqueeze(0).to(policy.device)
            )
            next_observation, reward, done, info = env.step(action)

            observations.append(observation)
            actions.append(action)
            rewards.append(reward)

            observation = next_observation

    return np.array(observations), np.array(actions), np.array(rewards)


In [9]:
bar = tqdm(np.arange(50000))

writer = get_tensorboard_logger("ReinforceLunarLander")
optimizer = ReinforceOptimizer(policy, lr=0.001)

for i in bar:
    observations, actions, rewards = optimizer.rollout(reinforce_rollout, env_name = "LunarLander-v2")

    torch_observations = torch.from_numpy(observations).to(policy.device)
    torch_actions = torch.from_numpy(actions).float().to(policy.device)
    torch_rewards = torch.from_numpy(rewards).float().to(policy.device)


    optimizer.zero_grad()
    loss = optimizer.loss_fn(torch_observations, torch_actions, torch_rewards)
    torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
    loss.backward()
    optimizer.step()

    # metrics

    grad_dict = {}
    for n, W in policy.named_parameters():
        if W.grad is not None:
            grad_dict["{}_grad".format(n)] = float(torch.sum(W.grad).item())

    loss = loss.item()
    sum_reward = np.sum(rewards)

    metrics_dict = {"loss":loss, "sum_reward":sum_reward, **grad_dict}

    for key in metrics_dict:
        writer.add_scalar(key, metrics_dict[key], i)


    bar.set_description("Loss: {}, Reward: {}".format(loss, sum_reward))

  0%|          | 0/50000 [00:00<?, ?it/s]

Follow tensorboard logs with: tensorboard --logdir '/home/shyam/Code/ez-rl/examples/tensorboard_logs/ReinforceLunarLander_2022-02-28 19:04:00.752660'


Loss: -71.58198547363281, Reward: 122.60099044065312:   3%|▎         | 1374/50000 [06:28<3:49:16,  3.53it/s]    


KeyboardInterrupt: 

In [None]:
import ray

In [None]:
ray.init()

2022-02-28 18:03:47,512	INFO services.py:1338 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '192.168.1.248',
 'raylet_ip_address': '192.168.1.248',
 'redis_address': '192.168.1.248:6379',
 'object_store_address': '/tmp/ray/session_2022-02-28_18-03-44_529855_107394/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-02-28_18-03-44_529855_107394/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2022-02-28_18-03-44_529855_107394',
 'metrics_export_port': 57252,
 'node_id': '2009b0a45221835f3a98b4378eb6e6695696a5b4643815d2ea2c09f8'}

In [None]:
from ray.util.multiprocessing import Pool
