## IMAPLA

## Models

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double

In [2]:
class MlpPolicy(nn.Module):
    def __init__(self, obs_dim: int, action_dim: int, hidden_dim: int):
        super(MlpPolicy, self).__init__()
        self.model = (
            nn.Sequential(
                nn.Linear(obs_dim, hidden_dim),
                nn.Dropout(p=0.8),
                nn.ReLU(),
                nn.Linear(hidden_dim, action_dim),
            )
            .to(device)
            .to(dtype)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        logits = self.model(x)
        return logits

    def select_action(self, obs: torch.Tensor, deterministic: bool = False):
        logits = self.forward(obs)
        if deterministic:
            action = torch.argmax(logits)
        else:
            action = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)

        return action, logits


class MlpValueFn(nn.Module):
    def __init__(self, obs_dim: int, hidden_dim: int):
        super(MlpValueFn, self).__init__()
        self.model = (
            nn.Sequential(
                nn.Linear(obs_dim, hidden_dim),
                nn.Dropout(p=0.8),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1),
            )
            .to(device)
            .to(dtype)
        )

    def forward(self, observation: torch.Tensor) -> torch.Tensor:
        return self.model(observation)


## Utils

In [3]:
import datetime
from collections import namedtuple
from pathlib import Path
from typing import List, Union

import gym
# import pybullet_envs  # noqa: F401
import numpy as np
import torch
import torch.multiprocessing as mp


In [4]:
Hyperparameters = namedtuple(
    "Hyperparameters",
    [
        "max_updates",
        "policy_hidden_dims",
        "value_fn_hidden_dims",
        "batch_size",
        "gamma",
        "rho_bar",
        "c_bar",
        # "policy_lr",
        # "value_fn_lr",
        "lr",
        "policy_loss_c",
        "v_loss_c",
        "entropy_c",
        "max_timesteps",
        "queue_lim",
        "max_norm",
        "n_actors",
        "env_name",
        "log_path",
        "save_every",
        "eval_every",
        "eval_eps",
        "verbose",
        "render",
    ],
)


class Trajectory:
    def __init__(
        self,
        id: int,
        observations: List[torch.Tensor] = [],
        actions: List[torch.Tensor] = [],
        rewards: List[torch.Tensor] = [],
        dones: List[torch.Tensor] = [],
        logits: List[torch.Tensor] = [],
    ):
        self.id = id
        self.obs = observations
        self.a = actions
        self.r = rewards
        self.d = dones
        self.logits = logits

    def add(
        self,
        obs: torch.Tensor,
        a: torch.Tensor,
        r: torch.Tensor,
        d: torch.Tensor,
        logits: torch.Tensor,
    ):
        self.obs.append(obs)
        self.a.append(a)
        self.r.append(r)
        self.d.append(d)
        self.logits.append(logits)


class Counter:
    def __init__(self, init_val: int = 0):
        self._val = mp.RawValue("i", init_val)
        self._lock = mp.Lock()

    def increment(self):
        with self._lock:
            self._val.value += 1

    @property
    def value(self):
        with self._lock:
            return self._val.value


def make_env(env_name: str):
    if "Bullet" in env_name:
        try:
            env = gym.make(env_name, isDiscrete=True)
        except TypeError:
            env = gym.make(env_name)
    else:
        env = gym.make(env_name)

    if env.action_space.__class__.__name__ != "Discrete":
        raise NotImplementedError("Continuous environments not supported yet")

    return env


