In [None]:
import retro

import gymnasium as gym
import re
import random

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize
from stable_baselines3.common.callbacks import CheckpointCallback, BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from gym.wrappers import GrayScaleObservation, FrameStack
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.vec_env import SubprocVecEnv
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.atari_wrappers import ClipRewardEnv, WarpFrame

import torch as th
from torch import nn
import numpy as np
import cv2

import os

from pathlib import Path
import datetime
from pytz import timezone
from stable_baselines3.common.utils import get_schedule_fn

def custom_schedule(progress_remaining):
    #return 1e-4 + (3e-4 - 1e-4) * (1 - progress_remaining)
    return 5e-5 + (2e-4 - 5e-5) * (1 - progress_remaining)

# Model Param
CHECK_FREQ_NUMB = 1000
TOTAL_TIMESTEP_NUMB = 500_000_000
LEARNING_RATE = custom_schedule # 0.00025 # 0.0002 # 0.0001 # 0.00025 # 0.0001
GAE = 0.98 # 0.9 # 1.0 # 0.95 # 1.0
ENT_COEF = 0.004 # 0.001 # 0.03 # 0.01 # 0.03 # 0.1 # 0.03 # 0.02 # 0.01 # 0.005 # 0.01
N_STEPS = 2048 # 4096 # 512 # 2048 # 4096 # 2048 # 512
GAMMA = 0.99 # 0.9
BATCH_SIZE = 512 # 128 # 64
CLIP_RANGE = 0.1 # 0.15 # 0.2 # 0.4 # 0.3
N_EPOCHS = 10 # 6 # 10 # 15 # 10
MAX_EPISODE=0# 15000
USE_CURRICULUM=False
USE_CLIP_REWARD=False
STATE="ControlA.MidRoute.Corneria"
TENSORBOARD="./tensorboard-starfox"
SAVE_DIR="./model-starfox"
NUM_ENV = 16

model = None

# Test Param
EPISODE_NUMBERS = 20
SAVE_FREQ=1000

save_dir = Path(SAVE_DIR)
GAME = "StarFox-Snes"
states = retro.data.list_states(GAME)

# print(retro.data.list_games())
print(f"States for {GAME}: {states}")

class Discretizer(gym.ActionWrapper):
    """
    Wrap a gym environment and make it use discrete actions.

    Args:
        combos: ordered list of lists of valid button combinations
    """

    def __init__(self, env, combos):
        super().__init__(env)
        assert isinstance(env.action_space, gym.spaces.MultiBinary)
        buttons = env.unwrapped.buttons
        self._decode_discrete_action = []
        for combo in combos:
            arr = np.array([False] * env.action_space.n)
            for button in combo:
                arr[buttons.index(button)] = True
            self._decode_discrete_action.append(arr)

        self.action_space = gym.spaces.Discrete(len(self._decode_discrete_action))

    def action(self, act):
        return self._decode_discrete_action[act].copy()

class MainDiscretizer(Discretizer):
    """
    Use Sonic-specific discrete actions
    based on https://github.com/openai/retro-baselines/blob/master/agents/sonic_util.py
    """

    def __init__(self, env):
        super().__init__(
            env=env,
            combos=[
                ["LEFT"],         # rotate to left
                ["RIGHT"],        # rotate to right
                ["UP"],          
                ["DOWN"],         
                ["B"],         # shoot  
                ["Y"],           
                ["L"],        # tilt to left
                ["R"],        # tilt to right 
                ["LEFT", "B"],    # shoot left
                ["RIGHT", "B"],   # shoot right
                ["UP", "B"],      # shoot up
                ["DOWN", "B"],   # shoot down
                ["L", "LEFT"],    # tilt to left and go to left
                ["R", "RIGHT"],   # tilt to right and go to right
                ["L", "RIGHT"],   # tilt to left and go to right
                ["R", "LEFT"],    # tilt to right and go to left
                ["L", "LEFT", "B"],  # tilt to left and go to left and fire
                ["R", "RIGHT", "B"], # tilt to right and go to right and fire
            ]
        )

