In [None]:
import time
import datetime
import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import yaml
import gc
from tqdm import tqdm

import torch
from pytorch3d import transforms
from scipy.spatial.transform import Rotation

from edf.utils import preprocess, voxel_filter, binomial_test, voxelize_sample
from edf.agent import PickAgent, PlaceAgent
from edf.pybullet_env.env import MugTask
from edf.visual_utils import scatter_plot_ax, draw_poincloud_arrow_new



model_seed = 0
random.seed(model_seed)
np.random.seed(model_seed)
torch.manual_seed(model_seed)
torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
eval_config_dir = "config/eval_config/figure.yaml"
task_config_dir = "config/task_config/mug_task.yaml"
pick_agent_config_dir = "config/agent_config/pick_agent.yaml"
checkpoint_path_pick = "checkpoint/train_pick/model_iter_600.pt"
place_agent_config_dir = "config/agent_config/place_agent.yaml"
checkpoint_path_place = "checkpoint/train_place/model_iter_600.pt"
plot_path = 'whatever/'

use_gui = False
visualize_plot = True
save_plot = False
place_max_distance_plan = [0.05, 1.5]

if save_plot is False:
    plot_path = None

deterministic = True
load_model = True

In [None]:
d_dense = 0.001
d_field = 0.01

In [None]:
def pick(T):
    R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]
    X_sdg, R_sdg = data_transform.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())
    z_axis = R_sdg[:,-1]
    
    R_dg_dgpre = np.eye(3)
    R_s_dgpre = R_sdg @ R_dg_dgpre
    X_dg_dgpre = np.array([0., 0., -0.03])
    sX_dg_dgpre = R_sdg @ X_dg_dgpre
    X_s_dgpre = X_sdg + sX_dg_dgpre

    pre_pick = (X_s_dgpre, R_s_dgpre)
    pick = (X_sdg, R_sdg)

    try:
        task.pick(pre_pick, pick)
        print("Pick IK Success", flush=True)
        return True
    except StopIteration:
        #print("Pick IK Failed", flush=True)
        return False

def place(T):
    R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]
    X_sdg, R_sdg = data_transform_K.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())
    R_dg_dgpre = np.eye(3)
    R_s_dgpre = R_sdg @ R_dg_dgpre
    X_dg_dgpre = np.array([0., 0., -0.03])
    sX_dg_dgpre = R_sdg @ X_dg_dgpre
    X_s_dgpre = X_sdg + sX_dg_dgpre

    pre_place = (X_s_dgpre, R_s_dgpre)
    place = (X_sdg, R_sdg)

    try:
        task.place(pre_place, place, max_distance_plan=place_max_distance_plan)
        print("Place IK Success", flush=True)
        return True
    except StopIteration:
        #print("Place IK Failed", flush=True)
        return False



##### Load eval config #####
with open(eval_config_dir) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
device = config['device']
characteristic_length = config['characteristic_length']
plot_figsize = config['plot_figsize']
plot_result_figsize = config['plot_result_figsize']
pick_only = config['pick_only']

pick_policy = config['pick_policy']
pick_dist_temp = config['pick_dist_temp']
pick_policy_temp = config['pick_policy_temp']
pick_attempt_max = config['pick_attempt_max']
N_transform_pick = config['N_transform_pick']
mh_iter_pick = config['mh_iter_pick']
langevin_iter_pick = config['langevin_iter_pick']
optim_iter_pick = config['optim_iter_pick']
langevin_dt_pick = config['langevin_dt_pick']
optim_lr_pick = config['optim_lr_pick']
X_seed_mean_pick = config['X_seed_mean_pick']
X_seed_std_pick = config['X_seed_std_pick']
max_N_query_pick = config['max_N_query_pick']

place_policy = config['place_policy']
place_dist_temp = config['place_dist_temp']
place_policy_temp = config['place_policy_temp']
place_attempt_max = config['place_attempt_max']
N_transform_place = config['N_transform_place']
mh_iter_place = config['mh_iter_place']
langevin_iter_place = config['langevin_iter_place']
optim_iter_place = config['optim_iter_place']
langevin_dt_place = config['langevin_dt_place']
optim_lr_place = config['optim_lr_place']
X_seed_mean_place = config['X_seed_mean_place']
X_seed_std_place = config['X_seed_std_place']
max_N_query_place = 5# config['max_N_query_place']
query_temp_place = config['query_temp_place']


schedules = config['schedules']

##### Load train config #####
with open(task_config_dir) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
sleep = config['sleep']
d = config['d']
d_pick = config['d_pick']
d_place = config['d_place']

model_seed = 0
random.seed(model_seed)
np.random.seed(model_seed)
torch.manual_seed(model_seed)
if deterministic:
    torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

