## Evaluate Trained Reward Function Model and Visualize

In [1]:
import os
import sys

import numpy as np
import pandas as pd

import torch

# get the current script's directory
current_directory = os.path.dirname(os.path.abspath(__file__)) if "__file__" in locals() else os.getcwd()
# get the parent directory
parent_directory = os.path.dirname(current_directory)
# add the parent directory to the sys.path
sys.path.append(parent_directory)

from utils import constants
from utils.dataset_loader import PolicyDatasetLoader

from optimization.updater import Updater
from optimization.functions import setup_config, get_directories, load_policy_from_path, load_reward_from_path
from optimization.functions import find_indices_of_trajectory_changes, get_estimated_rewards

from models.policy_model import RobotPolicy
from models.reward_model import RewardFunction

In [2]:
pd.set_option("display.max_columns",
              None)

# Initialization

In [3]:
# available evaluating machine
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Evaluating Device: ", device)

# setup hyperparameters
configs = setup_config(device=device)

# create and return preliminary base paths
json_paths, results_path = get_directories(parent_directory=parent_directory,
                                           data_folder_name=constants.TEST_COLLECTION_DATE)

Evaluating Device:  cpu
Current Time:  Feb_06_2024-17_23_03


In [4]:
# load test demonstrations dataset
all_test_data = PolicyDatasetLoader(demo_data_json_paths=json_paths)



Number of Trajectories:  5
Each Trajectory Length:  30
Full Demo Dataset Size:  163


In [5]:
# get all indice numbers where the new trajectory is initialized in the dataset
trajectory_indices = find_indices_of_trajectory_changes(dataset=all_test_data)

In [6]:
policy_network = RobotPolicy(state_size=configs.state_size,
                             hidden_size=configs.hidden_size,
                             out_size=configs.action_size,
                             log_std_min=configs.policy_log_std_min,
                             log_std_max=configs.policy_log_std_max,
                             log_std_init=configs.policy_log_std_init,
                             device=configs.device)

In [7]:
reward_network = RewardFunction(state_action_size=configs.state_action_size,
                                hidden_size=configs.hidden_size,
                                out_size=configs.reward_size,
                                device=configs.device)

# Functions

# Test

In [8]:
# folder name where policy model parameters are located ("results / policy_network_params / loading_folder_name")
policy_loading_folder_name = constants.POLICY_LOADING_FOLDER
policy_params_name = constants.POLICY_PARAMS_NAME

In [9]:
# folder name where reward model parameters are located ("results / reward_network_params / loading_folder_name")
reward_loading_folder_name = constants.REWARD_LOADING_FOLDER
reward_params_name = constants.REWARD_PARAMS_NAME

In [10]:
# load pretrained policy network parameters
policy_network = load_policy_from_path(policy_network=policy_network,
                                       results_path=results_path,
                                       policy_loading_folder_name=policy_loading_folder_name,
                                       policy_params_name=policy_params_name)

# set policy model to evaluation mode
for param in policy_network.parameters():
    param.requires_grad = False
policy_network = policy_network.eval()

In [11]:
# load pretrained reward network parameters
reward_network = load_reward_from_path(reward_network=reward_network,
                                       results_path=results_path,
                                       reward_loading_folder_name=reward_loading_folder_name,
                                       reward_params_name=reward_params_name)

# set reward model to evaluation mode
for param in reward_network.parameters():
    param.requires_grad = False
reward_network = reward_network.eval()

In [12]:
updater_obj = Updater(configs=configs,
                      policy_network=policy_network,
                      reward_network=reward_network)

In [13]:
nu_factor = torch.tensor(0.0)

## Run Through All Trajectories in a Test Demonstrations