def test_policy(
    policy: MlpPolicy,
    env: Union[gym.Env, str],
    episodes: int,
    deterministic: bool,
    max_episode_len: int,
    log_dir: Union[str, None] = None,
    verbose: bool = False,
):
    start_time = datetime.datetime.now()
    start_text = f"Started testing at {start_time:%d-%m-%Y %H:%M:%S}\n"

    if type(env) == str:
        env = make_env(env)

    if log_dir is not None:
        Path(log_dir).mkdir(parents=True, exist_ok=True)
        fpath = Path(log_dir).joinpath(f"test_log_{start_time:%d%m%Y%H%M%S}.txt")
        fpath.write_text(start_text)
    if verbose:
        print(start_text)
    policy.eval()
    rewards = []
    for e in range(episodes):
        obs = env.reset()
        obs = torch.tensor(obs, device=device, dtype=dtype)
        d = False
        ep_rewards = []
        for t in range(max_episode_len):
            action, _ = policy.select_action(obs, deterministic)
            obs, r, d, _ = env.step(action.item())
            obs = torch.tensor(obs, device=device, dtype=dtype)
            ep_rewards.append(r)
            if d:
                break
        rewards.append(sum(ep_rewards))
        ep_text = f"Episode {e+1}: Reward = {rewards[-1]:.2f}\n"
        if log_dir is not None:
            with open(fpath, mode="a") as f:
                f.write(ep_text)
        if verbose:
            print(ep_text)
    avg_reward = np.mean(rewards)
    std_dev = np.std(rewards)
    complete_text = (
        f"-----\n"
        f"Testing completed in "
        f"{(datetime.datetime.now() - start_time).seconds} seconds\n"
        f"Average Reward per episode: {avg_reward}"
    )
    if verbose:
        print(complete_text)
    if log_dir is not None:
        with open(fpath, mode="a") as f:
            f.write(complete_text)

    return avg_reward, std_dev


## learner

In [5]:
import queue
from pathlib import Path
from typing import Union

import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double