##### Load agent models #####
# if random_init_tp:
#     pick_agent = PickAgent(config_dir=pick_agent_config_dir, device=device, max_N_query=max_N_query_pick)
#     place_agent = PlaceAgent(config_dir=place_agent_config_dir, device=device, max_N_query=max_N_query_place)
# else:
#     pick_agent = PickAgent(config_dir=pick_agent_config_dir, tp_pickle_path=pick_tp_pickle_dir, device=device, max_N_query=max_N_query_pick)
#     place_agent = PlaceAgent(config_dir=place_agent_config_dir, tp_pickle_path=place_tp_pickle_dir, device=device, max_N_query=max_N_query_place)
if load_model:
    pick_agent = PickAgent(config_dir=pick_agent_config_dir, device=device, max_N_query=max_N_query_pick, langevin_dt=langevin_dt_pick).requires_grad_(False)
    place_agent = PlaceAgent(config_dir=place_agent_config_dir, device=device, max_N_query=max_N_query_place, langevin_dt=langevin_dt_place).requires_grad_(False)
    pick_agent.load(checkpoint_path_pick)
    place_agent.load(checkpoint_path_place)

##### Initialize task env #####
task = MugTask(use_gui=use_gui)

In [None]:
def lowvar_randn(N, d=3):
    x=torch.randn(5*N,d)
    x=x[(x.norm(dim=-1)<2.).nonzero().squeeze(-1)]

    return x[:N]

def append_alpha(color, alpha):
    assert len(color.shape)==2
    assert color.shape[-1]==3
    assert color.max() <= 1.
    assert color.min() >= 0.
    assert alpha <= 1.
    assert alpha >= 0.
    
    color_alpha = np.concatenate((color, alpha*np.ones((color.shape[0],1))), axis=-1)
    return color_alpha

def get_rot(*rots):
    q = Rotation.from_euler('xyz',[0., 0., 0.],degrees=True)
    for rot in rots:
        q = Rotation.from_euler('xyz',rot,degrees=True) * q

    return np.array(q.as_quat())

def transform_coord(coord, T):
    q,x = T
    assert q.shape==(4,) and x.shape==(3,)
    q=np.array([q[-1],q[0],q[1],q[2]])
    coord = transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(coord)).numpy() + x.reshape(-1,3)
    return coord


def append_pc(*pcs):
    coords = []
    colors = []
    for pc in pcs:
        if len(pc)==2:
            coord, color = pc
            T = (get_rot([0,0,0]),np.array([0.,0.,0.])) 
            alpha = None
        elif len(pc)==3:
            coord, color, T = pc
            if T == None:
                T = (get_rot([0,0,0]),np.array([0.,0.,0.])) 
            alpha = None
        elif len(pc)==4:
            coord, color, T, alpha = pc
            if T==None:
                T = (get_rot([0,0,0]),np.array([0.,0.,0.]))
        else:
            print(len(pc))
            raise ValueError
        if alpha is not None:
            color[...,-1] = alpha

        assert len(coord.shape)==2 and len(color.shape)==2 and len(T)==2
        assert coord.shape[-1]==3 and color.shape[-1]==4 and coord.shape[0]==color.shape[0]
        assert T[0].shape==(4,) and T[1].shape==(3,)
        coord = transform_coord(coord, T)
        coords.append(coord)
        colors.append(color)
    
    coords = np.concatenate(coords, axis=0)
    colors = np.concatenate(colors, axis=0)
    
    return coords, colors

def normalize_color_range(color, elementwise = True):
    assert len(color.shape) == 2
    if elementwise:
        color[...,:3] = color[...,:3] - color[...,:3].min(axis=0, keepdims=True) + 1e-6
        color[...,:3] = color[...,:3] / color[...,:3].max(axis=0, keepdims=True)
    else:
        color[...,:3] = color[...,:3] - color[...,:3].min() + 1e-6
        color[...,:3] = color[...,:3] / color[...,:3].max()
    return color

def normalize_scalar_feature(scalar_val, elementwise = True):
    if elementwise:
        scalar_val = scalar_val / scalar_val.std(axis=0,keepdims=True)
    scalar_val = normalize_color_range(scalar_val, elementwise = elementwise)
    return scalar_val

