### Clone repo

In [None]:
# Clone repo
!git clone https://www.github.com/zbeucler2018/HotWheelsRL.git
%cd HotWheelsRL/

### Install pip libraries

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

!python -m pip install pip --upgrade
!python -m pip install /content/gdrive/MyDrive/HotWheelsRL/stable_retro-0.9.0-cp310-cp310-linux_x86_64.whl
!python -m pip install stable_baselines3
!python -m pip install wandb
!python -m pip install tensorboard

### import rom

In [None]:
import os
from google.colab import files

# copy rom from drive
!cp /content/gdrive/MyDrive/HotWheelsRL/rom.gba /content/HotWheelsRL/HotWheelsStuntTrackChallenge-gba/rom.gba

link_name = 'HotWheelsStuntTrackChallenge-gba'
source_path = os.path.join(os.getcwd(), link_name)
lib_path = '/usr/local/lib/python3.10/dist-packages/retro/data/stable'

if not os.path.isdir(source_path):
    print(f'{source_path} is not a valid directory.')
    exit(1)

if not os.path.isdir(lib_path):
    print(f'{lib_path} is not a valid directory.')
    exit(1)

dest_path = os.path.join(lib_path, link_name)

if os.path.islink(dest_path):
    print(f'Removing existing symlink: {dest_path}')
    os.remove(dest_path)

os.symlink(source_path, dest_path)
print(f'Created symlink: {dest_path} -> {source_path}')

!python -m retro.import /content/HotWheelsRL/HotWheelsStuntTrackChallenge-gba

### Make env

In [5]:
import retro
from HotWheelsEnv import make_env, CustomEnv, GameStates

env_config = CustomEnv(
    game_state=GameStates.SINGLE,
    action_space=retro.Actions.DISCRETE,
    grayscale=False,
    framestack=False,
    encourage_tricks=True
)


env = make_env(env_config)

### Log into WandB

In [None]:
!wandb login #<API_KEY_HERE>

### Train agent

In [None]:
from trainer import Trainer, ModelConfig, WandbConfig, ValidAlgos

model_config = ModelConfig(
    policy="CnnPolicy",
    total_training_timesteps=1_000_000,
    max_episode_steps=25_000
)

wandb_config = WandbConfig(
    model_save_freq=25_000,
    hot_wheels_env_type=env_config
)


Trainer.train(
    env=env,
    algo=ValidAlgos.PPO,
    modelConfig=model_config,
    wandbConfig=wandb_config
)

### Run agent

In [None]:
from stable_baselines3 import A2C, DQN, PPO
from gymnasium.wrappers import RecordVideo

import os

def get_filename(directory_path: str, extension: str):
    for filename in os.listdir(directory_path):
        if filename.endswith(extension):
            return filename
    return None

def find_file_with_extension(directory, extension):
    for root, dirs, files in os.walk(directory):
        for file in files:
            if file.endswith(extension):
                return os.path.abspath(os.path.join(root, file))
    return None

model_path = find_file_with_extension("models/", ".zip")

model = PPO.load(model_path)

env = make_env(env_config)

env = RecordVideo(
    env=env, 
    video_folder="videos/"
)

total_reward = 0
observation, info = env.reset(seed=42)
while True:
    action, _ = model.predict(observation, deterministic=True)
    observation, reward, terminated, truncated, info = env.step(action)

    total_reward += reward
    print(reward)

    if terminated or truncated:
        break


from google.colab import files
video_path = get_filename("videos/", ".mp4")
files.download(f"{video_path}")