# Install Dependencies

In [None]:
# ! pip install rl_warp_drive
# ! pip install pytorch_lightning

In [None]:
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint

In [None]:
import argparse
import numpy as np
import time
import torch

from example_envs.tag_continuous.tag_continuous import TagContinuous
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.training.lightning_trainer import WarpDriveModel
from warp_drive.training.utils.data_loader import create_and_push_data_placeholders

In [None]:
# Set logger level e.g., DEBUG, INFO, WARNING, ERROR
import logging

logging.getLogger().setLevel(logging.ERROR)

In [None]:
# Specify a set of run configurations for your experiments.
# Note: these override some of the default configurations in 'warp_drive/training/run_configs/default_configs.yaml'.
run_config = dict(
    name="tag_continuous",
    # Environment settings
    env=dict(
        num_taggers=5,
        num_runners=20,
        episode_length=100,
        seed=1234,
        use_full_observation=False,
        num_other_agents_observed=10,
        tagging_distance=0.02,
    ),
    # Trainer settings
    trainer=dict(
        num_envs=100,  # number of environment replicas (number of GPU blocks used)
        train_batch_size=10000,  # total batch size used for training per iteration (across all the environments)
        num_episodes=5000,  # total number of episodes to run the training for (can be arbitrarily high!)
    ),
    # Policy network settings
    policy=dict(
        runner=dict(
            to_train=False,  # flag indicating whether the model needs to be trained
            algorithm="A2C",  # algorithm used to train the policy
            gamma=0.98,  # discount rate
            lr=0.005,  # learning rate
            model=dict(
                type="fully_connected", fc_dims=[256, 256], model_ckpt_filepath=""
            ),  # policy model settings
        ),
        tagger=dict(
            to_train=True,
            algorithm="A2C",
            gamma=0.98,
            lr=0.002,
            model=dict(
                type="fully_connected", fc_dims=[256, 256], model_ckpt_filepath=""
            ),
        ),
    ),
    # Checkpoint saving setting
    saving=dict(
        metrics_log_freq=10,  # how often (in iterations) to print the metrics
        model_params_save_freq=5000,  # how often (in iterations) to save the model parameters
        basedir="/tmp",  # base folder used for saving
        name="continuous_tag",  # experiment name
        tag="example",  # experiment tag
    ),
)

In [None]:
# Create a wrapped environment object via the EnvWrapper
# Ensure that use_cuda is set to True (in order to run on the GPU)
env_wrapper = EnvWrapper(
    TagContinuous(**run_config["env"]),
    num_envs=run_config["trainer"]["num_envs"],
    use_cuda=True,
)

# Agents can share policy models: this dictionary maps policy model names to agent ids.
policy_tag_to_agent_id_map = {
    "tagger": list(env_wrapper.env.taggers),
    "runner": list(env_wrapper.env.runners),
}

In [None]:
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser = Trainer.add_argparse_args(parent_parser)

parser = WarpDriveModel.add_model_specific_args(parent_parser)

# args = parser.parse_args()
args, _ = parser.parse_known_args()

wd_model = WarpDriveModel(
    env_wrapper=env_wrapper, 
    config=run_config, 
    policy_tag_to_agent_id_map=policy_tag_to_agent_id_map, 
    **args.__dict__
)

# save checkpoints based on avg_reward
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", verbose=True)

trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback, gpus=1)

# # trainer.fit(wd_trainer)

In [None]:
# PATCH
# Aternatively set environ variable for Cublas
torch.use_deterministic_algorithms(False)

In [None]:
trainer.fit(wd_model)