In [None]:
import time
import datetime
import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt
import yaml
import gc
from tqdm import tqdm

import torch
from pytorch3d import transforms
from scipy.spatial.transform import Rotation

from edf.utils import preprocess, voxel_filter, binomial_test, voxelize_sample
from edf.agent import PickAgent, PlaceAgent
from edf.pybullet_env.env import MugTask, StickTask
from edf.visual_utils import scatter_plot_ax, draw_poincloud_arrow_new



model_seed = 0
random.seed(model_seed)
np.random.seed(model_seed)
torch.manual_seed(model_seed)
torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

In [None]:
eval_config_dir = "config/eval_config/figure.yaml"
task_config_dir = "config/task_config/mug_task.yaml"
pick_agent_config_dir = "config/agent_config/pick_agent.yaml"
checkpoint_path_pick = "checkpoint/train_pick/model_iter_600.pt"
place_agent_config_dir = "config/agent_config/place_agent.yaml"
checkpoint_path_place = "checkpoint/train_place/model_iter_600.pt"
plot_path = 'whatever/'

use_gui = False
visualize_plot = True
save_plot = False
place_max_distance_plan = [0.05, 1.5]

if save_plot is False:
    plot_path = None

deterministic = True
load_model = True

In [None]:
d_dense = 0.001
d_field = 0.01

In [None]:
def pick(T):
    R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]
    X_sdg, R_sdg = data_transform.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())
    z_axis = R_sdg[:,-1]
    
    R_dg_dgpre = np.eye(3)
    R_s_dgpre = R_sdg @ R_dg_dgpre
    X_dg_dgpre = np.array([0., 0., -0.03])
    sX_dg_dgpre = R_sdg @ X_dg_dgpre
    X_s_dgpre = X_sdg + sX_dg_dgpre

    pre_pick = (X_s_dgpre, R_s_dgpre)
    pick = (X_sdg, R_sdg)

    try:
        task.pick(pre_pick, pick)
        print("Pick IK Success", flush=True)
        return True
    except StopIteration:
        #print("Pick IK Failed", flush=True)
        return False

def place(T):
    R, X = transforms.quaternion_to_matrix(T[...,:4]), T[...,4:]
    X_sdg, R_sdg = data_transform_K.inv_transform_T(X.detach().cpu().numpy(), R.detach().cpu().numpy())
    R_dg_dgpre = np.eye(3)
    R_s_dgpre = R_sdg @ R_dg_dgpre
    X_dg_dgpre = np.array([0., 0., -0.03])
    sX_dg_dgpre = R_sdg @ X_dg_dgpre
    X_s_dgpre = X_sdg + sX_dg_dgpre

    pre_place = (X_s_dgpre, R_s_dgpre)
    place = (X_sdg, R_sdg)

    try:
        task.place(pre_place, place, max_distance_plan=place_max_distance_plan)
        print("Place IK Success", flush=True)
        return True
    except StopIteration:
        #print("Place IK Failed", flush=True)
        return False



##### Load eval config #####
with open(eval_config_dir) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
device = config['device']
characteristic_length = config['characteristic_length']
plot_figsize = config['plot_figsize']
plot_result_figsize = config['plot_result_figsize']
pick_only = config['pick_only']

pick_policy = config['pick_policy']
pick_dist_temp = config['pick_dist_temp']
pick_policy_temp = config['pick_policy_temp']
pick_attempt_max = config['pick_attempt_max']
N_transform_pick = config['N_transform_pick']
mh_iter_pick = config['mh_iter_pick']
langevin_iter_pick = config['langevin_iter_pick']
optim_iter_pick = config['optim_iter_pick']
langevin_dt_pick = config['langevin_dt_pick']
optim_lr_pick = config['optim_lr_pick']
X_seed_mean_pick = config['X_seed_mean_pick']
X_seed_std_pick = config['X_seed_std_pick']
max_N_query_pick = config['max_N_query_pick']

place_policy = config['place_policy']
place_dist_temp = config['place_dist_temp']
place_policy_temp = config['place_policy_temp']
place_attempt_max = config['place_attempt_max']
N_transform_place = config['N_transform_place']
mh_iter_place = config['mh_iter_place']
langevin_iter_place = config['langevin_iter_place']
optim_iter_place = config['optim_iter_place']
langevin_dt_place = config['langevin_dt_place']
optim_lr_place = config['optim_lr_place']
X_seed_mean_place = config['X_seed_mean_place']
X_seed_std_place = config['X_seed_std_place']
max_N_query_place = 5# config['max_N_query_place']
query_temp_place = config['query_temp_place']


schedules = config['schedules']