def get_edf_sample(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1:
            sample_point = sample_center.clone()
        else:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        field_val, neighbor = pick_agent.energy_model.get_field_value(feature=feature_se3T.cpu(), pos=pos_se3T.cpu(), query_points=sample_point.unsqueeze(0))
        field_val = field_val.squeeze(0)
        neighbor = neighbor.squeeze(0)

        sample_points.append(sample_point)
        sample_features.append(field_val)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features


def get_edf_sample_K(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3, shell = False):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1 and shell is False:
            sample_point = sample_center.clone()
        elif shell is False:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        else:
            shell_ = torch.randn(len(sample_center), 3)
            shell_ = shell_ / shell_.norm(dim=-1, keepdim=True)
            sample_point = sample_center + shell_*std*2
        field_val, neighbor = place_agent.energy_model.get_field_value(feature=feature_se3T.cpu(), pos=pos_se3T.cpu(), query_points=sample_point.unsqueeze(0))
        field_val = field_val.squeeze(0)
        neighbor = neighbor.squeeze(0)

        sample_points.append(sample_point)
        sample_features.append(field_val)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features

def get_edf_sample_Q(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1:
            sample_point = sample_center.clone()
        else:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        sample_feature =  place_agent.query_model.get_feature(feature=feature_se3T, pos=pos_se3T, query_points = sample_point.unsqueeze(0).to('cuda')).squeeze(0).cpu()

        sample_points.append(sample_point)
        sample_features.append(sample_feature)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features

def draw_arrow(pc, begin, end, arrow_color, arrowhead_size = 1., density=10, view_normal = None):
    assert len(arrow_color) == 4
    if pc is not None:
        coord, color = pc
        coord, color = append_pc((coord, color), draw_poincloud_arrow_new(begin=begin, end=end, color=arrow_color, arrowhead_size=arrowhead_size, density=density, view_normal = view_normal))
        return coord, color
    else:
        return draw_poincloud_arrow_new(begin=begin, end=end, color=arrow_color, arrowhead_size=arrowhead_size, density=density)

def draw_blob(q, color, std ,N=100, shell = True):
    assert color.shape == (4,)
    # coord = torch.randn(N,3)
    # coord = coord[(coord.norm(dim=-1) < 2.).nonzero().squeeze(-1)].numpy()
    # coord = coord*std + q
    if shell:
        coord = torch.randn(N,3)
        coord = coord / (1e-5+coord.norm(dim=-1,keepdim=True))
        coord = coord.numpy() * std + q
    else:
        coord = lowvar_randn(N).numpy()*std + q
    color = np.repeat(color.reshape(1,-1), len(coord), axis=0)
    return coord, color

def get_vector_field(sample_points, vector_feature, criteria = lambda p,norm: norm>.5):
    #arrowed = (coord_dense, color_alpha_dense)
    vec_field = []
    for p, vec in zip(sample_points, vector_feature):
        norm = np.linalg.norm(vec)
        #vec = vec/norm
        begin = p
        #end = begin + vec * 4. * norm
        #arrowed = draw_arrow(arrowed, begin=end, end=begin, arrow_color=np.array([0, 0., 1.,1.]), N=10, arrowhead_size=0.)
        if criteria(p, norm):
            vec_field.append((p,vec))

    return vec_field

def draw_vector_field(vec_field, size = 1., length=4., append_to = None, view_normal = None, color = np.array([0, 0., 1.,1.])):
    if view_normal is None:
        view_normal = np.array([-1,1,-1]) / np.sqrt(3)

    arrowed = append_to
    for p,vec in vec_field:
        #norm = np.linalg.norm(vec)
        begin = p
        end = begin + vec * length
        arrowed = draw_arrow(arrowed, begin=begin, end=end, arrow_color=color, density=10, arrowhead_size=size, view_normal = view_normal)
        
    return arrowed

def transform_vec(vec_field, T):
    q,x = T
    assert q.shape==(4,) and x.shape==(3,)
    q=np.array([q[-1],q[0],q[1],q[2]])
    vec_rot = []
    for p,vec in vec_field:
        vec_rot.append((transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(p)).numpy() + x,   transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(vec)).numpy()))

    return vec_rot

# Dense Image (Pick)

In [None]:
seed=0
#seed=1
mug_pose='upright' 
#mug_pose='lying'
mug_type='default'
distractor=False
use_support=False

In [None]:
##### Observe #####
task.reset(seed = seed, mug_pose=mug_pose, mug_type=mug_type, distractor=distractor, use_support=use_support)
pc = task.observe_pointcloud(stride = (1,1))
sample_unprocessed = {}
sample_unprocessed['coord'] ,sample_unprocessed['color'] = voxel_filter(pc['coord'], pc['color'], d=d_dense)
vox = voxelize_sample({'coord': pc['coord'], 'color': pc['color'], 'd': d}, coord_jitter=0.1, color_jitter=0.03, pick=True, place=False)
sample_unprocessed['coord'] ,sample_unprocessed['color'] = vox['coord'], vox['color']


