In [None]:
import numpy as np
import pandas as pd
import retro
from gym import Env
from gym.spaces import MultiBinary, Box
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
import cv2
import matplotlib.pyplot as plt
import os
import time
import retrowrapper
import warnings
warnings.simplefilter("ignore")

! rm -rf logs opt train train_logs
! mkdir logs opt train train_logs

In [None]:
! ls -R

In [None]:
# env.observation_space.shape

In [None]:
# env.action_space

In [None]:
# info

In [None]:
class StreetFighter(Env):
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
        self.action_space = MultiBinary(12)
        self.game = retro.make(game='StreetFighterIISpecialChampionEdition-Genesis',
                               record='./logs/',
#                                state=sts[0],
                               use_restricted_actions=retro.Actions.FILTERED,
#                               scenario='scenario'
                              )

    def get_reward(self, info, reward=None): return info['score'] - self.score

    def step(self, action):
        obs, reward, done, info = self.game.step(action)
        obs = self.preprocess(obs)
        frame_delta = obs - self.previous_frame
        self.previous_frame = obs
        reward = self.get_reward(info)
        self.score = info['score']
        return frame_delta, reward, done, info

    def render(self, *args, **kwargs):
        self.game.render(*args, **kwargs)

    def reset(self):
        obs = self.game.reset()
        obs = self.preprocess(obs)
        self.previous_frame = obs
        self.score = 0
        return obs

    def close(self):
        self.game.close()

    def preprocess(self, obs):
        gray = cv2.cvtColor(obs, cv2.COLOR_BGR2GRAY)
        resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_CUBIC)
        return np.reshape(resized, (84, 84, 1))

In [None]:
# env = StreetFighter()

In [None]:
# obs = env.reset()
# done = False
# for game in range(1):
#     # one game
#     while not done:
#         action = env.action_space.sample()
#         # env.render()
#         obs, reward, done, info = env.step(action)
#         # time.sleep(0.01)
#         if reward: print(reward)
# env.close()

In [None]:
# obs = env.reset()


In [None]:
# obs, reward, done, info = env.step(env.action_space.sample())
# plt.imshow(obs)

In [None]:
LOG_DIR = './logs/'
OPT_DIR = './opt/'

In [None]:
def optimize_ppo(trial):
    return {
        'n_steps':trial.suggest_int('n_steps', 2048, 8192),
        'gamma':trial.suggest_loguniform('gamma', 0.8, 0.9999),
        'learning_rate':trial.suggest_loguniform('learning_rate', 1e-5, 1e-4),
        'clip_range':trial.suggest_uniform('clip_range', 0.1, 0.4),
        'gae_lambda':trial.suggest_uniform('gae_lambda', 0.8, 0.99)
    }

In [None]:
SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(1))

In [None]:

def optimize_agent(trial):
    # try:
    model_params = optimize_ppo(trial)

    # Create environment
    env = StreetFighter()
    env = Monitor(env, LOG_DIR)
    env = DummyVecEnv([lambda: env])
    env = VecFrameStack(env, 4, channels_order='last')

    # Create algo
    model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=0, **model_params)
    model.learn(total_timesteps=30000)
    #model.learn(total_timesteps=100000)

    # Evaluate model
    mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)
    env.close()

    SAVE_PATH = os.path.join(OPT_DIR, 'trial_{}_best_model'.format(trial.number))
    model.save(SAVE_PATH)

    return mean_reward

    # except Exception as e:
    #     return -1000

In [None]:
study = optuna.create_study(direction='maximize')
study.optimize(optimize_agent, n_trials=1, n_jobs=1)

In [None]:
study.best_params

In [None]:
study.best_trial

In [None]:
study.best_trial.number

In [None]:
path = os.path.join(OPT_DIR, 'trial_{}_best_model.zip'.format(study.best_trial.number))
print(path)
model = PPO.load(path)

In [None]:
from stable_baselines3.common.callbacks import BaseCallback

In [None]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [None]:
CHECKPOINT_DIR = './train/'
TRAIN_LOG_DIR = './train_logs/'
callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)

In [None]:

env = StreetFighter()
env = Monitor(env, LOG_DIR)
env = DummyVecEnv([lambda: env])
env = VecFrameStack(env, 4, channels_order='last')

In [None]:
model_params = study.best_params
model_params['n_steps'] = 7488  # set n_steps to 7488 or a factor of 64
# model_params['learning_rate'] = 5e-7
model_params

In [None]:
model = PPO('CnnPolicy', env, tensorboard_log=TRAIN_LOG_DIR, verbose=1, **model_params)

In [None]:
model.load(path)

In [None]:
model.learn(total_timesteps=30000, callback=callback)

In [None]:
mean_reward, _ = evaluate_policy(model, env, render=False, n_eval_episodes=1)

In [None]:
mean_reward

In [None]:
obs = env.reset()

In [None]:
obs.shape

In [None]:
env.step(model.predict(obs)[0])

In [None]:
# Reset game to starting state
obs = env.reset()
# Set flag to flase
done = False
for game in range(1):
    while not done:
        if done:
            obs = env.reset()
#         env.render()
        action = model.predict(obs)[0]
        obs, reward, done, info = env.step(action)
        time.sleep(0.01)
        if reward: print(reward)