In [19]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from itertools import product

from unc.utils import load_info
from unc.agents import Agent, DQNAgent
from definitions import ROOT_DIR
from unc.envs.wrappers.lobster.belief import get_lobster_state_map

In [20]:
data_path = Path(ROOT_DIR, 'results', 'lobster_data.npy')
lobster_data = np.load(data_path, allow_pickle=True).item()
pb_data = lobster_data['2pb']
# gvf_data = lobster_data['2f']

In [21]:
# gvf_obs = gvf_data['obs'].reshape(-1, gvf_data['obs'].shape[-1])
# gvf_states = gvf_data['states'].reshape(-1, gvf_data['states'].shape[-1])
# gvf_predictions = gvf_obs[:, -2:]
# gvf_zero_obs_mask = gvf_obs[:, 0] == 1
# gvf_zero_obs = gvf_obs[gvf_zero_obs_mask]

# gvf_zero_predictions = gvf_predictions[gvf_zero_obs_mask]
# gvf_zero_predictions_range = gvf_zero_predictions.max(axis=0) - gvf_zero_predictions.min(axis=0)
# gvf_zero_predictions_normalized = (gvf_zero_predictions - gvf_zero_predictions.min(axis=0)) / gvf_zero_predictions_range
# gvf_zero_predictions_normalized.shape
# gvf_predictions

In [22]:
# L1_reward_observed = np.nonzero(gvf_obs[:, 4] == 1)[0]
# L1_reward_not_observed = np.nonzero(gvf_obs[:, 3] == 1)[0]
# L1_reward_present = np.nonzero(gvf_states[:, 1] == 1)[0]
# L1_position = gvf_states[L1_reward_present, 0]
# zero_position = gvf_states[:, 0] == 0
# zero_L1_reward_present = zero_position & (gvf_states[:, 1] == 1)
# zero_L1_reward_not_present = zero_position & (gvf_states[:, 1] == 0)
# # gvf_predictions[(gvf_obs[:, 4] == 1), 0].mean(), gvf_predictions[(gvf_obs[:, 3] == 1), 0].mean()
# np.nonzero(zero_L1_reward_present)[0][:10], gvf_predictions[[6, 7, 8, 9, 10], 0]

In [23]:
optimal_lobster_fpath = Path(ROOT_DIR, 'results', 'optimal_lobster_results.npy')

optimal_lobster_res = load_info(optimal_lobster_fpath)
optimal_q = optimal_lobster_res['qs']
state_to_idx = optimal_lobster_res['state_to_idx']
zero_states = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 1]])
optimal_zero_qs = []
for zero_state in zero_states:
    idx = state_to_idx[str(zero_state)]
    optimal_zero_qs.append(optimal_q[idx])
optimal_zero_qs = np.stack(optimal_zero_qs)
optimal_zero_qs, zero_states

