In [None]:
!python --version
!pip --version

# Utils

    import_rom(local=False)
    load_model_from_gdrive(model_path)
    save_model_in_gdrive(model_path)
    load_model_from_local_drive()
    download_colab_file_to_local_drive(filename)
    imshow(frame)
    create_video(frames, filename, fps=30)

In [None]:
import shutil
import os
import numpy as np
from pathlib import Path
import cv2
from google.colab import files, drive
from matplotlib import pyplot as plt
%matplotlib inline 




def uploadFromLocalDrive():
  uploaded = files.upload()
  destination_path = "/content/HotWheelsRL/rom"
  filename = list(uploaded.keys())[0]
  try:
      shutil.move(filename, os.path.join(destination_path, filename))
  except Exception as err:
    raise Exception(err)



def uploadFromGoogleDrive():
  drive.mount('/content/gdrive')
  source = "/content/gdrive/MyDrive/theLab_/HotWheelsRL/rom.gba"
  dest = "/content/HotWheelsRL/rom/rom.gba"
  try:
    shutil.copy(source, dest)
  except Exception as err:
    raise Exception(err)



def import_rom(local=False):
  """
  Downloads the ROM from gdrive or local drive, creates a symlink to the HotWheelsRL/rom folder
  to allow for importing the rom into stable-retro
  """
  if local:
    uploadFromLocalDrive()
  else:
    uploadFromGoogleDrive()
  folder_name="/content/HotWheelsRL/rom"
  link_name="HotWheelsStuntTrackChallenge-gba"
  lib_path="/usr/local/lib/python3.8/dist-packages/retro/data/stable"
  # Validate input and handle errors
  if not Path(folder_name).is_dir():
      raise ValueError(f"{folder_name} is not a valid directory.")
  if not Path(lib_path).is_dir():
      raise ValueError(f"{lib_path} is not a valid directory.")
  # Define paths as Path objects
  source_path = Path.cwd() / folder_name
  dest_path = Path(lib_path) / link_name
  # Use Path.symlink_to() to create the symbolic link
  try:
      dest_path.symlink_to(source_path)
      print(f"Created symlink: {dest_path} -> {source_path}")
  except OSError as e:
      print(f"Error creating symlink: {e}")

      
  os.system('python -m retro.import /content/HotWheelsRL/rom/')
  return



def download_colab_file_to_local_drive(filename):
  """
  Downloads a file from the cloud instance to the local computer
  """
  try:
    files.download(filename)
  except Exception as err:
    print('Could not save file, is the path correct?')
    print(err)
    raise



def imshow(obs):
  """
  displays a 3d numpy array
  """
  plt.imshow(obs, interpolation='nearest')
  plt.show()



def create_video(frames, filename, fps=30):
  """
  converts frames into a mp4. create_video(frames, 'video.mp4')
  """
  height, width, _ = frames[0].shape
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
  video_writer = cv2.VideoWriter(filename, fourcc, fps, (width, height))
  for frame in frames:
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    video_writer.write(frame)
    video_writer.release()
  print(f"Video saved as {os.getcwd()}/{filename}")

# Env utils


- [ ] speed fix wrapper
- [ ] score reward wrapper
- [ ] Timestep penality wrapper
- [ ] button restriction wrapper

In [None]:
import retro
from stable_baselines3.common.atari_wrappers import WarpFrame, ClipRewardEnv
from gymnasium.wrappers import GrayScaleObservation, TimeLimit
from stable_baselines3.common.vec_env import VecFrameStack
from gymnasium.spaces import MultiBinary
from enum import Enum



def make_hot_wheels_env():
  env = retro.make("HotWheelsStuntTrackChallenge-gba", render_mode="rgb_array")
  env = GrayScaleObservation(env, keep_dim=True)
  env = VecFrameStack(env, n_stack=4)
  env = TimeLimit(env, max_episode_steps=15_000)
  return env




class HotWheelsButtons(Enum):
  ALL = ['B', None, 'SELECT', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'A', None, 'L', 'R'],
  FILTERED = ['B', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'A', 'L', 'R']



class HotWheelsEnv(retro.RetroEnv):
    def __init__(self, game, state=None, scenario=None):
        super().__init__(game, state=state, scenario=scenario)

    def step(self, action):
        # fix speed 
        info['speed0'] = info['speed0'] * 0.702
        return observation, reward, terminated, truncated, info


