In [5]:
import numpy as np
import config.config as config
from validation import Validation
from Environment import Env
from Agent_RNN import *
from pathlib import Path

import sys; sys.path.append('../analysis/')
from my_utils import reset_seeds

In [6]:
def training(datapath, seed_number, Actor, Critic, task='gain', 
             init_expnoise_std=0.8, TOTAL_EPISODE=1e5, pro_noise=0.2, obs_noise=0.1, extensive=False):
    if task == 'gain':
        arg = config.ConfigGain(datapath)
    elif task == 'gain_control':
        arg = config.ConfigGainControl(datapath, pro_noise=pro_noise, obs_noise=obs_noise)
    else:
        raise ValueError('No such a task!')
            
    arg.SEED_NUMBER = seed_number
    arg.save()
    
    # reproducibility
    reset_seeds(arg.SEED_NUMBER)

    # initialize environment and agent
    env = Env(arg)
    agent = Agent(arg, Actor, Critic)
    validator = Validation(arg.task, agent_name='RNN', extensive=extensive)
    
    # define exploration noise
    noise = ActionNoise(arg.ACTION_DIM, mean=0, std=init_expnoise_std)

    # Remove observation noise in the beginning to help learning in the early stage.
    agent.bstep.obs_noise_range = None
    
    # Loop now
    tot_t = 0
    episode = agent.initial_episode
    reward_log = []
    rewarded_trial_log = []
    step_log = []
    actor_loss_log = 0
    critic_loss_log = 0
    num_update = 1e-5
    dist_log = []

    LOG_FREQ = 100
    VALIDATION_FREQ = 500
    decrease_lr = True
    REPLAY_PERIOD = 4    # critic update frequency
    PRE_LEARN_PERIOD = arg.BATCH_SIZE * 50     # no learning in the first n trials

    enable_mirror_traj = True     # store mirror image of each trajectory
    pre_phase=True                # training phase I

    # Start loop
    while episode < TOTAL_EPISODE:
        # initialize a trial
        cross_start_threshold = False
        reward = torch.zeros(1, 1, 1)

        x = env.reset()
        agent.bstep.reset(env.pro_gains)
        last_action = torch.zeros(1, 1, arg.ACTION_DIM)   # the action with exploration noise
        last_action_raw = last_action.clone()  # the exploration noise-free action

        # state contains observations of linear and angular velocities, last action (efference copy),
        # and target's x and y locations when visible
        state = torch.cat([x[-arg.OBS_DIM:].view(1, 1, -1), last_action,
                           env.target_position_obs.view(1, 1, -1)], dim=2).to(arg.device)

        hiddenin = None  # reset RNN memory

        states = []
        actions = []
        rewards = []

        for t in range(arg.EPISODE_LEN):
            # 1. Check if the agent's action crosses the start threshold
            if not cross_start_threshold and (last_action_raw.abs() > arg.TERMINAL_ACTION).any():
                cross_start_threshold = True

            # 2. Take an action based on current state and previous hidden states of RNN units
            action, action_raw, hiddenout = agent.select_action(state, hiddenin, action_noise=noise)

            # 3. Update the environment given the agent's action
            next_x, reached_target, relative_dist = env(x, action, t)

            # 4. Collect new observation and construct the next state
            next_ox = agent.bstep(next_x)
            next_state = torch.cat([next_ox.view(1, 1, -1), action,
                                    env.target_position_obs.view(1, 1, -1)], dim=2).to(arg.device)

            # 5. Check if agent stops
            is_stop = env.is_stop(x, action)

            # 6. Give reward if agent stopped         
            if is_stop and cross_start_threshold:
                reward = env.return_reward(x, reward_mode='mixed')

            # 7. Append data
            states.append(state)
            actions.append(action)
            rewards.append(reward)

            # 8. Update timestep
            last_action_raw = action_raw
            last_action = action
            state = next_state
            x = next_x
            hiddenin = hiddenout
            tot_t += 1

            # 9. Update model
            if len(agent.memory.memory) > PRE_LEARN_PERIOD and tot_t % REPLAY_PERIOD == 0:
                actor_loss, critic_loss = agent.learn()
                actor_loss_log += actor_loss
                critic_loss_log += critic_loss
                num_update += 1

            # 10. Trial ends if agent stops
            if is_stop and cross_start_threshold:
                break


        # End of a trial, store trajectory into buffer
        states = torch.cat(states)
        actions = torch.cat(actions).to(arg.device)
        rewards = torch.cat(rewards).to(arg.device)
        agent.memory.push(states, actions, rewards) 

        if enable_mirror_traj and noise.std != init_expnoise_std:
            # store mirrored trajectories reflected along the y-axis
            agent.memory.push(*agent.mirror_traj(states, actions), rewards) 

        # Logs
        reward_log.append(reward.item())
        rewarded_trial_log.append(int(reached_target & is_stop))
        step_log.append(t + 1)
        dist_log.append(relative_dist.item())

        if episode % LOG_FREQ == LOG_FREQ - 1:
            print(f"t: {tot_t}, Ep: {episode}, action std: {noise.std:0.2f}")
            print(f"mean steps: {np.mean(step_log):0.3f}, "
                  f"mean reward: {np.mean(reward_log):0.3f}, "
                  f"rewarded fraction: {np.mean(rewarded_trial_log):0.3f}, "
                  f"relative distance: {np.mean(dist_log) * arg.LINEAR_SCALE:0.3f}, "
                  f"obs noise: {agent.bstep.obs_noise_range}, "
                  f"critic loss: {critic_loss_log / num_update:0.3f}, "
                  f"actor loss: {-actor_loss_log / (num_update/2):0.3f}")

            # training phase III
            if decrease_lr and (validator.data.reward_fraction > 0.8).any():
                noise.reset(mean=0, std=0.4)
                agent.actor_optim.param_groups[0]['lr'] = arg.decayed_lr
                agent.critic_optim.param_groups[0]['lr'] = arg.decayed_lr
                decrease_lr = False
                print('Noise and learning rate are changed.')

            # training phase II
            if noise.std == init_expnoise_std and np.mean(rewarded_trial_log) > 0.2:
                noise.reset(mean=0, std=0.5)
                agent.bstep.obs_noise_range = arg.obs_noise_range
                agent.memory.reset()
                tot_t = 0
                episode = 0
                pre_phase = False

            # reset logs
            reward_log = []
            rewarded_trial_log = []
            step_log = []
            actor_loss_log = 0
            critic_loss_log = 0
            num_update = 1e-5
            dist_log = []

        # save checkpoints and validation
        if episode % VALIDATION_FREQ == VALIDATION_FREQ - 1:
            # save
            agent.save(save_memory=False, episode=episode, pre_phase=pre_phase, full_param=False)
            # validation for deciding the training phase
            if noise.std < init_expnoise_std and decrease_lr:
                validator(agent, episode).to_csv(arg.data_path / f'{arg.filename}.csv', index=False)
                agent.bstep.obs_noise_range = arg.obs_noise_range

        episode += 1
        
        # break if no learning
        if pre_phase and episode >= 5e4:
            break

