# Example for training an agent on the Bongard environment

In [None]:
import gym
import torch
import numpy as np
import random
from bongard_base_env import BPEnv
from stable_baselines3 import PPO, A2C, DQN
from eval_model import eval_model, dict_to_runname

## Set Hyperparameters

In [None]:
algo_dict = {'PPO': PPO, 'A2C': A2C}
env_dict = {'BPEnv': BPEnv}

params = {

    'env': 'BPEnv', # Environment for training the agent
    'algo': 'PPO', # Algorithm used for training
    'policy': 'SiaMlpPolicy', # Network architecture for policy
    'lr': 7e-05, # Learning rate
    'eplength': False, # Episode length
    'CB': False, # Causal Bounds
    'clip_range': 0.2, # PPO clipping range
    'save_model': False, # Whether to save model or not
    'run_name' : '', # Runname for saving model, if empty name will be generated from parameters
    'seeds':  [11, 61, 331], # Random seeds to train on
    'log_dir': 'logs/', # Log directory used for tensorboard logs
    'total_timesteps': 2000000, # Total timesteps for training
    'skip_action' : True, # Whether to include skip action or not (not at the same time with CB)
    'test_mode': False, # Testing mode for agent
}

## Initialize Environment
In the case of BP environments we have the option of including the skip action or leaving it out which can be specified in the hyperparameters.

In [None]:
env = env_dict[params['env']](skip_action=params['skip_action'])

## Train the agent
We train the agent for several random seeds which can also be set in the hyperparameters. The name of the run which will be used for saving logs and the trained model can be specified in the hyperparameters and if no name is specified, a name will be generated based on the settings for this run.

In [None]:
for seed in params['seeds']:

    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    params['seed'] = seed
    params['run_name'] = dict_to_runname(params)
    algo = algo_dict[params['algo']](params['policy'], env, learning_rate=params['lr'], clip_range=params['clip_range'], verbose=1, causal=params['CB'], tensorboard_log=f"./{params['log_dir']}{params['run_name']}/")

    algo.learn(total_timesteps=params['total_timesteps'], tb_log_name=f"{params['run_name']}")
    algo.save(params['run_name'])