class TimePenaltyWrapper(gymnasium.Wrapper):
    def __init__(self, env, time_penalty=0.1):
        super().__init__(env)
        self.time_penalty = time_penalty

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        # add a penalty for each time step that the agent takes
        reward -= self.time_penalty
        if terminated or truncated:
            # add a bonus reward for reaching the goal in the fewest number of time steps
            reward += (self.env.unwrapped.time * 10)
        return observation, reward, terminated, truncated, info


import gymnasium as gym

class FixSpeedWrapper(gym.Wrapper):
  """
  Fixes env bug so the speed is accurate
  """
  def __init__(self, env):
    super().__init__(env)

  def step(self, action):
    observation, reward, terminated, truncated, info = self.env.step(action)
    info['speed'] *= 0.702
    return observation, reward, terminated, truncated, info




# load rom into gym-retro

In [None]:
!ls -lA /usr/local/lib/python3.8/dist-packages/

In [None]:
import_rom()

# load tensorboard

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
%tensorboard --logdir logs

## Train PPO sb3 model


[Tensorboard PPO article](https://medium.com/aureliantactics/understanding-ppo-plots-in-tensorboard-cbc3199b9ba2)

In [None]:
import retro
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import PPO


# import the rom into retro
import_rom()


# make env
env = retro.make("HotWheelsStuntTrackChallenge-gba", render_mode="rgb_array")
observation, info = env.reset(seed=42)


# check if valid env
try:
  check_env(env)
except Exception as err:
  env.close()
  print(err)
  raise


# train model
try:
  model = PPO('CnnPolicy', env, verbose=1, tensorboard_log='/content/HotWheelsRL/logs', learning_rate=0.000001)
  model.learn(total_timesteps=10_000)
except Exception as err:
  env.close()
  print(err)
  raise

env.close()

In [None]:
env.close()

# PPO

## Train

In [None]:
from stable_baselines3 import PPO

model = PPO('CnnPolicy', env, verbose=1, tensorboard_log="/content/HotWheelsRL/logs")
model.learn(25_000)

## Save/load model

In [None]:
from stable_baselines3 import PPO


filename = "ppo_25k_cnn"
filepath = f"/content/gdrive/MyDrive/theLab_/HotWheelsRL/{filename}"

#model.save(filepath)
#del model
model = PPO.load(filepath)


## Run agent

In [None]:
env = retro.make("HotWheelsStuntTrackChallenge-gba", render_mode="rgb_array")
obs, info = env.reset(seed=42)


total_reward = 0


while True:

    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(action)
    env.render()
    total_reward += reward

    print(info['progress'], reward, total_reward, env.get_action_meaning(action))

    if terminated or truncated:
      imshow(obs)
      break

env.close()

print(total_reward, terminated, truncated)

# Train A2C

In [None]:
%%time


from stable_baselines3 import A2C

# import the rom into retro
import_rom()

env = retro.make("HotWheelsStuntTrackChallenge-gba", render_mode="rgb_array")
obs, info = env.reset(seed=42)


model = A2C('CnnPolicy', env, verbose=1, tensorboard_log="/content/HotWheelsRL/logs")
model.learn(25_000)

## save A2C model

In [None]:
from stable_baselines3 import A2C


filename = "a2c_25k_cnn"
filepath = f"/content/gdrive/MyDrive/theLab_/HotWheelsRL/{filename}"

#model.save(filepath)
#del model
model = A2C.load(filepath)


## Run A2C

In [None]:
#env = retro.make("HotWheelsStuntTrackChallenge-gba", render_mode="rgb_array")
observation, info = env.reset()

totalFrames = 0
totalReward = 0
frames = []
while True:
  action, _state = model.predict(observation, deterministic=True)
  observation, reward, terminated, truncated, info = env.step(action)

  totalFrames += 1
  totalReward += reward
  frames.append(observation)


  if totalFrames % 50 == 0:
    print(info, totalFrames, env.get_action_meaning(action))


  if terminated or truncated:
    imshow(observation)
    break

env.close()




## Save agent trial to mp4

In [None]:
gif_name = "a2c_25k_cnn"
gif_path = f"/content/gdrive/MyDrive/theLab_/HotWheelsRL/{gif_name}"
create_video(frames, gif_path)

# Misc