sample_unprocessed['range'] = pc['ranges']
sample_unprocessed['images'] = task.observe()
sample_unprocessed['center'] = task.center
color_unprocessed = sample_unprocessed['color']
sample = preprocess(sample_unprocessed, characteristic_length)


##### Prepare Input #####
coord, color, ranges = sample['coord'], sample['color'], sample['ranges']
data_transform = sample['data_transform']
feature = torch.tensor(color, dtype=torch.float32, device=device)
pos = torch.tensor(coord, dtype=torch.float32, device=device)
in_range_cropped_idx = pick_agent.crop_range_idx(pos)
pos, feature = pos[in_range_cropped_idx], feature[in_range_cropped_idx]
inputs = {'feature': feature, 'pos': pos, 'edge': None, 'max_neighbor_radius': pick_agent.max_radius}




visual_info_dense = {'coord':coord[in_range_cropped_idx.cpu()].copy(), 
                'color': color_unprocessed[in_range_cropped_idx.cpu()].copy(), 
                'ranges': ranges.copy(),
                }#'ax': axes[0]}

In [None]:
coord_dense, color_dense, ranges_dense = visual_info_dense['coord'], visual_info_dense['color'], visual_info_dense['ranges']
color_alpha_dense = append_alpha(color_dense, 1.)

theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))


fig,ax = plt.subplots(1,1, figsize=(10,10), subplot_kw={'projection':'3d'})
fig.tight_layout()
ax.axis('off')
scatter_plot_ax(ax, transform_coord(coord_dense,T), color_alpha_dense, ranges_dense, frame_infos = [], transform='default')
pc_dense = (coord_dense, color_alpha_dense)
fig.savefig(f'pc_pick', transparent=True)

In [None]:
fsda

# EDF Field

In [None]:
# seed=0
# mug_pose='upright'
# mug_type='default'
# distractor=False
# use_support=False

In [None]:
##### Observe #####
task.reset(seed = seed, mug_pose=mug_pose, mug_type=mug_type, distractor=distractor, use_support=use_support)
pc = task.observe_pointcloud(stride = (1,1))
sample_unprocessed = {}
sample_unprocessed['coord'] ,sample_unprocessed['color'] = voxel_filter(pc['coord'], pc['color'], d=d_field)
# vox = voxelize_sample({'coord': pc['coord'], 'color': pc['color'], 'd': d}, coord_jitter=0.1, color_jitter=0.03, pick=True, place=False)
# sample_unprocessed['coord'] ,sample_unprocessed['color'] = vox['coord'], vox['color']


sample_unprocessed['range'] = pc['ranges']
sample_unprocessed['images'] = task.observe()
sample_unprocessed['center'] = task.center
color_unprocessed = sample_unprocessed['color']
sample = preprocess(sample_unprocessed, characteristic_length)


##### Prepare Input #####
coord, color, ranges = sample['coord'], sample['color'], sample['ranges']
data_transform = sample['data_transform']
feature = torch.tensor(color, dtype=torch.float32, device=device)
pos = torch.tensor(coord, dtype=torch.float32, device=device)
in_range_cropped_idx = pick_agent.crop_range_idx(pos)
pos, feature = pos[in_range_cropped_idx], feature[in_range_cropped_idx]
inputs = {'feature': feature, 'pos': pos, 'edge': None, 'max_neighbor_radius': pick_agent.max_radius}




visual_info_field = {'coord':coord[in_range_cropped_idx.cpu()].copy(), 
                'color': color_unprocessed[in_range_cropped_idx.cpu()].copy(), 
                'ranges': ranges.copy(),
                }#'ax': axes[0]}

In [None]:
coord_field, color_field, ranges_field = visual_info_field['coord'], visual_info_field['color'], visual_info_field['ranges']
color_alpha_field = append_alpha(color_field, 1.)

In [None]:
# fig,ax = plt.subplots(1,1, figsize=(10,10), subplot_kw={'projection':'3d'})
# ax.axis('off')
# scatter_plot_ax(ax, coord_field, color_alpha_field, ranges_field, frame_infos = [], transform='default')

In [None]:
coord_jitter = None # 0.1
color_jitter = None # 0.03

##### Observe #####
task.reset(seed = seed, mug_pose=mug_pose, mug_type=mug_type, distractor=distractor, use_support=use_support)
pc = task.observe_pointcloud(stride = (1,1))
sample_unprocessed = {}
sample_unprocessed['coord'] ,sample_unprocessed['color'] = voxel_filter(pc['coord'], pc['color'], d=d)
# vox = voxelize_sample({'coord': pc['coord'], 'color': pc['color'], 'd': d}, coord_jitter=0.1, color_jitter=0.03, pick=True, place=False)
# sample_unprocessed['coord'] ,sample_unprocessed['color'] = vox['coord'], vox['color']


