In [1]:
import numpy as np

from air_hockey_challenge.framework import AirHockeyChallengeWrapper, ChallengeCore
from air_hockey_challenge.utils.tournament_agent_wrapper import TrainingTournamentAgentWrapper
from baseline.baseline_agent.baseline_agent import BaselineAgent

from air_hockey_agent.agent_builder import build_agent
from air_hockey_agent.reward import *

from mushroom_rl.utils.dataset import compute_J
from mushroom_rl.core import Logger

In [2]:
mdp = AirHockeyChallengeWrapper(
    env='tournament',
    custom_reward_function=RewardList(
        rewards=[ScoreReward(), ConstraintReward('ee_constr'), ConstraintReward('link_constr')],
        weights=[1, 1, 1],
    ),
)

In [3]:
agent_1 = build_agent(mdp.env_info, layer_sizes=(256, 128, 64, 128, 256))

In [4]:
agent = TrainingTournamentAgentWrapper(
    mdp.env_info,
    agent_1=agent_1,
    agent_2=BaselineAgent(mdp.env_info, 2),
)

In [5]:
core = ChallengeCore(agent, mdp, is_tournament=True, init_state=mdp.base_env.init_state, time_limit=0.02)

In [None]:
initial_replay_size = 500
n_steps = 400
n_steps_test = 100
n_epochs = 20

train_id = 1

In [None]:
epoch = 0

logger = Logger(f'train', results_dir='./logs', use_timestamp=True)
logger.strong_line()
logger.info(f'Training started')

agent_1.sac._critic_approximator.set_logger(logger)

core.learn(n_steps=initial_replay_size, n_steps_per_fit=initial_replay_size)

while True:
    if epoch % 20 == 0:
        agent_1.save(f'checkpoints/{train_id}/epoch{epoch}.msh')
        logger.info(f'Saved model to checkpoints/{train_id}/epoch{epoch}.msh')

    core.learn(n_steps=n_steps, n_steps_per_fit=1)

    dataset = agent.get_dataset_1(core.evaluate(n_steps=n_steps_test))
    states, actions, rewards, next_states, abosorbing, _ = map(np.array, zip(*dataset))

    J = np.mean(compute_J(dataset, mdp.info.gamma))
    R = np.mean(compute_J(dataset))
    E = agent_1.sac.policy.entropy(states)

    logger.epoch_info(epoch, J=J, R=R, entropy=E)

    epoch += 1

                                              
                                               
                                               
                                               
 80%|████████  | 8/10 [02:52<00:00, 37.54it/s] 

11/11/2024 15:47:18 [INFO] ###################################################################################################
11/11/2024 15:47:18 [INFO] ###################################################################################################
11/11/2024 15:47:18 [INFO] Training started
11/11/2024 15:47:18 [INFO] Training started



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

KeyboardInterrupt: 


[A

In [None]:
dataset = core.evaluate(n_steps=1000, render=True)

 82%|████████▏ | 3271/4000 [02:07<00:26, 27.26it/s]

NameError: name 'exit' is not defined

