In [1]:
import gymnasium as gym
import retro
import random
import time
import sounddevice as sd

from gymnasium.wrappers import TimeLimit
from stable_baselines3 import A2C, PPO
# from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize, VecTransposeImage
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from gymnasium.wrappers import FrameStackObservation
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, WarpFrame
from pathlib import Path
from sb3_contrib import RecurrentPPO
from gymnasium.wrappers import RecordVideo

import torch as th
from torch import nn
import numpy as np
import cv2
import re
import datetime

import os


LOAD_FROM="./model-tmnt"
GAME="TeenageMutantNinjaTurtlesIVTurtlesInTime-Snes"
STATE="Leo.Level1.99Lives"

states = retro.data.list_states(GAME)

# print(retro.data.list_games())
print(f"States for {GAME}: {states}")

MAX_EPISODE=122000
SAVE_VIDEO = False
IS_RECURRENT = False
USE_CURRICULUM = False
DETERMINISTIC = False
LOAD_FROM_VIDEO= LOAD_FROM + "/best_model_100000"

# LOAD_FROM="./model-sbx"
# LOAD_FROM="./model2"
# GAME="SuperMarioBros-Nes"
# STATE="Level1-1"

class IgnorePauseActionWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def action(self, act):
        act[3] = 0
        return act

class Discretizer(gym.ActionWrapper):
    """
    Wrap a gym environment and make it use discrete actions.

    Args:
        combos: ordered list of lists of valid button combinations
    """

    def __init__(self, env, combos):
        super().__init__(env)
        assert isinstance(env.action_space, gym.spaces.MultiBinary)
        buttons = env.unwrapped.buttons
        self._decode_discrete_action = []
        for combo in combos:
            arr = np.array([False] * env.action_space.n)
            for button in combo:
                arr[buttons.index(button)] = True
            self._decode_discrete_action.append(arr)

        self.action_space = gym.spaces.Discrete(len(self._decode_discrete_action))

    # def action(self, act):
    #     return self._decode_discrete_action[act].copy()
    def action(self, act):
        return self._decode_discrete_action[act]

class MainDiscretizer(Discretizer):
    """
    Use Sonic-specific discrete actions
    based on https://github.com/openai/retro-baselines/blob/master/agents/sonic_util.py
    """

    def __init__(self, env):
        super().__init__(
            env=env,
            combos=[
                ["LEFT"],         # Move left
                ["RIGHT"],        # Move right
                ["UP"],           # Move up
                ["DOWN"],         # Move down
                ["B"],            # Attack
                ["Y"],            # Special attack
                ["A"],            # Jump attack
                ["X"],            # Throw (in some cases)
                ["L"],            # Block (if available)
                ["R"],            # Not used, but reserved
                
                # Combined actions
                ["LEFT", "B"],     # Move left and attack
                ["RIGHT", "B"],    # Move right and attack
                ["UP", "B"],       # Move up and attack
                ["DOWN", "B"],     # Move down and attack
                
                ["LEFT", "A"],     # Move left and jump
                ["RIGHT", "A"],    # Move right and jump
                ["UP", "A"],       # Move up and jump
                ["DOWN", "A"],     # Move down and jump

                ["LEFT", "A", "B"],  # Jump left and attack
                ["RIGHT", "A", "B"], # Jump right and attack

                # ["B", "Y"],        # Attack + Special
            ]
        )


class ResetStateWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
        self.env = env
        self.steps = 0
        self.lose_lives = False
        self.current_health = 16
        self.current_lives = 2

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        #self.x_last = self.env.unwrapped.data['x']

        self.steps = 0

        self.current_health = 16

        self.current_lives = 3

        self.lose_lives = False

        return obs

    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)

        self.steps += 1

        # reward = 0
        if reward > 0:
            reward = 0.1 # new score reward

        # Existential to encourage staying alive longer
        reward -= 0.05

        lives = info['lives']

        health = info['health']

        if health == 80:
            health = 16

        if health < self.current_health:
            self.current_health = health
            reward -=0.5

        if MAX_EPISODE > 0 and self.steps > MAX_EPISODE:
            done = True
            reward -= 1

        # if lives < 2:
        #     reward -= 1
        #     self.current_lives =  info['lives']
        #     done = True

        # if lives < 1:
        #     done = True
        
        return obs, reward, done, trunc, info


class RandomStateWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
        self.env = env

    def get_random_state(self):
        """Select a random state from folder STATE_PATH"""
        STATE_PATH = "./States"
        states = [f for f in os.listdir(STATE_PATH) if f.endswith(".state")]
        if not states:
            raise FileNotFoundError("File not found!")
        c = random.choice(states)

        return os.path.abspath("./States/" + c)

    def reset(self, **kwargs):
        state = self.get_random_state()
        print(f"Loading state: {state}")
        self.env.load_state(state)

        obs = self.env.reset(**kwargs)

        return obs

    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)
      
    
        return obs, reward, done, trunc, info