(array([[0.67423829, 0.67423829, 0.67195555],
        [1.44421562, 1.22964733, 1.31641296],
        [1.22964733, 1.44421562, 1.31641296],
        [1.83245927, 1.83245927, 1.64921334]]),
 array([[0, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [0, 1, 1]]))

In [24]:
tol = 0.001
pb_state_map = get_lobster_state_map()

pb_obs = pb_data['obs'].reshape(-1, pb_data['obs'].shape[-1])
# pb_unique_obs = np.unique(np.floor(pb_obs / tol).astype(int), axis=0) * tol
pb_zero_obs_mask = (pb_obs[:, :4].sum(axis=-1) > 0) & (pb_obs[:, 4:].sum(axis=-1) == 0)
pb_zero_obs = pb_obs[pb_zero_obs_mask]
# pb_0_obs = np.unique(pb_unique_obs[:, :4], axis=0)
r1_pb_states = [pb_state_map[0, 1, 0], pb_state_map[0, 1, 1]]
r2_pb_states = [pb_state_map[0, 0, 1], pb_state_map[0, 1, 1]]
r1_pb_states, r2_pb_states

([2, 3], [1, 3])

In [25]:
# bins = np.linspace(0, 1, 50)
# # node_0_maybe_r1_states = [1, 3]
# fig, ax = plt.figure(), plt.axes()

# # ax.hist(pb_0_obs[:, node_0_maybe_r1_states].sum(axis=-1), bins=bins, color='blue', label=state, alpha=0.75, edgecolor='black', linewidth=1)
# ax.hist(gvf_zero_predictions_normalized[:, 1], bins=bins, color='blue', label=1, alpha=0.75, edgecolor='black', linewidth=1)

# # ax.set_xlabel('AUC over 50K steps')
# ax.set_ylabel(f"Frequency", rotation=0, labelpad=35)
# plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left')

In [26]:
# Likelihood Predictions
# all possible states for 2e
counts = np.arange(201)
rate = 1 / 10
likelihoods = np.exp(-counts * rate)
all_possible_likelihoods = np.array(list(product(likelihoods, likelihoods)))
all_possible_likelihoods
all_zero_obs_2e = np.zeros((all_possible_likelihoods.shape[0], 9))
all_zero_obs_2e[:, [4, 7]] = all_possible_likelihoods

In [27]:
## Here we get all possible states for 2g, at node 0
zero_obs_2g = np.zeros((4, 12))
zero_obs_2g[:, :4] = 1

# so for unc encoding, 1 == observable and collected.
ot_obs_2g = np.array([[0, 0], [1, 0], [0, 1], [1, 1]])

# For particle filter belief state, we simply find distribution over first four states
r1r2 = np.mgrid[0:1.:0.05, 0:1.:0.05].reshape(2,-1).T
r1_and_r2 = r1r2[:, 0] * r1r2[:, 1]
not_r1_and_r2 = np.expand_dims(1 - r1_and_r2, -1)
r1_and_r2 = np.expand_dims(r1_and_r2, -1)

zero_obs_2pb = np.concatenate([not_r1_and_r2, r1r2, r1_and_r2, np.zeros((r1r2.shape[0], 12 - 4))], axis=-1)


In [28]:

# here we get all possible observations at node 0 for 2
zero_obs_2 = np.array([[1., 0., 0., 0., 0., 1., 0., 0., 1.]])

# all possible observations at node 0 for 2o
discount = 0.95

obs_2o_range_single = discount ** (np.arange(300) + 1)
obs_2o_range_x, obs_2o_range_y = np.meshgrid(obs_2o_range_single, obs_2o_range_single)
obs_2o_range = np.stack((obs_2o_range_x, obs_2o_range_y), axis=-1)
ot_obs_2o = obs_2o_range.reshape(-1, 2)

zero_obs_2o = np.repeat(zero_obs_2, ot_obs_2o.shape[0], axis=0)
zero_obs_2o[:, [3, 6]] = ot_obs_2o

In [29]:
fa = 'linear'
obs_agent_fname = Path(ROOT_DIR, 'results', f'2_{fa}_agent.pth')
unc_agent_fname = Path(ROOT_DIR, 'results', f'2o_{fa}_agent.pth')
# gt_agent_fname = Path(ROOT_DIR, 'results', f'2g_{fa}_agent.pth')
pb_agent_fname = Path(ROOT_DIR, 'results', f'2pb_{fa}_agent.pth')
# gvf_agent_fname = Path(ROOT_DIR, 'results', f'2f_{fa}_agent.pth')
pred_agent_fname = Path(ROOT_DIR, 'results', f'2e_{fa}_agent.pth')



obs_agent = DQNAgent.load(obs_agent_fname, DQNAgent)
unc_agent = DQNAgent.load(unc_agent_fname, DQNAgent)
# gt_agent = DQNAgent.load(gt_agent_fname, DQNAgent)
pb_agent = DQNAgent.load(pb_agent_fname, DQNAgent)
# gvf_agent = DQNAgent.load(gvf_agent_fname, DQNAgent)
pred_agent = DQNAgent.load(pred_agent_fname, DQNAgent)


In [30]:
# gvf_agent.network_params['linear']['w'][-2:], pb_agent.network_params['linear']['w'].shape

In [31]:
all_zero_2_qs = obs_agent.Qs(zero_obs_2, obs_agent.network_params)
all_zero_2o_qs = unc_agent.Qs(zero_obs_2o, unc_agent.network_params)
# all_zero_2g_qs = gt_agent.Qs(zero_obs_2g, gt_agent.network_params)
all_zero_2pb_qs = pb_agent.Qs(pb_zero_obs, pb_agent.network_params)
# all_zero_2pb_qs = pb_agent.Qs(zero_obs_2pb, pb_agent.network_params)
# all_zero_2f_qs = gvf_agent.Qs(gvf_zero_obs, gvf_agent.network_params)
all_zero_2e_qs = pred_agent.Qs(all_zero_obs_2e, pred_agent.network_params)[:, :2]



range_2 = all_zero_2_qs.max() - all_zero_2_qs.min()
range_2o = all_zero_2o_qs.max() - all_zero_2o_qs.min()
range_optimal = optimal_zero_qs.max() - optimal_zero_qs.min()
# range_2g = all_zero_2g_qs.max() - all_zero_2g_qs.min()
range_2pb = all_zero_2pb_qs.max() - all_zero_2pb_qs.min()
# range_2f = all_zero_2f_qs.max() - all_zero_2f_qs.min()
range_2e = all_zero_2e_qs.max() - all_zero_2e_qs.min()


normalized_2_qs = (all_zero_2_qs - all_zero_2_qs.min()) / range_optimal
normalized_2o_qs = (all_zero_2o_qs - all_zero_2o_qs.min()) / range_2o
normalized_optimal_qs = (optimal_zero_qs - optimal_zero_qs.min()) / range_optimal
# normalized_2g_qs = (all_zero_2g_qs - all_zero_2g_qs.min()) / range_2g
normalized_2pb_qs = (all_zero_2pb_qs - all_zero_2pb_qs.min()) / range_2pb
# normalized_2f_qs = (all_zero_2f_qs - all_zero_2f_qs.min()) / range_2f
normalized_2e_qs = (all_zero_2e_qs - all_zero_2e_qs.min()) / range_2e


In [32]:
normalized_2e_qs

Array([[0.64435005, 0.9701441 ],
       [0.6602192 , 0.90487957],
       [0.6745782 , 0.84582573],
       ...,
       [0.16675834, 0.31417853],
       [0.16675834, 0.31417853],
       [0.16675834, 0.31417853]], dtype=float32)

In [38]:
# actions_to_plot = [0]
# action_sets = [[0], [0, 1]]
action_sets = [[0, 1]]

# algs = ['2o', '2pb', '2e']
algs = ['2o']
# algs = ['none']

show_legend = True

for alg in algs:
    for actions_to_plot in action_sets:
        action_mapping = ['Left', 'Right', 'Collect']
        actions_to_color = ['rgb(241, 196, 15)', 'rgb(52, 152, 219)']
        if show_legend:
            fig_path = Path(ROOT_DIR, 'results', f'lobster_interpolation_{alg}_{actions_to_plot}_legend.pdf')
        else:
            fig_path = Path(ROOT_DIR, 'results', f'lobster_interpolation_{alg}_{actions_to_plot}.pdf')

        fig = go.Figure(layout=go.Layout(
            margin=dict(l=0, r=0, t=0, b=0),
            showlegend=show_legend,
            font=dict(size=18),
            scene = dict(
                xaxis = dict(
                    backgroundcolor="rgb(255, 255, 255)",
                    gridcolor="rgb(189, 195, 199)",
                    title=r'r(L1) feature',
                    range=[-0.1, 1.1],
                    tickvals=[0, 1],
                    tickangle=0
                ),
                yaxis = dict(
                    backgroundcolor="rgb(255, 255, 255)",
                    gridcolor="rgb(189, 195, 199)",
                    title=r'r(L2) feature',
                    range=[-0.1, 1.1],
                    tickvals=[0, 1],
                    tickangle=0
                ),
                zaxis = dict(
                    backgroundcolor="rgb(255, 255, 255)",
                    gridcolor="rgb(189, 195, 199)",
                    title="Normalized Q",
                    range=[-0.05, 1.05],
                    tickvals=[0, 1],
                    tickangle=0
                ),
            ),

        ))

        for action in actions_to_plot:
            if alg == '2o':
                z_2o = normalized_2o_qs[:, action]
                trace_2o = go.Scatter3d(
                    x=1 - ot_obs_2o[:, 0], 
                    y=1 - ot_obs_2o[:, 1], 
                    z=z_2o, 
                    name=f"Exp Trace",
                #         name=f"{action_mapping[action]}",
                    mode='markers',
                    marker={
                        'size': 2,
                        'color': actions_to_color[action],
                        'symbol': 'circle'
                    }
                )
                fig.add_trace(trace_2o)
            elif alg == '2pb':
                z_2pb = normalized_2pb_qs[:, action]
                trace_2pb = go.Scatter3d(
                    x=pb_zero_obs[:, r1_pb_states].sum(axis=-1), 
                    y=pb_zero_obs[:, r2_pb_states].sum(axis=-1), 
            #         x=zero_obs_2pb[:, [1, 3]].sum(axis=-1), 
            #         y=zero_obs_2pb[:, [2, 3]].sum(axis=-1), 
                    z=z_2pb, 
                    name=f"PF",
            #         name=f"{action_mapping[action]}",
                    mode='markers',
                    marker={
                        'size': 2,
                        'color': actions_to_color[action],
                        'symbol': 'circle'
                    }
                )
                fig.add_trace(trace_2pb)
            elif alg == '2f':
                z_2f = normalized_2f_qs[:, action]
                trace_2f = go.Scatter3d(
                    x=gvf_zero_predictions_normalized[:, 0], 
                    y=gvf_zero_predictions_normalized[:, 1], 
                    z=z_2f, 
                    name=f"GVF",
            #         name=f"{action_mapping[action]}",
                    mode='markers',
                    marker={
                        'size': 2,
                        'color': actions_to_color[action],
                        'symbol': 'circle'
                    }
                )
                fig.add_trace(trace_2f)
            elif alg == '2e':
                z_2e = normalized_2e_qs[:, action]
                trace_2e = go.Scatter3d(
                    x=all_possible_likelihoods[:, 0], 
                    y=all_possible_likelihoods[:, 1], 
                    z=z_2e, 
                    name=f"Likelihood",
            #         name=f"{action_mapping[action]}",
                    mode='markers',
                    marker={
                        'size': 2,
                        'color': actions_to_color[action],
                        'symbol': 'circle'
                    }
                )
                fig.add_trace(trace_2e)

            z_optimal = normalized_optimal_qs[:, action]
            trace_optimal = go.Scatter3d(
                x=zero_states[:, 1], 
                y=zero_states[:, 2], 
                z=z_optimal,
        #         name=f"Ground-truth {action_mapping[action]}",
                name=f"Ground-truth state",
                mode='markers',
                marker={
                    'size': 10,
                    'color': actions_to_color[action],
                    'symbol': 'cross',
                    'line': dict(width=0.5, color="black")
                }
            )
            fig.add_trace(trace_optimal)

            z_2 = normalized_2_qs[:, action]
            trace_2 = go.Scatter3d(
                x=np.array([0]), 
                y=np.array([0]), 
                z=z_2, 
        #         name=f"observation {action_mapping[action]}",
                name=f"Observations",
                mode='markers',
                marker={
                    'size': 5,
                    'color': actions_to_color[action],
                    'symbol': 'diamond',
                    'line': dict(width=0.5, color="black")

                }
            )
            fig.add_trace(trace_2)


        camera = dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=-0.2),
            eye=dict(x=-1.7, y=-0.85, z=0.6)
        )

        fig.update_layout(scene_camera=camera)
        fig.write_image(fig_path)
        fig.show()