##### Load train config #####
with open(task_config_dir) as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
sleep = config['sleep']
d = config['d']
d_pick = config['d_pick']
d_place = config['d_place']

model_seed = 0
random.seed(model_seed)
np.random.seed(model_seed)
torch.manual_seed(model_seed)
if deterministic:
    torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=4, sci_mode=False)

##### Load agent models #####
# if random_init_tp:
#     pick_agent = PickAgent(config_dir=pick_agent_config_dir, device=device, max_N_query=max_N_query_pick)
#     place_agent = PlaceAgent(config_dir=place_agent_config_dir, device=device, max_N_query=max_N_query_place)
# else:
#     pick_agent = PickAgent(config_dir=pick_agent_config_dir, tp_pickle_path=pick_tp_pickle_dir, device=device, max_N_query=max_N_query_pick)
#     place_agent = PlaceAgent(config_dir=place_agent_config_dir, tp_pickle_path=place_tp_pickle_dir, device=device, max_N_query=max_N_query_place)
if load_model:
    pick_agent = PickAgent(config_dir=pick_agent_config_dir, device=device, max_N_query=max_N_query_pick, langevin_dt=langevin_dt_pick).requires_grad_(False)
    place_agent = PlaceAgent(config_dir=place_agent_config_dir, device=device, max_N_query=max_N_query_place, langevin_dt=langevin_dt_place).requires_grad_(False)
    pick_agent.load(checkpoint_path_pick)
    place_agent.load(checkpoint_path_place)

##### Initialize task env #####
task = StickTask(use_gui=use_gui)

In [None]:
def lowvar_randn(N, d=3):
    x=torch.randn(5*N,d)
    x=x[(x.norm(dim=-1)<2.).nonzero().squeeze(-1)]

    return x[:N]

def append_alpha(color, alpha):
    assert len(color.shape)==2
    assert color.shape[-1]==3
    assert color.max() <= 1.
    assert color.min() >= 0.
    assert alpha <= 1.
    assert alpha >= 0.
    
    color_alpha = np.concatenate((color, alpha*np.ones((color.shape[0],1))), axis=-1)
    return color_alpha

def get_rot(*rots):
    q = Rotation.from_euler('xyz',[0., 0., 0.],degrees=True)
    for rot in rots:
        q = Rotation.from_euler('xyz',rot,degrees=True) * q

    return np.array(q.as_quat())

def transform_coord(coord, T):
    q,x = T
    assert q.shape==(4,) and x.shape==(3,)
    q=np.array([q[-1],q[0],q[1],q[2]])
    coord = transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(coord)).numpy() + x.reshape(-1,3)
    return coord


def append_pc(*pcs):
    coords = []
    colors = []
    for pc in pcs:
        if len(pc)==2:
            coord, color = pc
            T = (get_rot([0,0,0]),np.array([0.,0.,0.])) 
            alpha = None
        elif len(pc)==3:
            coord, color, T = pc
            if T == None:
                T = (get_rot([0,0,0]),np.array([0.,0.,0.])) 
            alpha = None
        elif len(pc)==4:
            coord, color, T, alpha = pc
            if T==None:
                T = (get_rot([0,0,0]),np.array([0.,0.,0.]))
        else:
            print(len(pc))
            raise ValueError
        if alpha is not None:
            color[...,-1] = alpha

        assert len(coord.shape)==2 and len(color.shape)==2 and len(T)==2
        assert coord.shape[-1]==3 and color.shape[-1]==4 and coord.shape[0]==color.shape[0]
        assert T[0].shape==(4,) and T[1].shape==(3,)
        coord = transform_coord(coord, T)
        coords.append(coord)
        colors.append(color)
    
    coords = np.concatenate(coords, axis=0)
    colors = np.concatenate(colors, axis=0)
    
    return coords, colors

def normalize_color_range(color, elementwise = True):
    assert len(color.shape) == 2
    if elementwise:
        color[...,:3] = color[...,:3] - color[...,:3].min(axis=0, keepdims=True) + 1e-6
        color[...,:3] = color[...,:3] / color[...,:3].max(axis=0, keepdims=True)
    else:
        color[...,:3] = color[...,:3] - color[...,:3].min() + 1e-6
        color[...,:3] = color[...,:3] / color[...,:3].max()
    return color

def normalize_scalar_feature(scalar_val, elementwise = True):
    if elementwise:
        scalar_val = scalar_val / scalar_val.std(axis=0,keepdims=True)
    scalar_val = normalize_color_range(scalar_val, elementwise = elementwise)
    return scalar_val