class CurriculumWrapper(gym.Wrapper):
    def __init__(self, env, required_wins=20, required_avg_reward=1.0):
        super().__init__(env)
        self.required_wins = required_wins
        self.required_avg_reward = required_avg_reward
        self.current_phase = 1
        self.total_wins = 0 
        self.rewards_list = []

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        return obs

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)

        self.rewards_list.append(reward)

        could_to_next_stage = info["matches_won"] / self.current_phase >= 2

        if info["matches_won"] % 2 == 0 and info["matches_won"] > 0 and could_to_next_stage:
            self.total_wins += 1


        avg_reward = np.mean(self.rewards_list[-self.required_wins:]) if len(self.rewards_list) >= self.required_wins else np.mean(self.rewards_list)

        if could_to_next_stage and \
            ((info["matches_won"] % 2 == 0  and info["matches_won"] > 0) \
                 or (info["enemy_matches_won"] % 2 == 0 and info["enemy_matches_won"] > 0)) :
            print(info)
            print(f"🔥 stage {self.current_phase}! ({self.total_wins} fights win, avg rewards: {avg_reward:.2f})")
            done = True
        
        if self.total_wins >= self.required_wins and avg_reward >= self.required_avg_reward:
            self.current_phase += 1
            print(f"🔥 Going to next stage {self.current_phase}! ({self.total_wins} fights win, avg rewards: {avg_reward:.2f})")
            self.total_wins = 0
            self.rewards_list = []

        return obs, reward, done, truncated, info

def make_test_env(video_prefix):
    env = retro.make(
        game=GAME, 
        render_mode="rgb_array", # rgb_array or human
        state=STATE,
        #use_restricted_actions=retro.Actions.DISCRETE,
    )

    my_env = env
    # env = RandomStateWrapper(env)
    # env = IgnorePauseActionWrapper(env)

    env = MainDiscretizer(env)
    env = ResetStateWrapper(env)
    env = WarpFrame(env)
    #env = ClipRewardEnv(env)

    if USE_CURRICULUM:
        env = CurriculumWrapper(env, required_wins=3, required_avg_reward=0.3)

    if video_prefix != None:
        video_prefix = video_prefix.replace(LOAD_FROM, "").replace("/", "")

    print("video_prefix:", video_prefix)
    
    if SAVE_VIDEO:
        env = RecordVideo(
            env, 
            video_folder="videos/", 
            # episode_trigger=lambda e: True, 
            episode_trigger=lambda episode_id: episode_id == 0, # record only first episode
            fps=60,
            name_prefix=f"gameplay_{video_prefix}"
        )
    env = DummyVecEnv([lambda: env])
    # env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
    env = VecFrameStack(env, 4)
    #env = VecTransposeImage(env)
    # env = TimeLimit(env, max_episode_steps=MAX_EPISODE)
    return env, my_env

save_dir = Path(LOAD_FROM)


def get_latest_model(path):
    models = list(path.glob("best_model_*"))
    if not models:
        return None
    model_numbers = [int(re.search(r"best_model_(\d+)", str(m)).group(1)) for m in models]
    latest_model = max(model_numbers)
    return path / f"best_model_{latest_model}"

latest_model_path = get_latest_model(save_dir)


if SAVE_VIDEO:
    latest_model_path = LOAD_FROM_VIDEO


print(f"Loading existent model: {latest_model_path}")

# latest_model_path= LOAD_FROM + "/best_model_413000"

env, my_env = make_test_env(f"{latest_model_path}")

env.metadata['render_fps'] = 60

model = None

if IS_RECURRENT:
    model = RecurrentPPO.load(str(latest_model_path), device="cuda", verbose=0)
else:
    # model = PPO.load(str(latest_model_path), device="cuda", verbose=0)
    model = PPO.load(
        str(latest_model_path), 
        env=env, 
        verbose=0, 
        # learning_rate=LEARNING_RATE,
        # gae_lambda=GAE,
        # clip_range=CLIP_RANGE,
        # ent_coef=ENT_COEF,
        # gamma=GAMMA,
        # n_steps=N_STEPS
    )

obs = env.reset()
done = False


render_mode = "human"

if SAVE_VIDEO:
    render_mode = "rgb_array" 

step_count = 0
while True:
    env.render(render_mode) 
    action, _ = model.predict(obs, deterministic=DETERMINISTIC)
    
   
    obs, reward, done, info = env.step(action)

    # step_count+= 1

    # print(step_count)

    if done:
        if SAVE_VIDEO:
            break
        else:
            obs = env.reset()
        
env.close()

States for TeenageMutantNinjaTurtlesIVTurtlesInTime-Snes: ['1Player.Leonardo.Level1', 'Leo.Level1.99Lives', 'Leo.Level1.InfinitiLives']
Loading existent model: model-tmnt/best_model_7644000
video_prefix: model-tmntbest_model_7644000


KeyboardInterrupt: 