In [29]:
# Here we plot all 3 to get the best camera position among all 3
# actions_to_plot = [0]
actions_to_plot = [0, 1]
algs = ['2o', '2pb', '2e']

action_mapping = ['Left', 'Right', 'Collect']
actions_to_color = ['rgb(241, 196, 15)', 'rgb(52, 152, 219)']
if actions_to_plot == [0, 1]:
    fig_path = Path(ROOT_DIR, 'results', f'lobster_s0_{algs}_all_qval.pdf')
elif action_to_plot == [0]:
    fig_path = Path(ROOT_DIR, 'results', f'lobster_s0_a0_{algs}_qval_legend.pdf')
    
fig = make_subplots(rows=1, cols=3, specs=[
    [{'type': 'scene'}, {'type': 'scene'}, {'type': 'scene'}]
])
for i, alg in enumerate(algs, start=1):
    for action in actions_to_plot:
        z_2o = normalized_2o_qs[:, action]
        if alg == '2o':
            # We do 1 - decaying trace here b/c the smaller the trace, the longer it's been
            # since you've seen NO reward.
            trace_2o = go.Scatter3d(
                x=1 - ot_obs_2o[:, 0], 
                y=1 - ot_obs_2o[:, 1], 
                z=z_2o, 
                name=f"Exp Trace",
            #         name=f"{action_mapping[action]}",
                mode='markers',
                marker={
                    'size': 2,
                    'color': actions_to_color[action],
                    'symbol': 'circle'
                }
            )
            fig.add_trace(trace_2o, row=1, col=i)
        elif alg == '2pb':
            z_2pb = normalized_2pb_qs[:, action]
            trace_2pb = go.Scatter3d(
                x=pb_zero_obs[:, r1_pb_states].sum(axis=-1), 
                y=pb_zero_obs[:, r2_pb_states].sum(axis=-1), 
        #         x=zero_obs_2pb[:, [1, 3]].sum(axis=-1), 
        #         y=zero_obs_2pb[:, [2, 3]].sum(axis=-1), 
                z=z_2pb, 
                name=f"PF",
        #         name=f"{action_mapping[action]}",
                mode='markers',
                marker={
                    'size': 2,
                    'color': actions_to_color[action],
                    'symbol': 'circle'
                }
            )
            fig.add_trace(trace_2pb, row=1, col=i)
        elif alg == '2f':
            z_2f = normalized_2f_qs[:, action]
            trace_2f = go.Scatter3d(
                x=gvf_zero_predictions_normalized[:, 0], 
                y=gvf_zero_predictions_normalized[:, 1], 
                z=z_2f, 
                name=f"GVF",
        #         name=f"{action_mapping[action]}",
                mode='markers',
                marker={
                    'size': 2,
                    'color': actions_to_color[action],
                    'symbol': 'circle'
                }
            )
            fig.add_trace(trace_2f, row=1, col=i)

        z_optimal = normalized_optimal_qs[:, action]
        trace_optimal = go.Scatter3d(
            x=zero_states[:, 1], 
            y=zero_states[:, 2], 
            z=z_optimal,
    #         name=f"Ground-truth {action_mapping[action]}",
            name=f"Ground-truth state",
            mode='markers',
            marker={
                'size': 10,
                'color': actions_to_color[action],
                'symbol': 'cross',
                'line': dict(width=0.5, color="black")
            }
        )
        fig.add_trace(trace_optimal, row=1, col=i)
        
        z_2 = normalized_2_qs[:, action]
        trace_2 = go.Scatter3d(
            x=np.array([0]), 
            y=np.array([0]), 
            z=z_2, 
    #         name=f"observation {action_mapping[action]}",
            name=f"Observations",
            mode='markers',
            marker={
                'size': 5,
                'color': actions_to_color[action],
                'symbol': 'diamond',
                'line': dict(width=0.5, color="black")

            }
        )
        fig.add_trace(trace_2, row=1, col=i)

# plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left', title='action')
# ax.set_title("Q-values of Lobster environment at node 0 (uncertainty obs)")
# if alg == '2o':

# elif alg == '2pb':
axis = {
    'xaxis': dict(
            backgroundcolor="rgb(255, 255, 255)",
            gridcolor="rgb(189, 195, 199)",
            title=r'r(L1) feature',
            range=[-0.1, 1.1],
            tickvals=[0, 0.5, 1],
            tickangle=0
        ),
    'yaxis': dict(
            backgroundcolor="rgb(255, 255, 255)",
            gridcolor="rgb(189, 195, 199)",
            title=r'r(L2) feature',
            range=[-0.1, 1.1],
            tickvals=[0, 0.5, 1],
            tickangle=0
        ),
    'zaxis': dict(
            backgroundcolor="rgb(255, 255, 255)",
            gridcolor="rgb(189, 195, 199)",
            title="Normalized Q",
            range=[-0.05, 1.05],
            tickvals=[0, 0.5, 1],
            tickangle=0
        ),
}
# camera = dict(
#     up=dict(x=0, y=0, z=1),
#     center=dict(x=0, y=0, z=-0.2),
#     eye=dict(x=0.85, y=1.8, z=1.85)
# )
camera = dict(
    up=dict(x=0, y=0, z=1),
    center=dict(x=0, y=0, z=-0.1),
    eye=dict(x=-1.6, y=-0.75, z=0.5)
)
fig.update_scenes(camera=camera, 
#                   margin=dict(l=0, r=0, t=0, b=0),
#                   showlegend=True,
        **axis
    )