class IgnorePauseActionWrapper(gym.ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
    
    def action(self, act):
        act[3] = 0
        return act


class ResetStateWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
        self.env = env
        self.steps = 0
        self.lose_lives = False
        self.current_health = 40
        self.current_lives = 3

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        #self.x_last = self.env.unwrapped.data['x']

        self.steps = 0

        self.current_health = 40

        self.current_lives = 3

        self.lose_lives = False

        return obs

    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)

        self.steps += 1

        # reward = 0
        if reward > 0:
            reward = 0.5 # new score reward

        # Existential to encourage staying alive longer
        # reward += 0.05

        # 'health': 36, 'kills': 15, 'lives': 2,
        health = info['health']
        lives = info['lives']

        # Reach Max Trainning Step
        if MAX_EPISODE > 0 and self.steps > MAX_EPISODE:
            done = True
            reward -= 2

        # lost lives
        if self.current_lives > info['lives']:
            reward -= 2
            self.current_lives =  info['lives']
            done = True

        # win lives
        if info['lives'] > self.current_lives:
            reward += 1
            self.current_lives =  info['lives']

        # lose
        if info['lives'] < 1:
            reward -= 2
            done = True

        # lost shield
        if info['health'] < self.current_health:
            reward -= 0.5
            self.current_health = info['health']

        if info['health'] > self.current_health:
            reward += 0.5
            self.current_health = info['health']

        # if done and not self.lose_lives:
        #     reward += 2

        
        return obs, reward, done, trunc, info

class RandomStateWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        
        self.env = env

    def get_random_state(self):
        """Select a random state from folder STATE_PATH"""
        STATE_PATH = "./States"
        states = [f for f in os.listdir(STATE_PATH) if f.endswith(".state")]
        if not states:
            raise FileNotFoundError("File not found!")
        c = random.choice(states)

        return os.path.abspath("./States/" + c)

    def reset(self, **kwargs):
        state = self.get_random_state()
        #print(f"Loading state: {state}")
        self.env.load_state(state)

        obs = self.env.reset(**kwargs)

        return obs

    def step(self, action):
        obs, reward, done, trunc, info = self.env.step(action)

        if done or trunc:
            self.reset()

        return obs, reward, done, trunc, info


class CurriculumWrapper(gym.Wrapper):
    def __init__(self, env, required_wins=20): #required_avg_reward=1.0):
        super().__init__(env)
        self.required_wins = required_wins
        #self.required_avg_reward = required_avg_reward
        self.current_phase = 1
        self.total_wins = 0 
        self.rewards_list = []

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        return obs

    def step(self, action):
        obs, reward, done, truncated, info = self.env.step(action)

        self.rewards_list.append(reward)

        could_to_next_stage = info["matches_won"] / self.current_phase >= 2

        if info["matches_won"] % 2 == 0 and info["matches_won"] > 0 and could_to_next_stage:
            self.total_wins += 1


        # avg_reward = np.mean(self.rewards_list[-self.required_wins:]) if len(self.rewards_list) >= self.required_wins else np.mean(self.rewards_list)
        avg_reward = np.mean(self.rewards_list)

        if could_to_next_stage and \
            ((info["matches_won"] % 2 == 0  and info["matches_won"] > 0) \
                 or (info["enemy_matches_won"] % 2 == 0 and info["enemy_matches_won"] > 0)) :
            print(info)
            print(f"🔥 stage {self.current_phase}! ({self.total_wins} fights win, avg rewards: {avg_reward:.2f})")
            done = True
        
        if self.total_wins >= self.required_wins: #and avg_reward >= self.required_avg_reward:
            self.current_phase += 1
            print(f"🔥 Going to next stage {self.current_phase}! ({self.total_wins} fights win, avg rewards: {avg_reward:.2f})")
            self.total_wins = 0
            self.rewards_list = []

        return obs, reward, done, truncated, info

class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            observation, reward, terminated, trunk, info = self.env.step(action)
            total_reward += reward
            if terminated:
                break
        return observation, total_reward, terminated, trunk, info


