In [3]:
import sys
sys.path.append('/tmp-data/zhx/DriverOrderOfflineRL/cage-challenge-1/CybORG')
sys.path.append('/tmp-data/zhx/DriverOrderOfflineRL/tianshou')

### args define

In [4]:
import argparse
import datetime
import os
import pprint

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch import nn

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DQNPolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import offpolicy_trainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.env import SubprocVectorEnv

from typing import Any, Dict, Optional, Sequence, Tuple, Union


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default="cyborg")
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--scale-obs", type=int, default=0)
    parser.add_argument("--eps-test", type=float, default=0.005)
    parser.add_argument("--eps-train", type=float, default=1.)
    parser.add_argument("--eps-train-final", type=float, default=0.05)
    parser.add_argument("--buffer-size", type=int, default=100000)
    parser.add_argument("--lr", type=float, default=0.0001)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--n-step", type=int, default=3)
    parser.add_argument("--target-update-freq", type=int, default=500)
    parser.add_argument("--epoch", type=int, default=100)
    parser.add_argument("--step-per-epoch", type=int, default=1000)
    parser.add_argument("--step-per-collect", type=int, default=10)
    parser.add_argument("--update-per-step", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--training-num", type=int, default=10)
    parser.add_argument("--test-num", type=int, default=10)
    parser.add_argument("--logdir", type=str, default="log")
    parser.add_argument("--render", type=float, default=0.)
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
    )
    parser.add_argument("--frames-stack", type=int, default=1)
    parser.add_argument("--resume-path", type=str, default=None)
    parser.add_argument("--resume-id", type=str, default=None)
    parser.add_argument(
        "--logger",
        type=str,
        default="wandb",
        choices=["tensorboard", "wandb"],
    )
    parser.add_argument("--wandb-project", type=str, default="cyborg.dqn")
    parser.add_argument(
        "--watch",
        default=False,
        action="store_true",
        help="watch the play of pre-trained policy only"
    )
    parser.add_argument("--save-buffer-name", type=str, default=None)
    parser.add_argument(
        "--icm-lr-scale",
        type=float,
        default=0.,
        help="use intrinsic curiosity module with this lr scale"
    )
    parser.add_argument(
        "--icm-reward-scale",
        type=float,
        default=0.01,
        help="scaling factor for intrinsic curiosity reward"
    )
    parser.add_argument(
        "--icm-forward-loss-weight",
        type=float,
        default=0.2,
        help="weight for the forward model loss in ICM"
    )
    return parser.parse_args(args=[])


### env

In [5]:
args = get_args()

  return torch._C._cuda_getDeviceCount() > 0


In [6]:
import inspect
from pprint import pprint
from CybORG import CybORG
from CybORG.Shared.Actions import *
from CybORG.Agents import RedMeanderAgent, B_lineAgent
from CybORG.Agents.Wrappers import *

path = str(inspect.getfile(CybORG))
path = path[:-10] + '/Shared/Scenarios/Scenario1b.yaml'

# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# env
CYBORG = CybORG(path,'sim', agents={'Red': B_lineAgent})
env = ChallengeWrapper(env=CYBORG, agent_name="Blue", max_steps=args.step_per_epoch)
train_envs = SubprocVectorEnv([lambda: ChallengeWrapper(env=CYBORG, agent_name="Blue", max_steps=args.step_per_epoch) for _ in range(5)])
test_envs = SubprocVectorEnv([lambda: ChallengeWrapper(env=CYBORG, agent_name="Blue", max_steps=args.step_per_epoch) for _ in range(5)])

In [7]:
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)

Observations shape: (52,)
Actions shape: 54


### network

In [8]:
class D5QN(nn.Module):
    """Reference: Human-level control through deep reinforcement learning.

    For advanced usage (how to customize the network), please refer to
    :ref:`build_the_network`.
    """

    def __init__(
        self,
        state_shape: Sequence[int],
        action_shape: Sequence[int],
        device: Union[str, int, torch.device] = "cpu",
        features_only: bool = False,
        output_dim: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.device = device
        self.net = nn.Sequential(
            nn.Linear(state_shape, 512), nn.ReLU(inplace=True),
            nn.Linear(512, action_shape)
        )

    def forward(
        self,
        obs: Union[np.ndarray, torch.Tensor],
        state: Optional[Any] = None,
        info: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Any]:
        r"""Mapping: s -> Q(s, \*)."""
        obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
        return self.net(obs), state

### define policy

In [9]:
# define model
net = D5QN(args.state_shape[0], args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(
    net,
    optim,
    args.gamma,
    args.n_step,
    target_update_freq=args.target_update_freq
)

In [10]:
if args.icm_lr_scale > 0:
    feature_net = DQN(
        args.state_shape[0], args.action_shape, args.device, features_only=True
    )
    action_dim = np.prod(args.action_shape)
    feature_dim = feature_net.output_dim
    icm_net = IntrinsicCuriosityModule(
        feature_net.net,
        feature_dim,
        action_dim,
        hidden_sizes=[512],
        device=args.device
    )
    icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
    policy = ICMPolicy(
        policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
        args.icm_forward_loss_weight
    ).to(args.device)

### load policy

In [11]:
# load a previous policy
if args.resume_path:
    policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
    print("Loaded agent from: ", args.resume_path)

### replay buffer

In [12]:
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
    args.buffer_size,
    buffer_num=len(train_envs),
    ignore_obs_next=True,
    save_only_last_obs=False,
    stack_num=args.frames_stack
)

### collector

In [13]:
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)

