In [None]:
from stk_actor.agent import UnifiedSACPolicy
from stk_actor.wrappers import PreprocessObservationWrapper, StuckStopWrapper, SkipFirstNStepsWrapper
from stable_baselines3 import PPO, A2C
from stable_baselines3.common.env_util import make_vec_env
from pystk2_gymnasium import AgentSpec

import torch

path = 'stk_actor/trained_agents/'
agent = 'normed_behavioral_cloning_num5'

tracks = [
    'abyss',
    'black_forest',
    'candela_city',
    'cocoa_temple',
    'cornfield_crossing',
    'fortmagma',
    'gran_paradiso_island',
    'hacienda',
    'lighthouse',
    'mines',
    'minigolf',
    'olivermath',
    'ravenbridge_mansion',
    'sandtrack',
    'scotland',
    'snowmountain',
    'snowtuxpeak',
    'stk_enterprise',
    'volcano_island',
    'xr591',
    'zengarden',

# # #   ==================   #

#     'fortmagma',
#     'ravenbridge_mansion',
#     'snowmountain',
#     'cocoa_temple',
#     'sandtrack',    
#     'scotland', 
#     'stk_enterprise',
#     'volcano_island', # 1104
#     'xr591', # 864   
]

karts = [12]
n_envs = len(tracks)*len(karts)

print('making', n_envs, 'environments')
vec_env = make_vec_env(
    "supertuxkart/flattened_multidiscrete-v0",
    # seed=12,
    n_envs=n_envs, 
    wrapper_class=lambda x : (
        SkipFirstNStepsWrapper(
            StuckStopWrapper(
                PreprocessObservationWrapper(x, ret_dict=False, norm=True, agent_name=agent),
                n=92,
            ), 
            n=20,
        )
    ), 
    env_kwargs={
    'render_mode':None, 'agent':AgentSpec(use_ai=False, name="walid"), #'track':'minigolf', 
    'laps':1,
    'difficulty':2, 
    'num_kart':12, #'difficulty':0
})

ix = 0
for num_kart in enumerate(karts):
    for track in enumerate(tracks):
        venv = vec_env.envs[ix]
        venv.env.default_track = track
        venv.env.num_kart = num_kart
        print(ix, track, )
        ix+=1

net_arch=[1024,1024,1024]
activation_fn=torch.nn.Tanh
filename = path+f'{agent}/statedict'

action_dims = [space.n for space in vec_env.action_space]
unified_policy = UnifiedSACPolicy(
    vec_env.observation_space, 
    action_dims, 
    net_arch=net_arch, 
    activation_fn=activation_fn
)
unified_policy.load_state_dict(torch.load(filename, map_location='cpu'))


In [None]:
steps = [(
    # 2048,
    10*n_envs,
    300_000,
)]
for n_steps, total_timesteps in steps:
    # model = PPO(
    #     "MlpPolicy", 
    #     vec_env, 
    #     verbose=1, 
    #     policy_kwargs = dict(net_arch=net_arch, activation_fn=activation_fn,),
    #     device='cpu',
    #     learning_rate=0.0001,
    #     batch_size=128,
    #     n_epochs=100,
    #     n_steps=n_steps,
    #     tensorboard_log="./outputs/",
    #     clip_range=0.2,
    # )
    model = A2C(
        "MlpPolicy", 
        vec_env, 
        verbose=1, 
        policy_kwargs = dict(net_arch=net_arch, activation_fn=activation_fn,),
        device='cpu',
        learning_rate=0.001,
        n_steps=n_steps,
        tensorboard_log="./outputs/",
        use_rms_prop=False,
        normalize_advantage=True,
    )
    print('DOING', n_steps, total_timesteps)
    model.policy.load_state_dict(unified_policy.shared.state_dict())
    model.learn(total_timesteps=total_timesteps, progress_bar=True, log_interval=1)


## Saving

In [None]:
# unified_policy.shared.load_state_dict(model.policy.state_dict())
# torch.save(unified_policy.state_dict(), 'stk_actor/trained_agents/normed_a2c_num5_best/statedict')

In [None]:
# unified_policy.shared.load_state_dict(model.policy.state_dict())
# torch.save(unified_policy.state_dict(), 'stk_actor/trained_agents/normed_ppo_num5_best/statedict')