class GameNet(BaseFeaturesExtractor):

    def __init__(self, observation_space: gym.spaces.Box, features_dim):
        super(GameNet, self).__init__(observation_space, features_dim)
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_freq=SAVE_FREQ,
        self.save_path = save_path

        self.best_reward = float('-inf')
        self.episode_rewards = []
        self.current_episode_reward = 0

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        reward = self.locals["rewards"][0]
        self.current_episode_reward = reward

        done = self.locals["dones"][0]

        self.episode_rewards.append(self.current_episode_reward)

        if done:
            self.current_episode_reward = 0
            self.episode_rewards = []
            self.best_reward = float('-inf')
        
        if self.n_calls % self.check_freq == 0 and len(self.episode_rewards) > 0:
            latest_model = get_latest_model(self.save_path)
            next_save_step = (int(re.search(r"best_model_(\d+)", str(latest_model)).group(1)) + self.check_freq) if latest_model else self.n_calls
            model_path = self.save_path / f"best_model_{next_save_step}"
            model.save(model_path)
            print(f"Model saved in: {model_path}")

            average_reward = sum(self.episode_rewards) / len(self.episode_rewards)
            best_reward = max(self.episode_rewards)
            sum_rewards = sum(self.episode_rewards)

            self.best_reward = max(self.best_reward, best_reward)

            self.logger.record("average_reward", average_reward)
            self.logger.record("best_reward", self.best_reward)
            self.logger.record("sum_rewards", sum_rewards)

            if USE_CURRICULUM:
                self.logger.record("current_phase", self.training_env.get_attr("current_phase")[0])

            print(f"Time steps: {self.n_calls}, Average Reward: {average_reward}, Best Reward: {self.best_reward}")


        return True
          

policy_kwargs = dict(
    features_extractor_class=GameNet,
    features_extractor_kwargs=dict(features_dim=1024), # features_extractor_kwargs=dict(features_dim=512),
    net_arch=dict(
        pi=[1024, 512, 256],  # Actor
        vf=[1024, 1024, 512]  # Critic
    ) #
)

def get_latest_model(path):
    models = list(path.glob("best_model_*"))
    if not models:
        return None
    model_numbers = [int(re.search(r"best_model_(\d+)", str(m)).group(1)) for m in models]
    latest_model = max(model_numbers)
    return path / f"best_model_{latest_model}"

def make_env():
    def _init():
        env = retro.make(
            game=GAME, 
            #use_restricted_actions=retro.Actions.DISCRETE, 
            # render_mode="human",
            render_mode=None,
            state=STATE
        )

        # env = DonkeyKongCustomActions(env)

        env = MainDiscretizer(env)

        # env = IgnorePauseActionWrapper(env)

        # env = RandomStateWrapper(env)

        env = ResetStateWrapper(env)
        
        env = SkipFrame(env, skip=4)
        env = WarpFrame(env)

        if USE_CURRICULUM:
            env = CurriculumWrapper(env, required_wins=50) #, required_avg_reward=0.6)

        if USE_CLIP_REWARD:
            env = ClipRewardEnv(env)

        #env = TimeLimit(env, max_episode_steps=MAX_EPISODE)
        return env
    return _init

env = SubprocVecEnv([make_env() for _ in range(NUM_ENV)])
# env = DummyVecEnv([make_env()])
# env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
env = VecFrameStack(env, 4, channels_order='last')

checkpoint_callback=TrainAndLoggingCallback(check_freq=CHECK_FREQ_NUMB, save_path=save_dir)


latest_model_path = get_latest_model(save_dir)

if latest_model_path:
    print(f"Loading existent model: {latest_model_path}")
    model = PPO.load(
        str(latest_model_path), 
        env=env, 
        verbose=0, 
        tensorboard_log=TENSORBOARD, 
        learning_rate=LEARNING_RATE, 
        n_steps=N_STEPS, 
        batch_size=BATCH_SIZE, 
        n_epochs=N_EPOCHS, 
        gamma=GAMMA, 
        gae_lambda=GAE, 
        clip_range=CLIP_RANGE,
        ent_coef=ENT_COEF,
        policy_kwargs=policy_kwargs, 
    )
    
else:
    print("None finded, starting from zero.")
    model = PPO('CnnPolicy', 
                env, 
                verbose=0, 
                policy_kwargs=policy_kwargs, 
                tensorboard_log=TENSORBOARD, 
                learning_rate=LEARNING_RATE, 
                n_steps=N_STEPS, 
                batch_size=BATCH_SIZE, 
                n_epochs=N_EPOCHS, 
                gamma=GAMMA, 
                gae_lambda=GAE, 
                clip_range=CLIP_RANGE,
                ent_coef=ENT_COEF
            )

model.learn(total_timesteps=TOTAL_TIMESTEP_NUMB, reset_num_timesteps=False, callback=checkpoint_callback)
model.save("starfox_final")

env.close()



States for StarFox-Snes: ['ControlA.MidRoute.Corneria']
None finded, starting from zero.
Model saved in: model-starfox/best_model_1000
Time steps: 1000, Average Reward: 0.002, Best Reward: 0.5
Model saved in: model-starfox/best_model_2000
Time steps: 2000, Average Reward: 0.0013054830287206266, Best Reward: 0.5
