In [1]:
from constants import *
from utils import evaluate_model_policy
from trainer import get_trained_model
import optuna
from environment import StreetFighterEnv
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from actor_critic import A2CCNNPolicy
from feature_extractors import CNNExtractorWithAttention, CNNExtractor
from tuner import Tuner
import os
import plotly.io as pio
pio.renderers.default='notebook'

# MODEL WITH ATTENTION

In [None]:
model = A2C
model_dir = 'models/with_attention'
env = StreetFighterEnv()
policy_network = A2CCNNPolicy
frame_size = 1
timesteps = 1
policy_kwargs = dict(
    features_extractor_class=CNNExtractorWithAttention
)
tuner = Tuner(model=model, env=env, policy_network=policy_network, policy_args=policy_kwargs, 
              frame_size=frame_size, timesteps=timesteps, save_dir=model_dir)

study = tuner.tune_study(n_trials=2, )
env.close()

study.best_trial, study.best_params

[32m[I 2022-04-17 01:47:29,493][0m A new study created in memory with name: no-name-3a379b3c-bf60-4afe-80ab-b0f7d46376a6[0m


In [None]:
optuna.visualization.plot_parallel_coordinate(study)

In [None]:
optuna.visualization.plot_contour(study)

In [None]:
optuna.visualization.plot_slice(study)

In [None]:
optuna.visualization.plot_param_importances(study)

In [None]:
optuna.visualization.plot_edf(study)

In [None]:
optuna.visualization.plot_optimization_history(study)

# MODEL WITHOUT ATTENTION

In [None]:
model = A2C
model_dir = 'models/without_attention'
env = StreetFighterEnv()
policy_network = A2CCNNPolicy
frame_size = 1
timesteps = 1
policy_kwargs = dict(
    features_extractor_class=CNNExtractor
)
tuner = Tuner(model=model, env=env, policy_network=policy_network, policy_args=policy_kwargs, 
              frame_size=frame_size, timesteps=timesteps, save_dir=model_dir)

study = tuner.tune_study(n_trials=2, )
env.close()
study.best_trial, study.best_params

In [None]:
optuna.visualization.plot_parallel_coordinate(study)

In [None]:
optuna.visualization.plot_contour(study)

In [None]:
optuna.visualization.plot_slice(study)

In [None]:
optuna.visualization.plot_param_importances(study)

In [None]:
optuna.visualization.plot_edf(study)

In [None]:
optuna.visualization.plot_optimization_history(study)