In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pprint

In [None]:
# Load dataset
dataset_path = '/users/rrodri19/abs-mdp/experiments/pb_obstacles/pixel/data/simple.pt'
debug_data = '/Users/rrs/Desktop/abs-mdp/data/pinball_simple_obs_debug.pt'
dataset, _ = torch.load(dataset_path)
# debug = torch.load(debug_data)


In [None]:

# Split dataset into states, actions, rewards, etc.
obs, actions, next_obs, rewards, executed, duration, initiation_masks = zip(*dataset)

In [None]:
def process_reward(rewards, duration):
    rewards = map(lambda r, duration: sum(r)/duration if duration > 0 else 0, rewards, duration)
    return np.array(list(rewards))

In [None]:
# Convert to numpy arrays
obs = np.array(obs)
actions = np.array(actions)
next_obs = np.array(next_obs)
executed = np.array(executed)
initiation_masks = np.array(initiation_masks)
rewards = process_reward(rewards, duration)
duration = np.array(duration)
# Sample dataset
N = obs.shape[0]


In [None]:
latent_states = debug['latent_states']
s = np.array(list(map(lambda x: x['state'], latent_states)))
next_s = np.array(list(map(lambda x: x['next_state'], latent_states)))

# plot initial positions
plt.scatter(s[:,0], s[:,1], marker='x')


In [None]:
# Visualize overlayed pixel observations
def overlay_image(obs, next_obs):
    return (obs+next_obs)/2

In [None]:
# Compute statistics

for a in range(max(actions)+1):
    print(f'Action {a}:')
    idx = actions == a
    _executed = executed[idx]
    # rewards
    mean_reward = rewards[idx][_executed==1]
    print(f'\tMax reward: {mean_reward.max()}', f'\tMean reward: {mean_reward.mean()}', f'\tMin reward: {mean_reward.min()}')
    # duration
    mean_duration = duration[idx][_executed==1].mean()
    print(f'\tMax duration: {duration[idx].max()}', f'\tMean duration: {mean_duration}', f'\tMin duration: {duration[idx].min()}')
    # executed
    mean_executed = executed[idx].mean()
    print(f'\tProbability of initial execution: {mean_executed}')

In [None]:
pp = pprint.pprint(debug['stats'])

In [None]:
# Random observations
sample = np.random.choice(N, 2, replace=False)
obs1, obs2 = obs[sample]
plt.subplot(1, 3, 1)
plt.imshow(overlay_image(obs1, obs2))
plt.subplot(1, 3, 2)
plt.imshow(obs1)
plt.subplot(1, 3, 3)
plt.imshow(obs2)

In [None]:
# Sample observations for executed actions
n_samples = 4
executed_idx = executed == 1
obs_executed = obs[executed_idx]
next_obs_executed = next_obs[executed_idx]
sampled_indices = np.random.choice(obs_executed.shape[0], n_samples, replace=False)

action_executed = actions[executed_idx][sampled_indices]
s_executed = s[executed_idx]
next_s_executed = next_s[executed_idx]

# Plot Executed executed
for sample in range(n_samples):
    # plt.subplot(4, 1, sample+1)
    plt.figure()
    # set action name as title
    plt.subplot(1, 3, 1)
    plt.title(f'Action: {debug["options"][action_executed[sample]]}')
    plt.imshow(overlay_image(obs_executed[sampled_indices[sample]], next_obs_executed[sampled_indices[sample]]))
    plt.subplot(1, 3, 2)
    plt.title(f'Duration: {duration[executed_idx][sampled_indices[sample]]}')
    plt.imshow(obs_executed[sampled_indices[sample]])
    plt.subplot(1, 3, 3)
    disp = next_s_executed[sampled_indices[sample]] - s_executed[sampled_indices[sample]]
    print(s_executed[sampled_indices[sample]])
    i = np.argmax(np.abs(disp[:2]))
    plt.title(f'Disp{i}: {disp[i]:.2f}')
    plt.imshow(next_obs_executed[sampled_indices[sample]])

In [None]:
# Sample observations for executed actions
n_samples = 4
actions_ = np.logical_and(np.logical_or(actions==0, actions==1), executed==1)
executed_idx = actions_ #executed == 1
obs_executed = obs[executed_idx]
next_obs_executed = next_obs[executed_idx]
sampled_indices = np.random.choice(obs_executed.shape[0], n_samples, replace=False)

action_executed = actions[executed_idx][sampled_indices]
s_executed = s[executed_idx]
next_s_executed = next_s[executed_idx]

# Plot Executed executed
for sample in range(n_samples):
    # plt.subplot(4, 1, sample+1)
    plt.figure()
    # set action name as title
    plt.subplot(1, 3, 1)
    plt.title(f'Action: {debug["options"][action_executed[sample]]}')
    plt.imshow(overlay_image(obs_executed[sampled_indices[sample]], next_obs_executed[sampled_indices[sample]]))
    plt.subplot(1, 3, 2)
    plt.title(f'Duration: {duration[executed_idx][sampled_indices[sample]]}')
    plt.imshow(obs_executed[sampled_indices[sample]])
    plt.subplot(1, 3, 3)
    disp = next_s_executed[sampled_indices[sample]] - s_executed[sampled_indices[sample]]
    print(s_executed[sampled_indices[sample]])
    i = np.argmax(np.abs(disp[:2]))
    plt.title(f'Disp{i}: {disp[i]:.2f}')
    plt.imshow(next_obs_executed[sampled_indices[sample]])