class Learner:
    def __init__(
        self,
        id: int,
        hparams: Hyperparameters,
        policy: MlpPolicy,
        value_fn: MlpValueFn,
        q: mp.Queue,
        update_counter: Counter,
        log_path: Union[str, Path, None] = None,
        timeout=200,
    ):
        self.id = id
        self.hp = hparams
        self.policy = policy
        self.value_fn = value_fn
        # self.policy_optimizer = torch.optim.Adam(
        #     self.policy.parameters(), lr=self.hp.policy_lr
        # )
        # self.value_fn_optimizer = torch.optim.Adam(
        #     self.value_fn.parameters(), lr=self.hp.value_fn_lr
        # )
        self.optimizer = torch.optim.Adam(
            [*self.policy.parameters(), *self.value_fn.parameters()], lr=self.hp.lr
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lambda epoch: 0.95)
        self.timeout = timeout
        self.q = q
        self.update_counter = update_counter
        self.log_path = log_path
        if self.log_path is not None:
            self.log_path = Path(log_path) / Path(f"l{self.id}")
            self.log_path.mkdir(parents=True, exist_ok=False)

        self.completion = mp.Event()
        self.p = mp.Process(target=self._learn, name=f"learner_{self.id}")
        print(f"[main] learner_{self.id} Initialized")

    def start(self):
        self.completion.clear()
        self.p.start()
        print(f"[main] Started learner_{self.id} with pid {self.p.pid}")

    def terminate(self):
        self.p.terminate()
        print(f"[main] Terminated learner_{self.id}")

    def join(self):
        self.p.join()

    def _learn(self):
        try:
            update_count = 0

            if self.log_path is not None:
                writer = SummaryWriter(self.log_path)
                writer.add_text("hyperparameters", f"{self.hp}")

            while update_count < self.hp.max_updates:

                if self.hp.verbose >= 2:
                    print(f"[learner_{self.id}] Beginning Update_{update_count + 1}")

                # set up tracking variables
                traj_count = 0
                value_fn_loss = 0.0
                policy_loss = 0.0
                policy_entropy = 0.0
                loss = torch.zeros(1, device=device, dtype=dtype, requires_grad=True)
                reward = 0.0

                # process batch of trajectories
                while traj_count < self.hp.batch_size:
                    try:
                        traj = self.q.get(timeout=self.timeout)
                    except queue.Empty as e:
                        print(
                            f"[learner_{self.id}] No trajectory recieved for {self.timeout}"
                            f" seconds. Exiting!"
                        )
                        if self.log_path is not None:
                            writer.close()
                        self.completion.set()
                        raise e

                    if self.hp.verbose >= 2:
                        print(f"[learner_{self.id}] Processing traj_{traj.id}")
                    traj_len = len(traj.r)
                    obs = torch.stack(traj.obs)
                    actions = torch.stack(traj.a)
                    r = torch.stack(traj.r)
                    reward += torch.sum(r).item() / self.hp.batch_size
                    disc = self.hp.gamma * (~torch.stack(traj.d))

                    # compute value estimates and logits for observed states
                    v = self.value_fn(obs).squeeze(1)
                    curr_logits = self.policy(obs[:-1])

                    # compute log probs for current and old policies
                    curr_log_probs = action_log_probs(curr_logits, actions)
                    traj_log_probs = action_log_probs(torch.stack(traj.logits), actions)

                    # computing v trace targets recursively
                    with torch.no_grad():
                        imp_sampling = torch.exp(
                            curr_log_probs - traj_log_probs
                        ).squeeze(1)
                        rho = torch.clamp(imp_sampling, max=self.hp.rho_bar)
                        c = torch.clamp(imp_sampling, max=self.hp.c_bar)
                        delta = rho * (r + self.hp.gamma * v[1:] - v[:1])
                        vt = torch.zeros(traj_len + 1, device=device, dtype=dtype)

                        for i in range(traj_len - 1, -1, -1):
                            vt[i] = delta[i] + disc[i] * c[i] * (vt[i + 1] - v[i + 1])
                        vt = torch.add(vt, v)

                        # vt = (vt - torch.mean(vt)) / torch.std(vt)

                        pg_adv = rho * (r + disc * vt[1:] - v[:-1])

                    # print(f"v: {v}")
                    # print(f"vt: {vt}")
                    # print(f"pg_adv: {pg_adv}")
                    # print(f"rho: {rho}")

                    # compute loss as sum of value loss, policy loss and entropy
                    # traj_value_fn_loss = 0.5 * torch.sum(torch.pow(v - vt, 2))
                    # traj_policy_loss = torch.sum(curr_log_probs * pg_adv.detach())
                    # traj_policy_entropy = -1 * torch.sum(
                    #     F.softmax(curr_logits, dim=-1)
                    #     * F.log_softmax(curr_logits, dim=-1)
                    # )
                    traj_value_fn_loss = compute_baseline_loss(v - vt)
                    traj_policy_loss = compute_policy_gradient_loss(
                        curr_logits, actions, pg_adv
                    )
                    traj_policy_entropy = -1 * compute_entropy_loss(curr_logits)
                    traj_loss = (
                        self.hp.v_loss_c * traj_value_fn_loss
                        + self.hp.policy_loss_c * traj_policy_loss
                        - self.hp.entropy_c * traj_policy_entropy
                    )
                    loss = torch.add(loss, traj_loss / self.hp.batch_size)
                    value_fn_loss += traj_value_fn_loss.item() / self.hp.batch_size
                    policy_loss += traj_policy_loss.item() / self.hp.batch_size
                    policy_entropy += traj_policy_entropy.item() / self.hp.batch_size
                    traj_count += 1

                if self.hp.verbose >= 2:
                    print(
                        f"[learner_{self.id}] Updating model weights "
                        f" for Update {update_count + 1}"
                    )

                # backpropogating loss and updating weights
                # self.policy_optimizer.zero_grad()
                # self.value_fn_optimizer.zero_grad()
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    self.policy.parameters(), self.hp.max_norm
                )
                torch.nn.utils.clip_grad_norm_(
                    self.value_fn.parameters(), self.hp.max_norm
                )
                self.optimizer.step()
                self.scheduler.step()
                # self.policy_optimizer.step()
                # self.value_fn_optimizer.step()

                # log to console
                if self.hp.verbose >= 1:
                    print(
                        f"[learner_{self.id}] Update {update_count + 1} | "
                        f"Batch Mean Reward: {reward:.2f} | Loss: {loss.item():.2f}"
                    )

                # evaluate current policy
                if self.hp.eval_every is not None:
                    if (update_count + 1) % self.hp.eval_every == 0:
                        eval_r, eval_std = test_policy(
                            self.policy,
                            self.hp.env_name,
                            self.hp.eval_eps,
                            True,
                            self.hp.max_timesteps,
                        )
                        if self.hp.verbose >= 1:
                            print(
                                f"[learner_{self.id}] Update {update_count + 1} | "
                                f"Evaluation Reward: {eval_r:.2f}, Std Dev: {eval_std:.2f}"
                            )
                        if self.log_path is not None:
                            writer.add_scalar(
                                f"learner_{self.id}/rewards/evaluation_reward",
                                eval_r,
                                update_count + 1,
                            )

                # log to tensorboard
                if self.log_path is not None:
                    writer.add_scalar(
                        f"learner_{self.id}/rewards/batch_mean_reward",
                        reward,
                        update_count + 1,
                    )
                    writer.add_scalar(
                        f"learner_{self.id}/loss/policy_loss",
                        policy_loss,
                        update_count + 1,
                    )
                    writer.add_scalar(
                        f"learner_{self.id}/loss/value_fn_loss",
                        value_fn_loss,
                        update_count + 1,
                    )
                    writer.add_scalar(
                        f"learner_{self.id}/loss/policy_entropy",
                        policy_entropy,
                        update_count + 1,
                    )
                    writer.add_scalar(
                        f"learner_{self.id}/loss/total_loss", loss, update_count + 1
                    )

                # save model weights every given interval
                if (update_count + 1) % self.hp.save_every == 0:
                    path = self.log_path / Path(
                        f"IMPALA_{self.hp.env_name}_l{self.id}_{update_count+1}.pt"
                    )
                    self.save(path)
                    print(
                        f"[learner_{self.id}] Saved model weights at "
                        f"update {update_count+1} to {path}"
                    )

                # increment update counter
                self.update_counter.increment()
                update_count = self.update_counter.value

            if self.log_path is not None:
                writer.close()

            print(f"[learner_{self.id}] Finished learning")
            self.completion.set()
            return

        except KeyboardInterrupt:
            print(f"[learner_{self.id}] Interrupted")
            if self.log_path is not None:
                writer.close()
            self.completion.set()
            return

        except Exception as e:
            if self.log_path is not None:
                writer.close()
            print(f"[learner_{self.id}] Encoutered exception")
            raise e

    def save(self, path):
        """ Save model parameters """
        torch.save(
            {
                "policy_state_dict": self.policy.state_dict(),
                "value_fn_state_dict": self.value_fn.state_dict(),
            },
            path,
        )

    def load(self, path):
        """ Load model parameters """
        checkpoint = torch.load(path)
        self.policy.load_state_dict(checkpoint["policy_state_dict"])
        self.value_fn.load_state_dict(checkpoint["value_fn_state_dict"])

    @property
    def policy_weights(self) -> torch.Tensor:
        return self.policy.state_dict()