sample_unprocessed['range'] = pc['ranges']
sample_unprocessed['images'] = task.observe()
sample_unprocessed['center'] = task.center
color_unprocessed = sample_unprocessed['color']
sample = preprocess(sample_unprocessed, characteristic_length)


##### Prepare Input #####
coord, color, ranges = sample['coord'], sample['color'], sample['ranges']
data_transform = sample['data_transform']
feature = torch.tensor(color, dtype=torch.float32, device=device)
pos = torch.tensor(coord, dtype=torch.float32, device=device)
in_range_cropped_idx = pick_agent.crop_range_idx(pos)
pos, feature = pos[in_range_cropped_idx], feature[in_range_cropped_idx]
inputs = {'feature': feature, 'pos': pos, 'edge': None, 'max_neighbor_radius': pick_agent.max_radius}




visual_info = {'coord':coord[in_range_cropped_idx.cpu()], 
                'color': color_unprocessed[in_range_cropped_idx.cpu()], 
                'ranges': ranges,
                }#'ax': axes[0]}


if coord_jitter:
    if True: # pick
        visual_info['coord'] = visual_info['coord'] + np.random.randn(*(visual_info['coord'].shape)) * d * coord_jitter * 100
    if False: # place
        sample['coord_pick'] = sample['coord_pick'] + np.random.randn(*(sample['coord_pick'].shape)) * sample['d_pick'] * coord_jitter
        sample['coord_place'] = sample['coord_place'] + np.random.randn(*(sample['coord_place'].shape)) * sample['d_place'] * coord_jitter

if color_jitter:
    if True: # pick
        visual_info['color'] = visual_info['color'] + np.random.randn(*(visual_info['color'].shape)) * color_jitter
    if False: #place
        sample['color_pick'] = sample['color_pick'] + np.random.randn(*(sample['color_pick'].shape)) * color_jitter
        sample['color_place'] = sample['color_place'] + np.random.randn(*(sample['color_place'].shape)) * color_jitter

    if True: #pick
        visual_info['color'] = np.where(visual_info['color'] > 1., 1., visual_info['color'])
        visual_info['color'] = np.where(visual_info['color'] < 0., 0., visual_info['color'])
    if False: #place
        sample['color_pick'] = np.where(sample['color_pick'] > 1., 1., sample['color_pick'])
        sample['color_pick'] = np.where(sample['color_pick'] < 0., 0., sample['color_pick'])
        sample['color_place'] = np.where(sample['color_place'] > 1., 1., sample['color_place'])
        sample['color_place'] = np.where(sample['color_place'] < 0., 0., sample['color_place'])

In [None]:
# coord_input, color_input, ranges_input = visual_info['coord'], visual_info['color'], visual_info['ranges']
# color_alpha_input = append_alpha(color_input, 1.)

# fig,ax = plt.subplots(1,1, figsize=(10,10), subplot_kw={'projection':'3d'})
# ax.axis('off')
# scatter_plot_ax(ax, coord_input, color_alpha_input, ranges_input, frame_infos = [], transform='default')

In [None]:
feature_se3T, pos_se3T = pick_agent.get_descriptor(inputs=inputs, requires_grad=False)

### 10 Features

In [None]:
sample_center = torch.from_numpy(coord_field).type(torch.float32)
iters=100
std=0.7
sample_points, sample_features = get_edf_sample(feature_se3T=feature_se3T, pos_se3T=pos_se3T, sample_center=sample_center, iters=iters, std=std)

ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])

fig, ax = plt.subplots(2,5, figsize=(20*5,20*2), subplot_kw={'projection':'3d'})
fig.tight_layout()
for i in range(2):
    for j in tqdm(range(5)):
        ax[i,j].axis('off')
        scatter_plot_ax(ax[i,j], sample_points, append_alpha(normalize_scalar_feature(scalar_val=np.repeat(sample_features[...,5*i+j:5*i+j+1], 3, axis=-1)), alpha=0.3), ranges_pretty, frame_infos = [], transform='default')
fig.savefig(f'10feat_{mug_pose}_seed{seed}', transparent=True)

### Colormap

In [None]:
sample_center = torch.from_numpy(coord_field).type(torch.float32)
iters=500
std=0.5
sample_points, sample_features = get_edf_sample(feature_se3T=feature_se3T, pos_se3T=pos_se3T, sample_center=sample_center, iters=iters, std=std)

feature_indices = [3,1,7]
scalar_features = sample_features[:,[*feature_indices]]
sample_colors = append_alpha(normalize_scalar_feature(scalar_val=scalar_features), alpha=0.1)
sample_colors_remap = sample_colors ** np.array([1., 2., 1., 1.]) * np.array([1., 0.8, 1., 1.]) 