def get_edf_sample(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1:
            sample_point = sample_center.clone()
        else:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        field_val, neighbor = pick_agent.energy_model.get_field_value(feature=feature_se3T.cpu(), pos=pos_se3T.cpu(), query_points=sample_point.unsqueeze(0))
        field_val = field_val.squeeze(0)
        neighbor = neighbor.squeeze(0)

        sample_points.append(sample_point)
        sample_features.append(field_val)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features


def get_edf_sample_K(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3, shell = False):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1 and shell is False:
            sample_point = sample_center.clone()
        elif shell is False:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        else:
            shell_ = torch.randn(len(sample_center), 3)
            shell_ = shell_ / shell_.norm(dim=-1, keepdim=True)
            sample_point = sample_center + shell_*std*2
        field_val, neighbor = place_agent.energy_model.get_field_value(feature=feature_se3T.cpu(), pos=pos_se3T.cpu(), query_points=sample_point.unsqueeze(0))
        field_val = field_val.squeeze(0)
        neighbor = neighbor.squeeze(0)

        sample_points.append(sample_point)
        sample_features.append(field_val)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features

def get_edf_sample_Q(feature_se3T, pos_se3T, sample_center, iters=10, std=0.3):
    sample_points = []
    sample_features = []
    for i in tqdm(range(iters)):
        if iters == 1:
            sample_point = sample_center.clone()
        else:
            sample_point = sample_center + torch.randn(len(sample_center),3)*std
        sample_feature =  place_agent.query_model.get_feature(feature=feature_se3T, pos=pos_se3T, query_points = sample_point.unsqueeze(0).to('cuda')).squeeze(0).cpu()

        sample_points.append(sample_point)
        sample_features.append(sample_feature)

    sample_points = torch.cat(sample_points, dim=0).cpu().numpy()
    sample_features = torch.cat(sample_features, dim=0).cpu().numpy()

    return sample_points, sample_features

def draw_arrow(pc, begin, end, arrow_color, arrowhead_size = 1., density=10, view_normal = None):
    assert len(arrow_color) == 4
    if pc is not None:
        coord, color = pc
        coord, color = append_pc((coord, color), draw_poincloud_arrow_new(begin=begin, end=end, color=arrow_color, arrowhead_size=arrowhead_size, density=density, view_normal = view_normal))
        return coord, color
    else:
        return draw_poincloud_arrow_new(begin=begin, end=end, color=arrow_color, arrowhead_size=arrowhead_size, density=density)

def draw_blob(q, color, std ,N=100):
    assert color.shape == (4,)
    # coord = torch.randn(N,3)
    # coord = coord[(coord.norm(dim=-1) < 2.).nonzero().squeeze(-1)].numpy()
    # coord = coord*std + q
    coord = torch.randn(N,3)
    coord = coord / (1e-5+coord.norm(dim=-1,keepdim=True))
    coord = coord.numpy() * std + q
    color = np.repeat(color.reshape(1,-1), len(coord), axis=0)
    return coord, color

def get_vector_field(sample_points, vector_feature, criteria = lambda p,norm: norm>.5):
    #arrowed = (coord_dense, color_alpha_dense)
    vec_field = []
    for p, vec in zip(sample_points, vector_feature):
        norm = np.linalg.norm(vec)
        #vec = vec/norm
        begin = p
        #end = begin + vec * 4. * norm
        #arrowed = draw_arrow(arrowed, begin=end, end=begin, arrow_color=np.array([0, 0., 1.,1.]), N=10, arrowhead_size=0.)
        if criteria(p, norm):
            vec_field.append((p,vec))

    return vec_field

def draw_vector_field(vec_field, size = 1., length=4., append_to = None, view_normal = None, color = np.array([0, 0., 1.,1.])):
    if view_normal is None:
        view_normal = np.array([-1,1,-1]) / np.sqrt(3)

    arrowed = append_to
    for p,vec in vec_field:
        #norm = np.linalg.norm(vec)
        begin = p
        end = begin + vec * length
        arrowed = draw_arrow(arrowed, begin=begin, end=end, arrow_color=color, density=10, arrowhead_size=size, view_normal = view_normal)
        
    return arrowed

def transform_vec(vec_field, T):
    q,x = T
    assert q.shape==(4,) and x.shape==(3,)
    q=np.array([q[-1],q[0],q[1],q[2]])
    vec_rot = []
    for p,vec in vec_field:
        vec_rot.append((transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(p)).numpy() + x,   transforms.quaternion_apply(torch.from_numpy(q),torch.from_numpy(vec)).numpy()))

    return vec_rot

# Dense Image (Pick)

In [None]:
# task.cam_configs[0]['ypr'] = (90, -30, 0)
# task.cam_configs[1]['ypr'] = (90+90, -30, 0)
# task.cam_configs[2]['ypr'] = (90-90, -30, 0)

