In [1]:
import torch
from validation import Validation
from pathlib import Path
import numpy as np
from Environment import Env
import config.config as config

import sys; sys.path.append('../analysis/')
from my_utils import reset_seeds

# parameters for testing task

In [4]:
# only change in this cell

# training task
training_task = 'gain'  # Possible arguments: 'gain'; 'gain_control'
pro_noise = '0.2'       # Possible arguments: '0', '0.1', '0.2', '0.3'; for 'gain_control' task
obs_noise = '0.1'       # Possible arguments: '0', '0.1', '0.2', '0.3'; for 'gain_control' task


# testing task.
# Possible arguments: 'gain' or 'perturb' if training_task = 'gain'; 'gain_control' or 'perturb_control' if training_task = 'gain_control'
testing_task = 'gain'


# was agent trained extensively?
extensive = False


# root folder for all agents' checkpoints
progress_agents_path = Path(f'../data/agents_temp/')
# agent's architecture being evaluatedR
agent_archs = ['Actor3Critic5',] 
# agent's seed for the chosen architecture being evaluated
agent_seeds = [(1,),]


# root folder for storing validation result
save_path = Path(f'../data/training_curve_temp/')

# Run agents

In [None]:
# do not change this cell

if training_task == 'gain' and testing_task == 'gain':
    task = training_task
    subtask = ''
    arg = config.ConfigGain()
    
elif training_task == 'gain' and testing_task == 'perturb':
    task = training_task
    subtask = ''
    arg = config.ConfigPerturb()
    
elif training_task == 'gain_control' and testing_task == 'gain_control':
    task = training_task
    subtask = '_'.join([pro_noise, obs_noise])
    arg = config.ConfigGainControl(pro_noise=float(pro_noise), obs_noise=float(obs_noise))
    
elif training_task == 'gain_control' and testing_task == 'perturb_control':
    task = training_task
    subtask = '_'.join([pro_noise, obs_noise])
    arg = config.ConfigPerturbControl(pro_noise=float(pro_noise), obs_noise=float(obs_noise))
else:
    raise ValueError
    
    
if extensive:
    episode_max = int(1e5)
    validation_size = 300
else:
    episode_max = int(1e4)
    validation_size = 500
    
save_freq = 500
enable_noise = True

# locate the agent being evaluated
paths = [progress_agents_path / agent_arch / task / subtask / f'seed{seed}' 
         for agent_arch, seeds in zip(agent_archs, agent_seeds) for seed in seeds]
filenames = [list(path.glob('*.pkl'))[0].stem.split('_')[0] for path in paths]
episodes = np.arange(save_freq - 1, episode_max, save_freq)


# sample targets and perturbations
arg.device = 'cpu'
arg.process_gain_range = [1, 1]
env = Env(arg)

reset_seeds(0)
target_positions = []; perturb_peaks = []; perturb_start_times = []
for _ in range(validation_size):
    __ = env.reset()
    target_positions.append(env.target_position) 
    if 'perturbation' in arg.task:
        perturb_peaks.append(env.perturbation_velocities)
        perturb_start_times.append(env.perturbation_start_t)
    else:
        perturb_peaks.append(None)
        perturb_start_times.append(None)
    
perturb_peaks_large = []
for _ in range(validation_size):
    if 'perturbation' in arg.task:
        perturbation_velocities_large = torch.zeros(2)
        perturbation_velocities_large[0].uniform_(*arg.perturbation_velocity_range_large[:2])
        perturbation_velocities_large[1].uniform_(*arg.perturbation_velocity_range_large[2:])
        perturb_peaks_large.append(perturbation_velocities_large)
    else:
        perturb_peaks_large.append(None)
        

        
# validating the agent across checkpoints
for path, filename in zip(paths, filenames):
    if 'EKF' not in str(path):
        agent_name = 'RNN'
        from Agent_RNN import *
        if subtask == '':
            exec(f'from {path.parents[1].stem[:6]} import *')  # import Actor
            exec(f'from {path.parents[1].stem[6:13]} import *')  # import Critic
        else:
            exec(f'from {path.parents[2].stem[:6]} import *')
            exec(f'from {path.parents[2].stem[6:13]} import *')

        reset_seeds(0)
        agent_temp = Agent(arg, Actor, Critic)
    else:
        agent_name = 'EKF'
        if 'ActorEKF' in str(path):
            from Agent_RNN_EKF import *
            from ActorEKF import *; from Critic1 import *
            reset_seeds(0)
            agent_temp = Agent(arg, Actor, Critic)
        else:
            from Agent_EKF import *
            reset_seeds(0)
            agent_temp = Agent(arg)
    agent_temp.data_path = path
    
    validator = Validation(arg.task, agent_name=agent_name, validation_size=validation_size,
                           target_positions=target_positions, 
                           perturbs_info=[perturb_peaks, perturb_start_times, perturb_peaks_large], 
                           enable_noise=enable_noise, extensive=extensive)
            
    for episode in episodes:
        print(f'evaluating {episode}')
        agent_temp.load(f'{filename}-{episode}', load_memory=False, load_optimzer=False)
        validator(agent_temp, episode)
        
        # save
        if subtask == '':
            if 'perturbation' in arg.task:
                filepath = save_path / f'perturbation/{path.parents[1].stem}_{path.stem}.csv'
            else:
                filepath = save_path / f'{task}/{path.parents[1].stem}_{path.stem}.csv'
        else: 
            if 'perturbation' in arg.task:
                filepath = save_path / f'perturbation_control/'f'{subtask}/{path.parents[2].stem}_{path.stem}.csv' 
            else:
                filepath = save_path / f'{task}/{subtask}/{path.parents[2].stem}_{path.stem}.csv'
        filepath.parent.mkdir(parents=True, exist_ok=True) 
        validator.data.to_csv(filepath, index=False)