In [None]:
theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))

ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])


fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
fig.tight_layout()
ax.axis('off')
scatter_plot_ax(ax, transform_coord(sample_points, T), sample_colors_remap, ranges_pretty, frame_infos = [], transform='default')
scalar_field = (sample_points, sample_colors_remap)

fig.savefig(f'colormap_{mug_pose}_seed{seed}', transparent=True)

### Merged colormap

In [None]:
# sample_center = torch.from_numpy(coord_field).type(torch.float32)
# iters=500
# std=0.7
# sample_points, sample_features = get_edf_sample(feature_se3T=feature_se3T, pos_se3T=pos_se3T, sample_center=sample_center, iters=iters, std=std)

# feature_indices = [3,1,7]
# scalar_features = sample_features[:,[*feature_indices]]
# sample_colors = append_alpha(normalize_scalar_feature(scalar_val=scalar_features), alpha=0.1)
# sample_colors_remap = sample_colors ** np.array([1., 2., 1., 1.]) * np.array([1., 0.8, 1., 1.]) 

# coord_full, color_full = append_pc([coord_dense, color_alpha_dense, (get_rot([0,0,0]),np.array([0.,0.,0.]))  , 0.9],
#                                    [sample_points, sample_colors_remap, (get_rot([0,0,0]),np.array([0.,0.,0.]))  ,0.02],)

# fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
# ax.axis('off')
# scatter_plot_ax(ax, coord_full, color_full, ranges_field, frame_infos = [], transform='default')

### Vector Fields

In [None]:
def vec_criteria(p, norm, center):
    if norm > .5:
        pass
    else:
        return False
        
    distance = np.linalg.norm((p-center))
    return True
    if distance >= 5. and distance <= 6.:
        return True
    else:
        return False

In [None]:
z_offset = -20
rim_offset = np.array([0.,0., 7.5])
X_sm, R_sm = task.get_state()['T_sm']
X_sm = (X_sm - task.center)* 100 + np.array([0., 0., z_offset])
rim_center = X_sm + (rim_offset @ R_sm.T)
criteria = lambda p,norm: vec_criteria(p,norm,rim_center)

n=20
rs = [5.9]
if mug_pose == 'upright':
    thetas = np.pi*2*np.linspace(0,1,n+1)[:-1] # upright
elif mug_pose == 'lying':
    thetas = np.pi*np.linspace(-0.1,1.,n+1)[:-1] # lying

sample_points = []
for r in rs:
    sample_point = np.stack([np.cos(thetas) * r, np.sin(thetas)* r, np.zeros(n)], axis=-1)
    sample_point = rim_center + sample_point @ R_sm.T
    sample_points.append(sample_point)
sample_points = np.concatenate(sample_points, axis=0)

In [None]:
iters=1
std=0.5
feature_idx = 1#9
#sample_points = torch.stack(torch.meshgrid(torch.linspace(-20,20,90),torch.linspace(-20,20,90),torch.linspace(-20,0,1)), dim=-1).reshape(-1,3)
sample_points, sample_features = get_edf_sample(feature_se3T=feature_se3T, pos_se3T=pos_se3T, sample_center=torch.from_numpy(sample_points).type(torch.float32), iters=iters, std=std)
vector_feature = sample_features[...,10+feature_idx*3:10+feature_idx*3+3]
vec_field = get_vector_field(sample_points=sample_points, vector_feature=vector_feature, criteria=criteria)


In [None]:
size = 2.
length = 4.

theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))
arrowed = draw_vector_field(transform_vec(vec_field, T), size=size, length = length, color=np.array([0.8,0.8,0.8, 1.]))
#arrowed = draw_blob(rim_center, np.array([0.,1,1.,1.]), std=1., N=10000)

#ranges_pretty = np.array([[-20,20],[-20,20],[-40, -0]])
# ranges_pretty = np.array([[-15,15],[-15,15],[-40, -10]])
ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])
fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
ax.axis('off')
scatter_plot_ax(ax, *append_pc((coord_dense, color_alpha_dense, T , 1.), (*arrowed, None , 0.3)), ranges_pretty, frame_infos = [], transform='default')
fig.savefig(f'vec_{mug_pose}_seed{seed}', transparent=True)

In [None]:
# mug_center = (task.get_state()['T_sm'][0] - task.center)* 100

# sample_points = torch.stack(torch.meshgrid(torch.linspace(-20,20,50),torch.linspace(-20,20,50),torch.linspace(-20,0,2)), dim=-1).reshape(-1,3)
# iters=1
# std=0.5
# sample_points, sample_features = get_edf_sample(feature_se3T=feature_se3T, pos_se3T=pos_se3T, sample_center=sample_points, iters=iters, std=std)



