In [1]:
import os
import sys
import git
import pathlib

In [2]:
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

PROJ_ROOT_PATH = pathlib.Path(git.Repo('.', search_parent_directories=True).working_tree_dir)
PROJ_ROOT =  str(PROJ_ROOT_PATH)
if PROJ_ROOT not in sys.path:
    sys.path.append(PROJ_ROOT)

print(f"Project Root Directory: {PROJ_ROOT}")

Project Root Directory: /repos/drl_csense


In [3]:
import numpy as np
import imageio
import ipyplot

In [4]:
import gymnasium as gym

In [5]:
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C

In [6]:
from IPython.display import Image

In [7]:
env_id = "BreakoutNoFrameskip-v4"
exp_tag = "vanilla"
exp_name = f"{env_id}--{exp_tag}"

In [8]:
logfolder_root = pathlib.Path(PROJ_ROOT_PATH / "logging")

# Directory to save all training statistics
log_dir = pathlib.Path(logfolder_root / exp_name)
os.makedirs(log_dir, exist_ok=True)

# Directory to save gif animations
gif_dir = pathlib.Path(log_dir / "gifs" )
os.makedirs(gif_dir, exist_ok=True)

# Directory to save models
models_dir = pathlib.Path(PROJ_ROOT_PATH / "models" / exp_name)
os.makedirs(models_dir, exist_ok=True)

In [9]:
trial_mother_seed = 202306
EXPERIMENT_NOS = [0,1,2,3,4]
for experiment in EXPERIMENT_NOS:
    # Make trial environment
    trial_env = make_atari_env(env_id,
                         n_envs=1,
                         seed=trial_mother_seed)
    # Frame-stacking with 4 frames
    trial_env = VecFrameStack(trial_env, n_stack=4)
    
    # Load RL model
    model_file = f"{models_dir}/{exp_name}-run_{experiment}"
    model = A2C.load(model_file)
    
    # Create animation
    duration = 10 #sec
    fps = 240 #fps
    no_of_frames = duration*fps*2 #only every other frame is used for animation
    
    images = []
    obs = trial_env.reset()
    img = trial_env.render(mode="rgb_array")
    for i in range(no_of_frames):
        images.append(img)
        action, _ = model.predict(obs)
        obs, reward, done, info = trial_env.step(action)
        img = trial_env.render(mode="rgb_array")
    
    # Convert frames to animation
    gif_file = f"{gif_dir}/{exp_name}-run_{experiment}.gif"
    imageio.mimsave(gif_file, 
                    [np.array(img) for i, img in enumerate(images) if i%2 == 0], duration=duration)
    trial_env.close()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]
  logger.warn(


In [10]:
# Display gifs
gif_paths = [f"{gif_dir}/{exp_name}-run_{experiment}.gif" for experiment in EXPERIMENT_NOS]
labels = [f"run_{experiment}" for experiment in EXPERIMENT_NOS]
ipyplot.plot_images(gif_paths, labels, 
                    max_images=len(gif_paths), 
                    img_width=100, 
                    show_url=False)