In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from igp2 import AgentState
from igp2.data.data_loaders import InDDataLoader
from igp2.data.episode import Frame
from igp2.data.scenario import InDScenario, ScenarioConfig
from igp2.opendrive.map import Map
from igp2.opendrive.plot_map import plot_map
from grit.core.data_processing import get_dataset

In [2]:
test_set = get_dataset('round', 'test')

In [3]:
test_set.true_goal.value_counts() / test_set.true_goal.shape[0]

0    0.386155
2    0.265927
1    0.180200
3    0.167718
Name: true_goal, dtype: float64

In [4]:
test_set

Unnamed: 0,path_to_goal_length,in_correct_lane,speed,acceleration,angle_in_lane,vehicle_in_front_dist,vehicle_in_front_speed,oncoming_vehicle_dist,oncoming_vehicle_speed,road_heading,...,agent_id,ego_agent_id,possible_goal,true_goal,true_goal_type,frame_id,initial_frame_id,fraction_observed,final_frame_id,episode
0,37.668293,True,4.425219,1.237664,-0.235979,100.000000,20.000000,100.0,20.0,-0.852636,...,12,13,0,2,exit-roundabout,0,0,0.000000,360.0,4
1,69.035355,True,4.425219,1.237664,-0.235979,100.000000,20.000000,100.0,20.0,-0.865193,...,12,13,1,2,exit-roundabout,0,0,0.000000,360.0,4
2,109.806009,True,4.425219,1.237664,-0.235979,78.646403,6.118256,100.0,20.0,-0.823158,...,12,13,2,2,exit-roundabout,0,0,0.000000,360.0,4
3,139.102393,True,4.425219,1.237664,-0.235979,78.646403,6.118256,100.0,20.0,-0.740689,...,12,13,3,2,exit-roundabout,0,0,0.000000,360.0,4
4,32.565355,True,5.144091,0.869355,-0.174028,21.533931,7.689521,100.0,20.0,-0.852636,...,12,13,0,2,exit-roundabout,25,0,0.069444,360.0,4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76543,96.844461,True,6.505979,1.483180,-0.256828,84.817137,8.830591,100.0,20.0,-0.740689,...,459,453,3,2,exit-roundabout,23450,23365,0.361702,23600.0,21
76544,130.278941,True,7.123876,0.505092,0.454194,46.357458,10.425151,100.0,20.0,-0.852636,...,459,453,0,2,exit-roundabout,23475,23365,0.468085,23600.0,21
76545,19.631292,True,7.123876,0.505092,0.454194,100.000000,20.000000,100.0,20.0,-0.865193,...,459,453,1,2,exit-roundabout,23475,23365,0.468085,23600.0,21
76546,60.401946,True,7.123876,0.505092,0.454194,46.107603,10.425151,100.0,20.0,-0.823158,...,459,453,2,2,exit-roundabout,23475,23365,0.468085,23600.0,21


In [5]:
results_simple = pd.read_csv('../../predictions/round_occlusion_baseline_test.csv')

In [6]:
results_privileged = pd.read_csv('../../predictions/round_generalised_grit_test.csv')

In [7]:
results_simple

Unnamed: 0,episode,agent_id,ego_agent_id,frame_id,true_goal,true_goal_type,fraction_observed,model_prediction,predicted_goal_type,model_probs,max_probs,min_probs,model_entropy,model_entropy_norm,true_goal_prob,cross_entropy,model_correct
0,4,12,13,0,2,exit-roundabout,0.000000,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
1,4,12,13,25,2,exit-roundabout,0.069444,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
2,4,12,13,50,2,exit-roundabout,0.138889,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
3,4,12,13,75,2,exit-roundabout,0.208333,1,exit-roundabout,0.671037,0.671037,0.003896,0.692523,0.499549,0.316961,0.287244,False
4,4,12,13,100,2,exit-roundabout,0.277778,1,exit-roundabout,0.671037,0.671037,0.003896,0.692523,0.499549,0.316961,0.287244,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46305,21,459,453,23375,2,exit-roundabout,0.042553,2,exit-roundabout,0.830856,0.830856,0.007683,0.519074,0.374433,0.830856,0.046325,True
46306,21,459,453,23400,2,exit-roundabout,0.148936,2,exit-roundabout,0.830856,0.830856,0.007683,0.519074,0.374433,0.830856,0.046325,True
46307,21,459,453,23425,2,exit-roundabout,0.255319,2,exit-roundabout,0.826431,0.826431,0.007642,0.540311,0.389752,0.826431,0.047660,True
46308,21,459,453,23450,2,exit-roundabout,0.361702,2,exit-roundabout,0.830856,0.830856,0.007683,0.519074,0.374433,0.830856,0.046325,True


In [8]:
results_privileged

Unnamed: 0,episode,agent_id,ego_agent_id,frame_id,true_goal,true_goal_type,fraction_observed,model_prediction,predicted_goal_type,model_probs,max_probs,min_probs,model_entropy,model_entropy_norm,true_goal_prob,cross_entropy,model_correct
0,4,12,13,0,2,exit-roundabout,0.000000,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
1,4,12,13,25,2,exit-roundabout,0.069444,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
2,4,12,13,50,2,exit-roundabout,0.138889,1,exit-roundabout,0.592404,0.592404,0.003045,0.724666,0.522736,0.398663,0.229910,False
3,4,12,13,75,2,exit-roundabout,0.208333,1,exit-roundabout,0.671037,0.671037,0.003896,0.692523,0.499549,0.316961,0.287244,False
4,4,12,13,100,2,exit-roundabout,0.277778,1,exit-roundabout,0.671037,0.671037,0.003896,0.692523,0.499549,0.316961,0.287244,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
46305,21,459,453,23375,2,exit-roundabout,0.042553,2,exit-roundabout,0.826431,0.826431,0.007642,0.540311,0.389752,0.826431,0.047660,True
46306,21,459,453,23400,2,exit-roundabout,0.148936,2,exit-roundabout,0.826431,0.826431,0.007642,0.540311,0.389752,0.826431,0.047660,True
46307,21,459,453,23425,2,exit-roundabout,0.255319,2,exit-roundabout,0.826431,0.826431,0.007642,0.540311,0.389752,0.826431,0.047660,True
46308,21,459,453,23450,2,exit-roundabout,0.361702,2,exit-roundabout,0.826431,0.826431,0.007642,0.540311,0.389752,0.826431,0.047660,True