# feature_idx = 9

# #arrowed = (coord_dense, color_alpha_dense)
# arrowed = None
# for p, f in zip(sample_points, sample_features):
#     vec = f[...,10+feature_idx*3:10+feature_idx*3+3]
#     norm = np.linalg.norm(vec)


#     distance = np.linalg.norm((p-mug_center)[:2])
#     if distance < 5. or distance > 6.:
#         continue

#     if norm > .5:
#         vec = vec/norm
#         begin = p
#         #end = begin + vec * 3. * norm
#         end = begin + vec * 4.
#         arrowed = draw_arrow(arrowed, begin=begin, end=end, arrow_color=np.array([0, 0., 1.,1.]), N=10, arrowhead_size=2.)







# #ranges_pretty = np.array([[-20,20],[-20,20],[-40, -0]])
# ranges_pretty = np.array([[-15,15],[-15,15],[-40, -10]])
# #ranges_pretty = np.array([[-10,10],[-10,10],[-40, -20]])
# #ranges_pretty = ranges_dense

# ranges_pretty = np.array([[-15,15],[-15,15],[-30, -0]])

# fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
# ax.axis('off')
# scatter_plot_ax(ax, *append_pc((coord_dense, color_alpha_dense, (get_rot([0,0,120], [25,0,0]),np.array([0,0,0.])) , 1.), (*arrowed, (get_rot([0,0,120], [25,0,0]),np.array([0.,0,0.])) , 0.06)), ranges_pretty, frame_infos = [], transform='default')

## Vec and color

In [None]:
theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))

ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])
fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
ax.axis('off')
scatter_plot_ax(ax, *append_pc((*scalar_field, T), (*arrowed, None , 0.2)), ranges_pretty, frame_infos = [], transform='default')
fig.savefig(f'featured_{mug_pose}_seed{seed}', transparent=True)

# Pick

In [None]:
N_transforms = N_transform_pick
mh_iter = mh_iter_pick
langevin_iter = langevin_iter_pick
#T_seed_pos = torch.tensor([X_seed_std_pick])* torch.randn(N_transforms,3) + torch.tensor(X_seed_mean_pick)
# T_seed_pos = torch.rand(N_transforms,3, device=device) * (pick_agent.ranges[:,1] - pick_agent.ranges[:,0]) + pick_agent.ranges[:,0].unsqueeze(-2)
# T_seed = torch.cat([transforms.random_quaternions(N_transforms, device=device), T_seed_pos.to(device)] , dim=-1)
T_seed = 1000
#visual_info['ax']=axes[0]
Ts, edf_outputs, logs = pick_agent.forward(inputs=inputs, T_seed=T_seed, policy = pick_policy, mh_iter=mh_iter, langevin_iter=langevin_iter, 
                                            temperature=pick_dist_temp, policy_temperature=pick_policy_temp, optim_iter=optim_iter_pick, optim_lr=optim_lr_pick, resample=False)

In [None]:
# for T in Ts[:pick_attempt_max]:
#     pick_ik_success = pick(T)
#     if pick_ik_success:
#         break


# if not pick_ik_success:
#     print("Pick fail: Couldn't find IK solution", flush=True)

# if task.check_pick_success():
#     print("Pick success", flush=True)
# else:
#     print("Pick fail: Found IK solution but failed", flush=True)

In [None]:
task.retract_robot(gripper_val=1., IK_time=1., back=True)

# MCMC

In [None]:
samples = []
for T in Ts[:200]:
    X = T[4:].numpy()
    samples.append((*draw_blob(X, np.array([0.,0.,1.,1.]), std=0.8, shell=False, N=40), None, 0.08))

pc_mcmc = append_pc(pc_dense, *samples)

In [None]:
# samples = []
# for T in Ts[:200]:
#     X = T[4:].numpy()
#     samples.append(draw_blob(X, np.array([0.,0.,1.,1.]), std=0.2))

# pc_mcmc = append_pc(pc_dense, *samples)

In [None]:
theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))

ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])

fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
fig.tight_layout()
ax.axis('off')
scatter_plot_ax(ax, transform_coord(pc_mcmc[0], T), pc_mcmc[1], ranges_dense, frame_infos = [], transform='default')
fig.savefig(f'mcmc_{mug_pose}_seed{seed}', transparent=True)

# Dense Image (Place)

In [None]:
pc_pick = task.observe_pointcloud_pick(stride = (1, 1))
sample_unprocessed['range_pick'] = pc_pick['ranges']
sample_unprocessed['pick_pose'] = (pc_pick['X_sg'], pc_pick['R_sg'])
sample_unprocessed['images_pick'] = task.observe_pick()
pc_place = task.observe_pointcloud(stride = (1, 1))
sample_unprocessed['range_place'] = pc_place['ranges']
sample_unprocessed['images_place'] = task.observe()