def action_log_probs(policy_logits, actions):
    return -F.nll_loss(
        F.log_softmax(policy_logits, dim=-1),
        target=torch.flatten(actions),
        reduction="none",
    ).view_as(actions)


def compute_baseline_loss(advantages):
    return 0.5 * torch.sum(advantages ** 2)


def compute_entropy_loss(logits):
    """Return the entropy loss, i.e., the negative entropy of the policy."""
    policy = F.softmax(logits, dim=-1)
    log_policy = F.log_softmax(logits, dim=-1)
    return torch.sum(policy * log_policy)


def compute_policy_gradient_loss(logits, actions, advantages):
    cross_entropy = F.nll_loss(
        F.log_softmax(logits, dim=-1), target=torch.flatten(actions), reduction="none",
    ).view_as(advantages)
    return torch.sum(cross_entropy * advantages.detach())



## Actor

In [6]:
import queue
from pathlib import Path
from typing import Union

import torch
import torch.multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter

class Actor:
    def __init__(
        self,
        id: int,
        hparams: Hyperparameters,
        policy: MlpPolicy,
        learner: Learner,
        q: mp.Queue,
        update_counter: Counter,
        log_path: Union[Path, str, None] = None,
        timeout=10,
    ):
        self.id = id
        self.hp = hparams
        self.policy = policy
        for p in self.policy.parameters():
            p.requires_grad = False
        self.learner = learner
        self.timeout = timeout
        self.q = q
        self.update_counter = update_counter
        self.log_path = log_path
        if self.log_path is not None:
            self.log_path = Path(self.log_path) / Path(f"a{self.id}")
            self.log_path.mkdir(parents=True, exist_ok=False)

        self.completion = mp.Event()
        self.p = mp.Process(target=self._act, name=f"actor_{self.id}")
        print(f"[main] actor_{self.id} Initialized")

    def start(self):
        self.p.start()
        print(f"[main] Started actor_{self.id} with pid {self.p.pid}")

    def terminate(self):
        self.p.terminate()
        print(f"[main] Terminated actor_{self.id}")

    def join(self):
        self.p.join()

    def _act(self):
        try:

            if self.log_path is not None:
                writer = SummaryWriter(self.log_path)
                writer.add_text("hyperparameters", f"{self.hp}")

            env = make_env(self.hp.env_name)
            traj_no = 0

            while not self.learner.completion.is_set():
                traj_no += 1
                self.policy.load_state_dict(self.learner.policy_weights)
                traj_id = (self.id, traj_no)
                traj = Trajectory(traj_id, [], [], [], [], [])
                obs = env.reset()
                obs = torch.tensor(obs, device=device, dtype=dtype)
                traj.obs.append(obs)
                c = 0

                if self.hp.verbose >= 2:
                    print(f"[actor_{self.id}] Starting traj_{traj.id}")

                # record trajectory
                while c < self.hp.max_timesteps:
                    if self.hp.render:
                        env.render()
                    c += 1
                    a, logits = self.policy.select_action(obs)
                    # print(f"[actor_{self.id}] a_probs: {a_probs}")
                    obs, r, d, _ = env.step(a.item())
                    obs = torch.tensor(obs, device=device, dtype=dtype)
                    r = torch.tensor(r, device=device, dtype=dtype)
                    d = torch.tensor(d, device=device)
                    traj.add(obs, a, r, d, logits)

                    if d:
                        break

                if self.hp.verbose >= 2:
                    print(
                        f"[actor_{self.id}] traj_{traj.id} completed Reward = {sum(traj.r)}"
                    )
                if self.log_path is not None:
                    # action_one_hot = torch.zeros(env.action_space.n)
                    # action_one_hot[a] += 1
                    writer.add_histogram(
                        f"actor_{self.id}/actions/action_taken", a, traj_no
                    )
                    writer.add_histogram(
                        f"actor_{self.id}/actions/logits", logits.detach(), traj_no
                    )
                    writer.add_scalar(
                        f"actor_{self.id}/rewards/trajectory_reward",
                        sum(traj.r),
                        traj_no,
                    )

                while True:
                    try:
                        self.q.put(traj, timeout=self.timeout)
                        break
                    except queue.Full:
                        if self.learner.completion.is_set():
                            break
                        else:
                            continue

            if self.log_path is not None:
                writer.close()
            env.close()
            print(f"[actor_{self.id}] Finished acting")
            self.completion.set()
            return

        except KeyboardInterrupt:
            print(f"[actor_{self.id}] interrupted")
            if self.log_path is not None:
                writer.close()
            env.close()
            self.completion.set()
            return

        except Exception as e:
            if self.log_path is not None:
                writer.close()
            env.close()
            print(f"[actor_{self.id}] encoutered exception")
            raise e


