### Clone repo

In [None]:
# Clone repo
!git clone https://www.github.com/zbeucler2018/HotWheelsRL.git
%cd HotWheelsRL/

### Install pip libraries

In [None]:
#!python -m pip -v install git+https://github.com/DLR-RM/stable-baselines3@feat/gymnasium-support
#!python -m pip -v install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib@feat/gymnasium-support
#!python -m pip -v install git+https://github.com/Farama-Foundation/Gymnasium --upgrade
#!python -m pip install moviepy
#!python -m pip install imageio

!python -m pip install pip --upgrade
!python -m pip install "stable_baselines3[extra]>=2.0.0a9" # for gymnasium support
!python -m pip -v install git+https://github.com/Farama-Foundation/stable-retro.git
!python -m pip install wandb
!python -m pip install tensorboard

### import rom

In [None]:
# Download rom from user
from google.colab import files
uploaded_file = files.upload()
for filename, data in uploaded_file.items():
    with open(f"rom/{filename}", 'wb') as f:
        f.write(data)

In [None]:
import os

source_path = os.path.join(os.getcwd(), 'rom')
link_name = 'HotWheelsStuntTrackChallenge-gba'
lib_path = '/usr/local/lib/python3.10/dist-packages/retro/data/stable'

if not os.path.isdir(source_path):
    print(f'{source_path} is not a valid directory.')
    exit(1)

if not os.path.isdir(lib_path):
    print(f'{lib_path} is not a valid directory.')
    exit(1)

dest_path = os.path.join(lib_path, link_name)

if os.path.islink(dest_path):
    print(f'Removing existing symlink: {dest_path}')
    os.remove(dest_path)

os.symlink(source_path, dest_path)
print(f'Created symlink: {dest_path} -> {source_path}')

In [None]:
!python -m retro.import rom/

### Make env

In [None]:
from HotWheelsEnv import HotWheelsEnvFactory, CustomEnv, GameStates

env_config = CustomEnv(
    game_state=GameStates.MULTIPLAYER,
    discrete=True,
    multibinary=False,
    raw=True,
    grayscale=False,
    framestack=False
)


env = HotWheelsEnvFactory.make_env(env_config)

from stable_baselines3.common.env_checker import check_env

try:
    check_env(env)
except Exception as err:
    env.close()
    raise err

### Train agent

In [None]:
from trainer import Trainer, ModelConfig, WandbConfig, ValidAlgos

model_config = ModelConfig(
    policy="CnnPolicy",
    total_training_timesteps=1_000_000,
    max_episode_steps=25_000
)

wandb_config = WandbConfig(
    model_save_freq=25_000,
    hot_wheels_env_type=env_config
)

trainer = Trainer(env)

!wandb login #<API_KEY_HERE>

trainer.train(
    algo=ValidAlgos.PPO,
    modelConfig=model_config,
    wandbConfig=wandb_config
)

### Run agent

In [None]:
!python run_agent.py --algo=ppo --filename=ppo.zip --episodes=10 --record_gif