sample_unprocessed['coord_pick'], sample_unprocessed['color_pick'] = voxel_filter(pc_pick['coord'], pc_pick['color'], d=d_dense)
sample_unprocessed['coord_place'], sample_unprocessed['color_place'] = voxel_filter(pc_place['coord'], pc_place['color'], d=d_dense)
# vox = voxelize_sample({'coord_pick': pc_pick['coord'], 'color_pick': pc_pick['color'], 'd_pick': d_pick, 
#                        'coord_place': pc_place['coord'], 'color_place': pc_place['color'], 'd_place': d_place,}, coord_jitter=0.1, color_jitter=0.03, pick=False, place=True)
# sample_unprocessed['coord_pick'], sample_unprocessed['color_pick'], sample_unprocessed['coord_place'], sample_unprocessed['color_place'] = vox['coord_pick'], vox['color_pick'], vox['coord_place'], vox['color_place']

color_unprocessed_Q = sample_unprocessed['color_pick']
color_unprocessed_K = sample_unprocessed['color_place']
sample = preprocess(sample_unprocessed, characteristic_length, pick_and_place=True)


##### Prepare input #####
coord_Q, color_Q, ranges_Q = sample['coord_Q'], sample['color_Q'], sample['ranges_Q']
data_transform_Q = sample['data_transform_Q']
coord_K, color_K, ranges_K = sample['coord_K'], sample['color_K'], sample['ranges_K']
data_transform_K = sample['data_transform_K']

feature_Q = torch.tensor(color_Q, dtype=torch.float32, device=device)
pos_Q = torch.tensor(coord_Q, dtype=torch.float32, device=device)
in_range_cropped_idx_Q = place_agent.crop_range_idx_Q(pos_Q)
pos_Q, feature_Q  = pos_Q[in_range_cropped_idx_Q], feature_Q[in_range_cropped_idx_Q]

feature_K = torch.tensor(color_K, dtype=torch.float32, device=device)
pos_K = torch.tensor(coord_K, dtype=torch.float32, device=device)
in_range_cropped_idx_K = place_agent.crop_range_idx(pos_K)
pos_K, feature_K = pos_K[in_range_cropped_idx_K], feature_K[in_range_cropped_idx_K]

inputs_Q = {'feature': feature_Q, 'pos': pos_Q, 'edge': None, 'max_neighbor_radius': place_agent.max_radius_Q}
inputs_K = {'feature': feature_K, 'pos': pos_K, 'edge': None, 'max_neighbor_radius': place_agent.max_radius}

visual_info_K_dense = {'coord':coord_K[in_range_cropped_idx_K.cpu()].copy(), 
                'color': color_unprocessed_K[in_range_cropped_idx_K.cpu()].copy(), 
                'ranges': ranges_K.copy(),
                #'ax': axes[1],
                'coord_query': coord_Q[in_range_cropped_idx_Q.cpu()].copy(),
                'color_query': color_unprocessed_Q[in_range_cropped_idx_Q.cpu()].copy(),
                'ranges_query': ranges_Q.copy(),
                #'ax_query': axes[2]
                }

In [None]:
coord_Q_dense, color_Q_dense, ranges_Q_dense = visual_info_K_dense['coord_query'], visual_info_K_dense['color_query'], visual_info_K_dense['ranges_query']
color_alpha_Q_dense = append_alpha(color_Q_dense, 1.)

In [None]:
T = (get_rot([0,0,0]),np.array([0.,0.,0.]))
coord_Q_rot = transform_coord(coord_Q_dense, T)

fig,ax = plt.subplots(1,1, figsize=(13,13), subplot_kw={'projection':'3d'})
ax.axis('off')
scatter_plot_ax(ax, transform_coord(coord_Q_dense, T), color_alpha_Q_dense, ranges_Q_dense*0.3, frame_infos = [], transform='default')

In [None]:
Ts_ = [T]
for i in range(20):
    Ts_.append((get_rot([360*np.random.rand(), 180*np.random.rand(), 360*np.random.rand()]),np.array([0.,0.,0.])))

In [None]:
for i,T in enumerate(Ts_):
    fig,ax = plt.subplots(1,1, figsize=(15,15), subplot_kw={'projection':'3d'})
    fig.tight_layout()
    ax.axis('off')
    scatter_plot_ax(ax, transform_coord(coord_Q_dense, T), color_alpha_Q_dense, ranges_Q_dense*0.4, frame_infos = [], transform='default')
    fig.savefig(f'EE_{i}', transparent=True)