### log

In [14]:
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "dqn_icm" if args.icm_lr_scale > 0 else "dqn"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)

In [15]:
# logger
if args.logger == "wandb":
    logger = WandbLogger(
        save_interval=1,
        name=log_name.replace(os.path.sep, "__"),
        run_id=args.resume_id,
        config=args,
        project=args.wandb_project,
    )
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
    logger = TensorboardLogger(writer)
else:  # wandb
    logger.load(writer)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhongxi[0m. Use [1m`wandb login --relogin`[0m to force relogin
/usr/bin/nvidia-modprobe: unrecognized option: "-s"

ERROR: Invalid commandline, please run `/usr/bin/nvidia-modprobe --help` for usage information.


/usr/bin/nvidia-modprobe: unrecognized option: "-s"

ERROR: Invalid commandline, please run `/usr/bin/nvidia-modprobe --help` for usage information.


  from IPython.core.display import HTML, display  # type: ignore


### train helper funciton

In [16]:
def save_best_fn(policy):
    torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
    return mean_rewards >= 20

def train_fn(epoch, env_step):
    # nature DQN setting, linear decay in the first 1M steps
    if env_step <= 1e6:
        eps = args.eps_train - env_step / 1e6 * \
            (args.eps_train - args.eps_train_final)
    else:
        eps = args.eps_train_final
    policy.set_eps(eps)
    if env_step % 1000 == 0:
        logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
    policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
    # see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
    ckpt_path = os.path.join(log_path, "checkpoint.pth")
    torch.save({"model": policy.state_dict()}, ckpt_path)
    return ckpt_path

# watch agent's performance
def watch():
    print("Setup test envs ...")
    policy.eval()
    policy.set_eps(args.eps_test)
    test_envs.seed(args.seed)
    if args.save_buffer_name:
        print(f"Generate buffer with size {args.buffer_size}")
        buffer = VectorReplayBuffer(
            args.buffer_size,
            buffer_num=len(test_envs),
            ignore_obs_next=True,
            save_only_last_obs=True,
            stack_num=args.frames_stack
        )
        collector = Collector(policy, test_envs, buffer, exploration_noise=True)
        result = collector.collect(n_step=args.buffer_size)
        print(f"Save buffer into {args.save_buffer_name}")
        # Unfortunately, pickle will cause oom with 1M buffer size
        buffer.save_hdf5(args.save_buffer_name)
    else:
        print("Testing agent ...")
        test_collector.reset()
        result = test_collector.collect(
            n_episode=args.test_num, render=args.render
        )
    rew = result["rews"].mean()
    print(f"Mean reward (over {result['n/ep']} episodes): {rew}")

In [17]:
### train core

In [None]:
if args.watch:
    watch()
    exit(0)

# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
    policy,
    train_collector,
    test_collector,
    args.epoch,
    args.step_per_epoch,
    args.step_per_collect,
    args.test_num,
    args.batch_size,
    train_fn=train_fn,
    test_fn=test_fn,
    stop_fn=stop_fn,
    save_best_fn=save_best_fn,
    logger=logger,
    update_per_step=args.update_per_step,
    test_in_train=False,
    resume_from_log=args.resume_id is not None,
    save_checkpoint_fn=save_checkpoint_fn,
)

pprint.pprint(result)
watch()

Epoch #1: 1001it [00:08, 113.63it/s, env_step=1000, len=0, loss=1038.091, n/ep=0, n/st=10, rew=0.00]                          


Epoch #1: test_reward: -10505.610000 ± 3671.260844, best_reward: -10505.610000 ± 3671.260844 in #1


Epoch #2: 1001it [00:10, 96.79it/s, env_step=2000, len=0, loss=544.229, n/ep=0, n/st=10, rew=0.00]                          


Epoch #2: test_reward: -3060.630000 ± 275.558564, best_reward: -3060.630000 ± 275.558564 in #2


Epoch #3: 1001it [00:09, 102.96it/s, env_step=3000, len=0, loss=394.813, n/ep=0, n/st=10, rew=0.00]                          


Epoch #3: test_reward: -2810.960000 ± 374.251932, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #4: 1001it [00:08, 122.72it/s, env_step=4000, len=0, loss=298.952, n/ep=0, n/st=10, rew=0.00]                          


Epoch #4: test_reward: -11596.260000 ± 2436.488336, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #5: 1001it [00:08, 116.97it/s, env_step=5000, len=1000, loss=258.702, n/ep=0, n/st=10, rew=-3343.20]                         


Epoch #5: test_reward: -10887.510000 ± 3170.688636, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #6: 1001it [00:09, 104.26it/s, env_step=6000, len=1000, loss=225.124, n/ep=0, n/st=10, rew=-3343.20]                          


Epoch #6: test_reward: -11864.050000 ± 2894.592107, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #7: 1001it [00:09, 105.09it/s, env_step=7000, len=1000, loss=189.323, n/ep=0, n/st=10, rew=-3343.20]                         


Epoch #7: test_reward: -3036.060000 ± 100.565195, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #8: 1001it [00:08, 114.44it/s, env_step=8000, len=1000, loss=170.587, n/ep=0, n/st=10, rew=-3343.20]                          


Epoch #8: test_reward: -2963.760000 ± 275.327373, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #9: 1001it [00:10, 96.25it/s, env_step=9000, len=1000, loss=166.237, n/ep=0, n/st=10, rew=-3343.20]                           


Epoch #9: test_reward: -2886.690000 ± 349.444990, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #10: 1001it [00:09, 106.01it/s, env_step=10000, len=1000, loss=161.529, n/ep=0, n/st=10, rew=-1814.68]                          


Epoch #10: test_reward: -2992.560000 ± 507.831910, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #11: 1001it [00:09, 101.49it/s, env_step=11000, len=1000, loss=154.707, n/ep=0, n/st=10, rew=-1814.68]                         


Epoch #11: test_reward: -3007.920000 ± 117.930970, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #12: 1001it [00:10, 92.93it/s, env_step=12000, len=1000, loss=138.595, n/ep=0, n/st=10, rew=-1814.68]                           


Epoch #12: test_reward: -2987.580000 ± 178.094507, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #13: 1001it [00:08, 114.27it/s, env_step=13000, len=1000, loss=131.553, n/ep=0, n/st=10, rew=-1814.68]                          


Epoch #13: test_reward: -2879.090000 ± 266.559935, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #14: 1001it [00:09, 108.75it/s, env_step=14000, len=1000, loss=126.611, n/ep=0, n/st=10, rew=-1814.68]                         


Epoch #14: test_reward: -2916.100000 ± 190.750177, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #15: 1001it [00:08, 112.65it/s, env_step=15000, len=1000, loss=117.218, n/ep=0, n/st=10, rew=-2084.18]                          


Epoch #15: test_reward: -12739.770000 ± 150.972322, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #16: 1001it [00:10, 92.75it/s, env_step=16000, len=1000, loss=120.056, n/ep=0, n/st=10, rew=-2084.18]                          


Epoch #16: test_reward: -2849.810000 ± 260.233862, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #17: 1001it [00:09, 111.12it/s, env_step=17000, len=1000, loss=115.304, n/ep=0, n/st=10, rew=-2084.18]                         


Epoch #17: test_reward: -3002.750000 ± 72.922661, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #18: 1001it [00:08, 118.83it/s, env_step=18000, len=1000, loss=114.134, n/ep=0, n/st=10, rew=-2084.18]                          


Epoch #18: test_reward: -12756.370000 ± 315.103793, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #19: 1001it [00:11, 90.27it/s, env_step=19000, len=1000, loss=106.817, n/ep=0, n/st=10, rew=-2084.18]                           


Epoch #19: test_reward: -10766.970000 ± 3539.879850, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #20: 1001it [00:10, 92.69it/s, env_step=20000, len=1000, loss=100.896, n/ep=0, n/st=10, rew=-2125.10]                          


Epoch #20: test_reward: -11705.300000 ± 3194.149124, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #21: 1001it [00:09, 104.05it/s, env_step=21000, len=1000, loss=112.520, n/ep=0, n/st=10, rew=-2125.10]                          


Epoch #21: test_reward: -2907.550000 ± 294.739761, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #22: 1001it [00:09, 102.43it/s, env_step=22000, len=1000, loss=111.195, n/ep=0, n/st=10, rew=-2125.10]                         


Epoch #22: test_reward: -3067.450000 ± 13.904478, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #23: 1001it [00:10, 98.83it/s, env_step=23000, len=1000, loss=108.758, n/ep=0, n/st=10, rew=-2125.10]                          


Epoch #23: test_reward: -2925.170000 ± 290.167479, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #24: 1001it [00:09, 108.15it/s, env_step=24000, len=1000, loss=107.511, n/ep=0, n/st=10, rew=-2125.10]                          


Epoch #24: test_reward: -2954.690000 ± 267.274407, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #25: 1001it [00:08, 112.28it/s, env_step=25000, len=1000, loss=106.597, n/ep=0, n/st=10, rew=-2230.28]                          


Epoch #25: test_reward: -3004.070000 ± 131.749824, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #26: 1001it [00:08, 111.82it/s, env_step=26000, len=1000, loss=91.303, n/ep=0, n/st=10, rew=-2230.28]                          


Epoch #26: test_reward: -3022.770000 ± 124.248791, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #27: 1001it [00:10, 96.91it/s, env_step=27000, len=1000, loss=95.131, n/ep=0, n/st=10, rew=-2230.28]                           


Epoch #27: test_reward: -3068.150000 ± 4.510488, best_reward: -2810.960000 ± 374.251932 in #3


Epoch #28: 1001it [00:10, 96.04it/s, env_step=28000, len=1000, loss=94.454, n/ep=0, n/st=10, rew=-2230.28]                           
