In [1]:
from constants import *
from utils import evaluate_model_policy, plot_study, plot_fig
from trainer import get_trained_model
import optuna
from environment import StreetFighterEnv
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack
from actor_critic import A2CCNNPolicy
from feature_extractors import CNNExtractorWithAttention, CNNExtractor
from tuner import Tuner
import os
from layers import ActorCriticLayer

# ORIGINAL IMAGE

In [2]:
TIMESTEPS = 10000000
N_TRIALS = 5

PLOTLY_CONFIG = {"staticPlot": True}


In [3]:
model = A2C
model_dir = 'models/bias'
env = StreetFighterEnv(capture_movement=False)
policy_network = A2CCNNPolicy

policy_kwargs = dict(
    features_extractor_class=CNNExtractorWithAttention,
    features_extractor_kwargs=dict(features_dim=512,),
    actor_critic_class=ActorCriticLayer
)
tuner = Tuner(model=model, env=env, policy_network=policy_network, policy_args=policy_kwargs,
              timesteps=TIMESTEPS, save_dir=model_dir)

study = tuner.tune_study(n_trials=N_TRIALS, )
study.best_trial.number, study.best_params

[32m[I 2022-04-18 10:01:57,758][0m A new study created in memory with name: no-name-52900f6e-1d6a-40fc-9177-ebf04130792e[0m


  0%|          | 0/5 [00:00<?, ?it/s]

[32m[I 2022-04-18 20:11:57,442][0m Trial 0 finished with value: 1000.0 and parameters: {'gamma': 0.8191287796298925, 'learning_rate': 1.0456226062552554e-05, 'gae_lambda': 0.8549013409989044}. Best is trial 0 with value: 1000.0.[0m


KeyboardInterrupt: 

In [None]:
plots = plot_study(study)
for plot in plots:
    plot.show("notebook", config=PLOTLY_CONFIG)

# FRAME DIFF

In [None]:
model = A2C
model_dir = 'models/bias_with_movement'
env = StreetFighterEnv(capture_movement=True)
policy_network = A2CCNNPolicy

policy_kwargs = dict(
    features_extractor_class=CNNExtractorWithAttention,
    features_extractor_kwargs=dict(features_dim=512,),
    actor_critic_class=ActorCriticLayer
)
tuner = Tuner(model=model, env=env, policy_network=policy_network, policy_args=policy_kwargs,
              timesteps=TIMESTEPS, save_dir=model_dir)

study = tuner.tune_study(n_trials=N_TRIALS, )
study.best_trial.number, study.best_params

In [None]:
plots = plot_study(study)
for plot in plots:
    plot.show("notebook", config=PLOTLY_CONFIG)