cam = {'target_pos': task.center,
                        'distance': 0.5,
                        'ypr': (90+135+90, -30, 0),
                        'W': 480,
                        'H': 360,
                        'up': [0,0,1],
                        'up_axis_idx': 2,
                        'near': 0.01,
                        'far' : 100,
                        'fov' : 60 
                        }
task.cam_configs.append(cam)

cam = {'target_pos': task.center,
                        'distance': 0.5,
                        'ypr': (90+135, -30, 0),
                        'W': 480,
                        'H': 360,
                        'up': [0,0,1],
                        'up_axis_idx': 2,
                        'near': 0.01,
                        'far' : 100,
                        'fov' : 60 
                        }
task.cam_configs.append(cam)

cam = {'target_pos': task.center,
                        'distance': 0.5,
                        'ypr': (90, -30, 0),
                        'W': 480,
                        'H': 360,
                        'up': [0,0,1],
                        'up_axis_idx': 2,
                        'near': 0.01,
                        'far' : 100,
                        'fov' : 60 
                        }
task.cam_configs.append(cam)

cam = {'target_pos': task.center,
                        'distance': 0.5,
                        'ypr': (90+60, -30, 0),
                        'W': 480,
                        'H': 360,
                        'up': [0,0,1],
                        'up_axis_idx': 2,
                        'near': 0.01,
                        'far' : 100,
                        'fov' : 60 
                        }
task.cam_configs.append(cam)

cam = {'target_pos': task.center,
                        'distance': 0.5,
                        'ypr': (90-60, -30, 0),
                        'W': 480,
                        'H': 360,
                        'up': [0,0,1],
                        'up_axis_idx': 2,
                        'near': 0.01,
                        'far' : 100,
                        'fov' : 60 
                        }
task.cam_configs.append(cam)

In [None]:
seed=27
#seed=1
mug_pose='upright' 
#mug_pose='lying'
mug_type='cup8'
distractor=True
use_support=True


##### Observe #####
task.reset(seed = seed, mug_pose=mug_pose, mug_type=mug_type, distractor=distractor, use_support=use_support)
pc = task.observe_pointcloud(stride = (1,1))
sample_unprocessed = {}
sample_unprocessed['coord'] ,sample_unprocessed['color'] = voxel_filter(pc['coord'], pc['color'], d=d_dense)
# vox = voxelize_sample({'coord': pc['coord'], 'color': pc['color'], 'd': d}, coord_jitter=0.1, color_jitter=0.03, pick=True, place=False)
# sample_unprocessed['coord'] ,sample_unprocessed['color'] = vox['coord'], vox['color']


sample_unprocessed['range'] = pc['ranges']
sample_unprocessed['images'] = task.observe()
sample_unprocessed['center'] = task.center
color_unprocessed = sample_unprocessed['color']
sample = preprocess(sample_unprocessed, characteristic_length)


##### Prepare Input #####
coord, color, ranges = sample['coord'], sample['color'], sample['ranges']
data_transform = sample['data_transform']
feature = torch.tensor(color, dtype=torch.float32, device=device)
pos = torch.tensor(coord, dtype=torch.float32, device=device)
in_range_cropped_idx = pick_agent.crop_range_idx(pos)
pos, feature = pos[in_range_cropped_idx], feature[in_range_cropped_idx]
inputs = {'feature': feature, 'pos': pos, 'edge': None, 'max_neighbor_radius': pick_agent.max_radius}




visual_info_dense = {'coord':coord[in_range_cropped_idx.cpu()].copy(), 
                'color': color_unprocessed[in_range_cropped_idx.cpu()].copy(), 
                'ranges': ranges.copy(),
                }#'ax': axes[0]}


coord_dense, color_dense, ranges_dense = visual_info_dense['coord'], visual_info_dense['color'], visual_info_dense['ranges']
color_alpha_dense = append_alpha(color_dense, 1.)

theta = 90
T = (get_rot([0,0,theta]),np.array([0.,0,0.]))

ranges_pretty = np.array([[-20.,  20.],
                          [-20.,  20.],
                          [-40.,  0.]])


fig,ax = plt.subplots(1,1, figsize=(20,20), subplot_kw={'projection':'3d'})
fig.tight_layout()
ax.axis('off')
scatter_plot_ax(ax, transform_coord(coord_dense, T), color_alpha_dense, ranges_pretty, frame_infos = [], transform='default')
pc_dense = (coord_dense, color_alpha_dense)
#fig.savefig(f'{mug_type}_seed{seed}', transparent=True)
fig.savefig(f'full_upright_seed{seed}', transparent=True)

In [None]:
sdaffsad