In [14]:
# loop through each separate trajectory inside the testing dataset
for traj_start_index in range(len(trajectory_indices)):
    
    traj_df, reward_values_demo_data, reward_values_estim_data, logprob_action_estim_avg = get_estimated_rewards(configs=configs,
                                                                                                                 updater_obj=updater_obj,
                                                                                                                 data_loader=all_test_data,
                                                                                                                 policy_network=policy_network,
                                                                                                                 reward_network=reward_network,
                                                                                                                 trajectory_indices=trajectory_indices,
                                                                                                                 traj_start_index=traj_start_index,
                                                                                                                 is_inference_reward=True)
    
    irl_train_loss = updater_obj.calculate_irl_loss(demo_traj_reward=reward_values_demo_data,
                                                    robot_traj_reward=reward_values_estim_data,
                                                    log_probability=logprob_action_estim_avg,
                                                    nu_factor=nu_factor)
    
    # store outputs in the dataframe of trajectory information
    traj_df[constants.ACTION_PREDICTION_AVG_LOGPROB_NAME] = logprob_action_estim_avg.numpy().flatten()
    traj_df[constants.REWARD_DEMONSTRATION_TRAJECTORY_NAME] = reward_values_demo_data.numpy().flatten()
    traj_df[constants.REWARD_ROBOT_TRAJECTORY_NAME] = reward_values_estim_data.numpy().flatten()
    

In [15]:
traj_df