fig.update_layout(showlegend=False,
                  autosize=False,
                  width=1400,
                  height=500)
fig.write_image(fig_path)
fig.show()

ValueError: 
Image export using the "kaleido" engine requires the kaleido package,
which can be installed using pip:
    $ pip install -U kaleido


In [None]:
ot_obs_2g, normalized_2g_qs

In [None]:
actions_to_plot = [0, 1]
action_mapping = ['left', 'right', 'collect']
actions_to_color = ["orange", "blue"]

fig = plt.figure(figsize=(8, 6), dpi=80)

ax = fig.add_subplot(projection='3d')
ax.view_init(32, 195)

ax.set_xlim(-0.1, 0.9)
ax.set_ylim(-0.1, 0.9)
ax.set_zlim(0.8, 1.4)


ax.set_xlabel("R1 obs")
ax.set_ylabel("R2 obs")
ax.set_zlabel(f"Q")

for action in actions_to_plot:
    z = all_zero_2_qs[:, action]
    ax.scatter(0, 0, z, cmap='viridis', linewidth=0.5, label=action_mapping[action])

plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left', title='action')
ax.set_title("Q-values of Lobster environment at node 0 (normal obs)")

In [None]:
# plotting trajectories
results_fname = Path(ROOT_DIR, "results", "lobster_data.npy")
loaded = load_info(results_fname)