In [9]:
cols = ['episode', 'agent_id', 'ego_agent_id', 'frame_id', 'model_correct', 'true_goal_prob', 'model_prediction']
results_merged = results_simple[cols].merge(results_privileged[cols], on=['episode', 'agent_id', 'ego_agent_id', 'frame_id'])


In [10]:
dataset_merged = results_merged.merge(test_set, on=['episode', 'agent_id', 'ego_agent_id', 'frame_id'])

In [14]:
dataset_merged

Unnamed: 0,episode,agent_id,ego_agent_id,frame_id,model_correct_x,true_goal_prob_x,model_prediction_x,model_correct_y,true_goal_prob_y,model_prediction_y,...,goal_type,vehicle_in_front_missing,oncoming_vehicle_missing,exit_number_missing,possible_goal,true_goal,true_goal_type,initial_frame_id,fraction_observed,final_frame_id
0,4,12,13,0,False,0.398663,1,False,0.398663,1,...,exit-roundabout,False,False,False,0,2,exit-roundabout,0,0.000000,360.0
1,4,12,13,0,False,0.398663,1,False,0.398663,1,...,exit-roundabout,False,False,False,1,2,exit-roundabout,0,0.000000,360.0
2,4,12,13,0,False,0.398663,1,False,0.398663,1,...,exit-roundabout,False,False,False,2,2,exit-roundabout,0,0.000000,360.0
3,4,12,13,0,False,0.398663,1,False,0.398663,1,...,exit-roundabout,False,False,False,3,2,exit-roundabout,0,0.000000,360.0
4,4,12,13,25,False,0.398663,1,False,0.398663,1,...,exit-roundabout,False,False,False,0,2,exit-roundabout,0,0.069444,360.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
185234,21,459,453,23450,True,0.830856,2,True,0.826431,2,...,exit-roundabout,True,True,False,3,2,exit-roundabout,23365,0.361702,23600.0
185235,21,459,453,23475,True,0.826431,2,True,0.826431,2,...,exit-roundabout,False,False,False,0,2,exit-roundabout,23365,0.468085,23600.0
185236,21,459,453,23475,True,0.826431,2,True,0.826431,2,...,exit-roundabout,False,True,False,1,2,exit-roundabout,23365,0.468085,23600.0
185237,21,459,453,23475,True,0.826431,2,True,0.826431,2,...,exit-roundabout,False,False,False,2,2,exit-roundabout,23365,0.468085,23600.0


In [19]:
descrepancies = dataset_merged.loc[(dataset_merged.true_goal_prob_y - dataset_merged.true_goal_prob_x) > 0.1]


In [20]:
descrepancies.sample(10)

Unnamed: 0,episode,agent_id,ego_agent_id,frame_id,model_correct_x,true_goal_prob_x,model_prediction_x,model_correct_y,true_goal_prob_y,model_prediction_y,...,goal_type,vehicle_in_front_missing,oncoming_vehicle_missing,exit_number_missing,possible_goal,true_goal,true_goal_type,initial_frame_id,fraction_observed,final_frame_id
107165,4,755,759,30875,True,0.25,0,True,0.675366,0,...,exit-roundabout,True,False,True,2,0,exit-roundabout,30655,0.6875,30975.0
154511,21,260,263,13675,True,0.25,0,True,0.930993,0,...,exit-roundabout,True,False,True,0,0,exit-roundabout,13141,1.0,13675.0
30497,4,290,292,12325,False,0.25,0,True,0.871162,1,...,exit-roundabout,True,False,True,2,1,exit-roundabout,11962,0.930769,12352.0
175058,21,392,401,20000,False,0.25,0,True,0.542431,1,...,exit-roundabout,False,False,True,3,1,exit-roundabout,19796,0.784615,20056.0
30470,4,290,292,12150,False,0.25,0,False,0.352229,0,...,exit-roundabout,True,False,True,3,1,exit-roundabout,11962,0.482051,12352.0
96435,4,698,703,28950,False,0.25,0,True,0.874669,2,...,exit-roundabout,False,False,True,0,2,exit-roundabout,28637,0.978125,28957.0
183645,21,446,453,23225,False,0.25,0,True,0.873861,1,...,exit-roundabout,False,True,True,2,1,exit-roundabout,22972,0.937037,23242.0
119781,21,112,115,5825,False,0.25,0,True,0.708349,2,...,exit-roundabout,False,False,True,2,2,exit-roundabout,5718,0.516908,5925.0
52633,4,399,401,16325,False,0.25,0,True,0.684062,3,...,exit-roundabout,False,False,True,2,3,exit-roundabout,16142,1.0,16325.0
181482,21,439,443,23000,False,0.25,0,True,0.930993,2,...,exit-roundabout,True,True,True,3,2,exit-roundabout,22719,1.0,23000.0
