In [5]:
# general imports
import sys                       
import numpy as np                

# env checker
try:
    from stable_baselines3.common import env_checker
except ModuleNotFoundError: 
    !pip install stable-baselines3==1.2.0
    from stable_baselines3.common import env_checker

from stable_baselines3.common import env_checker

# stable baselines3 -> SAC
from stable_baselines3 import SAC
from stable_baselines3.sac import MlpPolicy

from sac_torch import Agent

# Tensorboard
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('tensorboard_log/',comment="-SAC_HER_buff20000")

# grpc communication
sys.path.insert(1, '/tum_nrp/grpc/python/communication')
import experiment_api_wrapper as eaw

sys.path.insert(1, '/tum_nrp/rlmodel/sb3')
from env import SimEnv
from train_helpers import evaluate, train

# Parameters

In [6]:
params = {
    "img_pi_data": 0,                 # 1 laod image data from .pi file, 0 do not load image data
    "train_loader": None,             # Specifies the loader for the training data
    "test_loader": None,              # Specifies the loader for the test data
    "VERBOSE": 1,                     
    "SETTING": 'reduced4',            # 'reduced', 'reduced2', 'reduced3', 'reduced3+', 'reduced4', 'reduced4+'
    "OBJ_SPACE_LOW": np.array([-0.92, -0.51, 0.58, -0.44, -0.48, 0, -np.pi/2, -np.pi/2, -0.001, -np.pi/2, -0.001, -np.pi]), # observation-space (ee-pos, cyl-pos, joints)
    "OBJ_SPACE_HIGH": np.array([0.92, 1.32, 2.07, 0.48, 0.44, 1.12, np.pi/2, 0.001, np.pi, np.pi/2, np.pi, np.pi]),
    "SPACE_NORM": 1,                  #  1 -> yes, 0 -> no (normalize the action and observation space)
    "CYLINDER": 'no',                 # 'no', fix', 'semi_random', 'semi_random_sides', 'half_table', '3/4-table', '7/8-table', 'whole_table'
    "BUFFER_SIZE": 2000,
    "THRESHOLD": 0.25,                # initial treshold
    "THRESHOLD_SCHEDULING": 1,        # 1-> yes, 0-> no
    "MIN_THRESHOLD": 0.10,
    "REWARD_TYPE": 'dense',           # 'sparse', 'dense', 'extra_dense'
    "LEARNING_STARTS": 15,            # number of random movements before learning starts,#
    "TOGGLE_REWARD": 0,
    "STEPS": 4000,                    # number of steps while training (=num_episodes when MAX_EPISODE_LEGTH is 1)
    "MAX_EPISODE_LENGTH": 1,          # 'None' (no limit) or value 
    "EXPLORATION": 25,                # just let it on 1 and ignore it
    "WRITER": writer,
    "USE_HER": 0,                     # 1-> yes, 0-> no
    "ENTROPY_COEFFICIENT": 0.007,     # 'auto' or value between 0 and 1 // 0.007 turned out to work well
    "GLOBAL_STEPPER": 0, 
    "EVALUATION_STEPS": 30,           # number of evaluation steps per investigates treshold (x4)
    "EVALS": [0.30, 0.25, 0.2, 0.15], # here, the list MUST contain always 4 tresholds for evaluation
    "BATCH_SIZE": 64,
    "ACTION_NOISE": None,
    "RANDOM_EXPLORATION": 0.0,
    "LR": 3e-4,
    "TB_LOGGER": None}

# Model

In [7]:
if __name__ == "__main__":
    # create a experiment (connection)
    exp = eaw.ExperimentWrapper()

    # test if simulation can be reached
    server_id = exp.client.test()
    if server_id:
        print("Simulation is available, id: ", server_id)
    else:
        print("Simulation NOT available")
        
        
    # create an environment (choice depends on usage of HER)
    env = SimEnv(exp, params, writer)
        
    # check env
    env_checker.check_env(env)


    # setup custome sac agent
    agent = Agent(alpha=0.0003, beta=params["LR"], input_dims=env.observation_space.shape, env=env, 
                  gamma=0.99, n_actions=env.action_space.shape[0], max_size=params["BUFFER_SIZE"], 
                  tau=0.005, layer1_size=256, layer2_size=256, batch_size=params["BATCH_SIZE"], 
                  reward_scale=params["EXPLORATION"])


    best_score = 0
    score_history = []
    load_checkpoint = False

    avg_score = 0
    
    # counter used to evaluate the agent every Nth epoch
    i = 0
    
    # run as long as the average score is less then 0.9 and for at least 100 epochs
    while avg_score < 0.90 or i < 100: # if i < param["step"]
        observation = env.reset()
        done = False
        score = 0

        # catch error that occurs very rarely in the step function and give it a second chance
        while not done:
            # choose an action
            action = agent.choose_action(observation)
            # execute the chosen action
            observation_, reward, done, info = env.step(action)
            
            score += reward
            
            # save the observation, reward, etc. in the replay buffer
            agent.remember(observation, action, reward, observation_, done)
            
            if not load_checkpoint:
                # start the training process
                agent.learn()
            
            # update the observation 
            observation = observation_
         
        # track the score and caculate the average score over the last 100 epochs
        score_history.append(score)
        avg_score = np.mean(score_history[-100:])
        
        # save a checkpoint if the current average score is better then all previous scores
        if avg_score > best_score:
            best_score = avg_score
            if not load_checkpoint:
                print('SAVING MODEL: episode ', i, 'score %.1f' % score, 'avg_score %.1f' % avg_score)
                agent.save_models()
                
        # evaluate the model ever 1000th epoch        
        if (i+1) % 1000 == 0:
            # disable plotting while evaluation because of some issues with inconsistent lengths
            print("start evaluation")
            env.set_eval(ev=True)
            try:
                evaluate(agent, env, params)
            except:
                print("error while plotting")
            env.set_eval(ev=False)
            print("stop evaluation")

        print('episode ', i, 'score %.1f' % score, 'avg_score %.1f' % avg_score)
        i+=1


_InactiveRpcError: <_InactiveRpcError of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses"
	debug_error_string = "{"created":"@1631786010.327216533","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":3008,"referenced_errors":[{"created":"@1631786010.327212322","description":"failed to connect to all addresses","file":"src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":397,"grpc_status":14}]}"
>