In [None]:
# obs_res = loaded['2']
# unc_res = loaded['2o']
pb_res = loaded['2pb']

In [None]:
pb_obs = pb_res['obs']
flat_pb = pb_obs.reshape(-1, pb_obs.shape[-1])
unique_pb = np.unique(flat_pb, axis=0)

In [None]:
unique_0_pb = np.unique(unique_pb[:, :4], axis=0)

In [None]:
bins = np.linspace(0, 1, 30)
state = 0
fig, ax = plt.figure(), plt.axes()

ax.hist(flat_pb[:, state], bins=bins, color='blue', label=state, alpha=0.75, edgecolor='black', linewidth=1)

ax.set_xlabel('AUC over 50K steps')
ax.set_ylabel(f"Frequency", rotation=0, labelpad=35)
plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left')

In [None]:
# reduced_unc_obses = get_distilled_obs(unc_res['obs'][0])
unc_obs = unc_res['obs']
at_zero_unc = unc_obs[:, :, 0] == 1

r1_ot_unc = unc_obs[:, :, 3]
r2_ot_unc = unc_obs[:, :, 6]
print(f"r1_ot min: {r1_ot_unc[at_zero_unc].min()}, r1_ot max: {r1_ot_unc[at_zero_unc].max()}")
print(f"r2_ot min: {r2_ot_unc[at_zero_unc].min()}, r2_ot max: {r2_ot_unc[at_zero_unc].max()}")