Unnamed: 0,state_label_norm_1,state_label_norm_2,state_label_norm_3,state_label_norm_4,state_label_denorm_1,state_label_denorm_2,state_label_denorm_3,state_label_denorm_4,action_label_norm_1,action_label_norm_2,action_label_norm_3,action_label_denorm_1,action_label_denorm_2,action_label_denorm_3,action_pred_logprob_1,action_pred_logprob_2,action_pred_logprob_3,action_pred_norm_1,action_pred_norm_2,action_pred_norm_3,action_pred_std_1,action_pred_std_2,action_pred_std_3,action_pred_denorm_1,action_pred_denorm_2,action_pred_denorm_3,trajectory_index,state_number,gnll_loss,state_est_norm_1,state_est_norm_2,state_est_norm_3,state_est_norm_4,state_est_denorm_1,state_est_denorm_2,state_est_denorm_3,state_est_denorm_4,next_state_label_norm_1,next_state_label_norm_2,next_state_label_norm_3,next_state_label_norm_4,next_state_label_denorm_1,next_state_label_denorm_2,next_state_label_denorm_3,next_state_label_denorm_4,next_state_est_norm_1,next_state_est_norm_2,next_state_est_norm_3,next_state_est_norm_4,next_state_est_denorm_1,next_state_est_denorm_2,next_state_est_denorm_3,next_state_est_denorm_4,action_pred_avg_logprob,reward_demo_traj,reward_robot_traj
0,0.357050865889,0.838037848473,0.0,0.391351431608,0.714101731777,1.676075696945,0.0,0.782702863216,-0.199934840202,-0.355778723955,0.307092785835,-0.399869680405,-0.711557388306,0.614185571671,1.38172018528,1.38238465786,1.381953120232,-0.254897743464,-0.334094941616,0.310078799725,0.100192822516,0.10012627393,0.100169487298,-0.509795427322,-0.668189883232,0.620157718658,4,0,0.301920861006,0.357050865889,0.838037848473,0.0,0.391351431608,0.714101731777,1.676075696945,0.0,0.782702863216,0.35704909917,0.838033201962,0.0,0.391342785835,0.71409819834,1.676066403923,0.0,0.782685571671,0.340187301223,0.823530480231,0.059160960251,0.394328859329,0.680374602445,1.647060960462,0.118321920502,0.788657718658,1.382019400597,0.99840170145,0.997605681419
1,0.357049137354,0.83803319931,1.479453e-05,0.391342788935,0.714098274708,1.676066398621,2.958906e-05,0.782685577869,-0.20088724792,-0.356528669596,0.305290669203,-0.401774525642,-0.713057279587,0.61058139801,1.38172018528,1.38238465786,1.381953120232,-0.298301160336,-0.302041709423,0.310738950968,0.100192822516,0.10012627393,0.100169487298,-0.596602320671,-0.604083418846,0.621477842331,4,1,0.305598944426,0.340187311172,0.823530495167,0.059160958976,0.394328862429,0.680374622345,1.647060990334,0.118321917951,0.788657724857,0.357650418535,0.837880037692,0.002171875745,0.389540699005,0.715300837071,1.675760075385,0.00434375149,0.77908139801,0.319422333527,0.801272564654,0.112146737658,0.394988921165,0.638844667054,1.602545129308,0.224293475315,0.789977842331,1.382019400597,0.998550713062,0.988874316216
2,0.35765042901,0.837880015373,0.002183987992,0.389540672302,0.715300858021,1.675760030746,0.004367975984,0.779081344604,-0.213343515992,-0.352634400129,0.306205183268,-0.426687002182,-0.705268859863,0.612410306931,1.38172018528,1.38238465786,1.381953120232,-0.334620386362,-0.264129817486,0.309306263924,0.100192822516,0.10012627393,0.100169487298,-0.66924071312,-0.528259634972,0.618612527847,4,2,0.309716969728,0.319422334433,0.801272571087,0.112146735191,0.394988924265,0.638844668865,1.602545142174,0.224293470383,0.78997784853,0.354093252101,0.835267363275,0.013800959226,0.390455153465,0.708186504202,1.670534726551,0.027601918451,0.780910306931,0.298212836032,0.77470220568,0.162925149673,0.393556263924,0.596425672063,1.54940441136,0.325850299346,0.787112527847,1.382019400597,0.998427450657,0.949690937996
3,0.35409322381,0.835267364979,0.013813394122,0.390455186367,0.70818644762,1.670534729958,0.027626788244,0.780910372734,-0.230947449803,-0.341997474432,0.3054240942,-0.461894869804,-0.683995008469,0.6108481884,1.38172018528,1.38238465786,1.381953120232,-0.368749916553,-0.222677960992,0.305788040161,0.100192822516,0.10012627393,0.100169487298,-0.737499833107,-0.445355892181,0.611576080322,4,3,0.314086019993,0.298212826252,0.774702191353,0.162925153971,0.393556267023,0.596425652504,1.549404382706,0.325850307941,0.787112534046,0.344569246684,0.826418121762,0.033977739144,0.3896740942,0.689138493369,1.652836243524,0.067955478289,0.7793481884,0.280874978003,0.746343546858,0.214979165268,0.390038040161,0.561749956006,1.492687093716,0.429958330536,0.780076080322,1.382019400597,0.997871398926,0.959940969944
4,0.344569206238,0.826418101788,0.033989518881,0.3896740973,0.689138412476,1.652836203575,0.067979037762,0.779348194599,-0.247186809778,-0.330925107002,0.305027574301,-0.494373559952,-0.661850214005,0.610055208206,1.38172018528,1.38238465786,1.381953120232,-0.405115872622,-0.179633289576,0.299891501665,0.100192822516,0.10012627393,0.100169487298,-0.810231685638,-0.359266519547,0.599782943726,4,4,0.320073366165,0.280874967575,0.746343553066,0.214979171753,0.390038043261,0.56174993515,1.492687106133,0.429958343506,0.780076086521,0.335445323772,0.817679567616,0.053429501048,0.389277604103,0.670890647545,1.635359135231,0.106859002095,0.778555208206,0.273614028288,0.719063648876,0.270514910514,0.384141471863,0.547228056577,1.438127297753,0.541029821029,0.768282943726,1.382019400597,0.996911883354,0.97309666872
5,0.335445314646,0.817679524422,0.053440917283,0.3892775774,0.670890629292,1.635359048843,0.106881834567,0.7785551548,-0.265253275633,-0.31778883934,0.304497599602,-0.53050661087,-0.63557767868,0.608995199203,1.38172018528,1.38238465786,1.381953120232,-0.430047333241,-0.133500486612,0.295430600643,0.100192822516,0.10012627393,0.100169487298,-0.860094666481,-0.26700091362,0.590861320496,4,5,0.325534552336,0.273614019156,0.719063639641,0.270514905453,0.384141474962,0.547228038311,1.438127279282,0.541029810905,0.768282949924,0.325582337149,0.807653994595,0.075607314201,0.388747599602,0.651164674298,1.61530798919,0.151214628401,0.777495199203,0.26661552231,0.689582334706,0.320148995659,0.379680660248,0.53323104462,1.379164669413,0.640297991319,0.759361320496,1.382019400597,0.995283007622,0.973658680916
6,0.325582325459,0.807653963566,0.075618445873,0.388747602701,0.651164650917,1.615307927132,0.151236891747,0.777495205402,-0.277430802584,-0.301107466221,0.305680006742,-0.554861545563,-0.602214932442,0.61136007309,1.38172018528,1.38238465786,1.381953120232,-0.439489096403,-0.082561343908,0.295017898083,0.100192822516,0.10012627393,0.100169487298,-0.878978252411,-0.165122747421,0.590035915375,4,6,0.330829441547,0.266615509987,0.68958234787,0.320149004459,0.379680663347,0.533231019974,1.37916469574,0.640298008919,0.759361326694,0.3122239185,0.795161394994,0.09485019025,0.389930036545,0.624447837001,1.590322789988,0.1897003805,0.77986007309,0.253960762321,0.655585561566,0.363565347093,0.379267957687,0.507921524643,1.311171123131,0.727130694186,0.758535915375,1.382019400597,0.990079164505,0.978941619396
7,0.312223941088,0.795161366463,0.094860710204,0.389930009842,0.624447882175,1.590322732925,0.189721420407,0.779860019684,-0.2920152843,-0.279585421085,0.303641468287,-0.584030628204,-0.559170842171,0.607282876968,1.38172018528,1.38238465786,1.381953120232,-0.442486822605,-0.025832075626,0.298340529203,0.100192822516,0.10012627393,0.100169487298,-0.88497364521,-0.051664113998,0.596681118011,4,7,0.33611792326,0.253960758448,0.655585587025,0.363565355539,0.379267960787,0.507921516895,1.311171174049,0.727130711079,0.758535921574,0.295541184875,0.777612358802,0.119566468263,0.387891438484,0.59108236975,1.555224717604,0.239132936526,0.775782876968,0.244790877617,0.619724008684,0.409600841077,0.382590559006,0.489581755234,1.239448017369,0.819201682154,0.765181118011,1.382019400597,0.978515923023,0.977792978287
8,0.295541167259,0.777612388134,0.119576558471,0.387891471386,0.591082334518,1.555224776268,0.239153116941,0.775782942772,-0.309009879827,-0.243681818247,0.294746935368,-0.61801981926,-0.487363576889,0.589493751526,1.38172018528,1.38238465786,1.381953120232,-0.44718080759,0.035779923201,0.304797112942,0.100192822516,0.10012627393,0.100169487298,-0.894361615181,0.071559906006,0.609594345093,4,8,0.340305089951,0.244790881872,0.619724035263,0.40960085392,0.382590562105,0.489581763744,1.239448070526,0.81920170784,0.76518112421,0.26754238172,0.74533413966,0.156893303029,0.378996875763,0.53508476344,1.49066827932,0.313786606059,0.757993751526,0.251299103619,0.58710980884,0.463091796939,0.389047172546,0.502598207237,1.174219617679,0.926183593879,0.778094345093,1.382019400597,0.972501218319,0.967071712017
9,0.267542392015,0.745334208012,0.156902998686,0.378996938467,0.535084784031,1.490668416023,0.313805997372,0.757993876934,-0.321786940098,-0.212996095419,0.286452949047,-0.643573880196,-0.425992250443,0.572906017303,1.38172018528,1.38238465786,1.381953120232,-0.461602926254,0.096433229744,0.313130229712,0.100192822516,0.10012627393,0.100169487298,-0.923205852509,0.192866563797,0.626260519028,4,9,0.347964644432,0.251299113035,0.587109804153,0.463091790676,0.389047175646,0.50259822607,1.174219608307,0.926183581352,0.778094351292,0.245537710974,0.717721131628,0.188840664693,0.370703008652,0.491075421947,1.435442263256,0.377681329387,0.741406017303,0.281142071798,0.566111322552,0.522496229298,0.397380259514,0.562284143597,1.132222645105,1.044992458596,0.794760519028,1.382019400597,0.988365709782,0.939524173737