## train

In [7]:
import datetime
from pathlib import Path

import torch.multiprocessing as mp


hparams = Hyperparameters(
    max_updates=50,
    policy_hidden_dims=128,
    value_fn_hidden_dims=128,
    batch_size=32,
    gamma=0.99,
    rho_bar=1.0,
    c_bar=1.0,
    # policy_lr=1e-3,
    # value_fn_lr=1e-3,
    lr=1e-3,
    policy_loss_c=1,
    v_loss_c=0.5,
    entropy_c=0.0006,
    max_timesteps=1000,
    queue_lim=8,
    max_norm=10,
    n_actors=1,
    env_name="CartPole-v1",          # "RacecarBulletEnv-v0",
    log_path="./logs/",
    save_every=50,
    eval_every=2,
    eval_eps=20,
    verbose=1,
    render=False,
)

if __name__ == "__main__":

    start_time = datetime.datetime.now()

    mp.set_start_method("fork", force=True)

    print(f"[main] Start time: {start_time:%d-%m-%Y %H:%M:%S}")
    print(f"[main] {hparams}\n")

    if hparams.log_path is not None:
        log_path = Path(Path(hparams.log_path) / f"{start_time:%d%m%Y%H%M%S}")
        log_path.mkdir(parents=True, exist_ok=True)
        with open(Path(log_path / "hyperparameters.txt"), "w+") as f:
            f.write(f"{hparams}")
        if not hparams.save_every > 0:
            raise ValueError(
                f"save_every hyperparameter should be greater than 0, "
                f"got {hparams.save_every}"
            )
    else:
        log_path = None

    q = mp.Queue(maxsize=hparams.queue_lim)
    update_counter = Counter(init_val=0)
    env = make_env(hparams.env_name)
    observation_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    env.close()
    policy = MlpPolicy(observation_size, action_size, hparams.policy_hidden_dims)
    policy.share_memory()
    value_fn = MlpValueFn(observation_size, hparams.value_fn_hidden_dims)
    learner = Learner(1, hparams, policy, value_fn, q, update_counter, log_path)

    actors = []
    for i in range(hparams.n_actors):
        policy = MlpPolicy(observation_size, action_size, hparams.policy_hidden_dims)
        actors.append(
            Actor(i + 1, hparams, policy, learner, q, update_counter, log_path)
        )

    print("[main] Initialized")

    for a in actors:
        a.start()
    learner.start()

    learner.completion.wait()
    for a in actors:
        a.completion.wait()

    learner.terminate()
    for a in actors:
        a.terminate()

    learner.join()
    for a in actors:
        a.join()

    print(
        f"[main] Completed in {(datetime.datetime.now() - start_time).seconds} seconds"
    )