obs_obs = obs_res['obs']
at_zero_obs = obs_obs[:, :, 0] == 1

r1_ot_obs = obs_obs[:, :, 3]
r2_ot_obs = obs_obs[:, :, 6]
print(f"r1_ot_unc min: {r1_ot_obs[at_zero_obs].min()}, r1_ot_unc max: {r1_ot_obs[at_zero_obs].max()}")
print(f"r2_ot_unc min: {r2_ot_obs[at_zero_obs].min()}, r2_ot_unc max: {r2_ot_obs[at_zero_obs].max()}")

In [None]:


def get_distilled_obs(traj):
    # first we get position
    pos_one_hot = traj[:, :3]
    pos_traj = np.nonzero(pos_one_hot)[-1]
    ot1 = traj[:, 3]
    ot2 = traj[:, 6]
    reduced_obses = np.stack([pos_traj, ot1, ot2], axis=-1)
    return reduced_obses

In [None]:
# reduced_unc_obses = get_distilled_obs(unc_res['obs'][0])
traj_obs = obs_res['obs'][0]

reduced_obs_obs = get_distilled_obs(traj_obs)


In [None]:

fig = plt.figure(figsize=(8, 6), dpi=80)
colors = (z/x.shape[0])[None, :]


ax = fig.add_subplot(projection='3d')
for i in range(1):
    x = r1_ot_obs[i]
    y = r2_ot_obs[i]
    z = np.arange(x.shape[0])
#     x = reduced_obs_obs[:, 1]
#     y = reduced_obs_obs[:, 2]
#     z = reduced_obs_obs[:, 0]

    ax.scatter(x, y, z, c=z, cmap='viridis')
    ax.plot3D(x, y, z, color="black", linewidth=0.5)


ax.set_xlabel("R1 obs")
ax.set_ylabel("R2 obs")
ax.set_zlabel(f"time step")

plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left', title='action')
ax.set_title("Q-values of Lobster environment at node 0 (uncertainty obs)")

In [None]:
traj_unc = unc_res['obs'][0]

reduced_obs_unc = get_distilled_obs(traj_unc)

In [None]:

fig = plt.figure(figsize=(8, 6), dpi=80)
colors = (z/x.shape[0])[None, :]


ax = fig.add_subplot(projection='3d')
for i in range(1):
    x = r1_ot_unc[i]
    y = r2_ot_unc[i]
    z = np.arange(x.shape[0])
#     x = reduced_obs_unc[:, 1]
#     y = reduced_obs_unc[:, 2]
#     z = reduced_obs_unc[:, 0]

    ax.scatter(x, y, z, c=z, cmap='viridis')
    ax.plot3D(x, y, z, color="black", linewidth=0.5)


ax.set_xlabel("R1 obs")
ax.set_ylabel("R2 obs")
ax.set_zlabel(f"time step")

plt.legend(bbox_to_anchor=(1.00, 1), loc='upper left', title='action')
ax.set_title("Q-values of Lobster environment at node 0 (uncertainty obs)")