# specify parameters

In [7]:
actors = ['Actor3']   # possible arguments: 'Actor1'; 'Actor2'; 'Actor3'
critics = ['Critic5'] # possible arguments: 'Critic1'; 'Critic2'; 'Critic3'; 'Critic4'; 'Critic5'
tasks = ['gain']      # possible arguments: 'gain'; 'gain_control'
seeds = [[0, 1, 2]]   # should be an iterable
init_expnoise_std = 0.8    # initial exploration noise std
TOTAL_EPISODE = 1e4        # total training trials; default training: 1e4, extensive training: 1e5
pro_noise = 0.2            # process noise std, applied when task is 'gain_control'. In paper, values 0, 0.1, 0.2, 0.3 are used.
obs_noise = 0.1            # observation noise std, applied when task is 'gain_control'. In paper, values 0, 0.1, 0.2, 0.3 are used.
extensive = True           # is extensive training?
folder_path = Path('../data/agents_temp')  # root folder for all agents' checkpoints

# training

In [None]:
for actor, critic, task, seed_ in zip(actors, critics, tasks, seeds):
    for seed in seed_:
        if task == 'gain_control':
            datapath = folder_path / f'{actor}{critic}' / task \
                        / f'{pro_noise}_{obs_noise}' / f'seed{seed}'
        else:
            datapath = folder_path / f'{actor}{critic}' / task / f'seed{seed}'
        exec(f'from {actor} import *'); exec(f'from {critic} import *')
        training(datapath, seed, Actor, Critic, task, init_expnoise_std, TOTAL_EPISODE, 
                 pro_noise, obs_noise, extensive)