[main] Start time: 07-08-2021 15:33:37
[main] Hyperparameters(max_updates=50, policy_hidden_dims=128, value_fn_hidden_dims=128, batch_size=32, gamma=0.99, rho_bar=1.0, c_bar=1.0, lr=0.001, policy_loss_c=1, v_loss_c=0.5, entropy_c=0.0006, max_timesteps=1000, queue_lim=8, max_norm=10, n_actors=1, env_name='CartPole-v1', log_path='./logs/', save_every=50, eval_every=2, eval_eps=20, verbose=1, render=False)

[main] learner_1 Initialized
[main] actor_1 Initialized
[main] Initialized
[main] Started actor_1 with pid 6932
[main] Started learner_1 with pid 6933
[learner_1] Update 1 | Batch Mean Reward: 17.75 | Loss: 94.44
[learner_1] Update 2 | Batch Mean Reward: 18.31 | Loss: 115.99
[learner_1] Update 2 | Evaluation Reward: 9.25, Std Dev: 0.89
[learner_1] Update 3 | Batch Mean Reward: 18.97 | Loss: 142.60
[learner_1] Update 4 | Batch Mean Reward: 22.44 | Loss: 210.46
[learner_1] Update 4 | Evaluation Reward: 9.30, Std Dev: 0.64
[learner_1] Update 5 | Batch Mean Reward: 18.47 | Loss: 152.07
[le

## test

In [9]:
import argparse

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.double


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-pp", "--policy_path", type=str, required=True, help="path to policy weights"
    )
    parser.add_argument(
        "-hd",
        "--policy_hidden_dim",
        type=int,
        required=True,
        help="dimension of hidden layer",
    )
    parser.add_argument(
        "-en",
        "--env_name",
        type=str,
        required=True,
        help="name of gym envrionment to test in",
    )
    parser.add_argument(
        "-ne",
        "--num_episodes",
        type=int,
        default=10,
        help="number of episodes to test for",
    )
    parser.add_argument(
        "-d", "--deterministic", help="use deterministic policy", action="store_true"
    )
    parser.add_argument(
        "-el",
        "--max_episode_len",
        type=int,
        default=1000,
        help="maximum number of steps per episode",
    )
    parser.add_argument(
        "-ld",
        "--log_dir",
        type=str,
        required=False,
        help="directory to store log file in",
    )
    parser.add_argument(
        "-v", "--verbose", help="increase output verbosity", action="store_true"
    )

    args = parser.parse_args()
    
    env = make_env(args.env_name)
    observation_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    policy = MlpPolicy(observation_size, action_size, args.policy_hidden_dim)
    checkpoint = torch.load(args.policy_path)
    policy.load_state_dict(checkpoint["policy_state_dict"])

    test_policy(
        policy,
        env,
        args.num_episodes,
        args.deterministic,
        args.max_episode_len,
        args.log_dir,
        args.verbose,
    )

    env.close()


usage: ipykernel_launcher.py [-h] -pp POLICY_PATH -hd POLICY_HIDDEN_DIM -en
                             ENV_NAME [-ne NUM_EPISODES] [-d]
                             [-el MAX_EPISODE_LEN] [-ld LOG_DIR] [-v]
ipykernel_launcher.py: error: the following arguments are required: -pp/--policy_path, -hd/--policy_hidden_dim, -en/--env_name


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
