In [1]:
import os

from edf.pc_utils import draw_geometry, voxel_filter
from edf.data import PointCloud, SE3, TargetPoseDemo, DemoSequence, DemoSeqDataset
from edf.preprocess import Rescale, NormalizeColor, Downsample
from edf.agent import PickAgent, PlaceAgent

import numpy as np
import yaml
import plotly as pl
import plotly.express as ple
import open3d as o3d

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

torch.set_printoptions(precision= 3, sci_mode=False, linewidth=120)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
scene_proc_fn = Compose([Downsample(voxel_size=1.7, coord_reduction="average"),
                            NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])
grasp_proc_fn = Compose([Downsample(voxel_size=1.4, coord_reduction="average"),
                            NormalizeColor(color_mean = torch.tensor([0.5, 0.5, 0.5]), color_std = torch.tensor([0.5, 0.5, 0.5]))])

In [3]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01

load_transforms = Compose([Rescale(rescale_factor=1/unit_len),
                          ])
trainset = DemoSeqDataset(dataset_dir="demo/test_demo1", annotation_file="data.yaml", load_transforms = load_transforms, device=device)
# train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda xs:{'processed': [x['processed'] for x in xs], 'raw': [x['raw'] for x in xs]}) 
train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda x:x)

In [4]:
for train_batch in train_dataloader:
    for data in train_batch:
        demo_seq_raw: DemoSequence = data['raw']
        demo_seq: DemoSequence = data['processed']
        break
    break

In [5]:
scene_raw1 = demo_seq_raw[1].scene_pc
grasp_raw1 = demo_seq_raw[1].grasp_pc
scene_proc1 = scene_proc_fn(scene_raw1)
grasp_proc1 = grasp_proc_fn(grasp_raw1)

In [6]:
device = 'cuda:0'
# device = 'cpu'
unit_len = 0.01

load_transforms = Compose([Rescale(rescale_factor=1/unit_len),
                          ])
trainset = DemoSeqDataset(dataset_dir="demo/test_demo2", annotation_file="data.yaml", load_transforms = load_transforms, device=device)
# train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda xs:{'processed': [x['processed'] for x in xs], 'raw': [x['raw'] for x in xs]}) 
train_dataloader = DataLoader(trainset, shuffle=False, collate_fn=lambda x:x)

In [7]:
for train_batch in train_dataloader:
    for data in train_batch:
        demo_seq_raw: DemoSequence = data['raw']
        demo_seq: DemoSequence = data['processed']
        break
    break

In [8]:
scene_raw2 = demo_seq_raw[1].scene_pc
grasp_raw2 = demo_seq_raw[1].grasp_pc
scene_proc2 = scene_proc_fn(scene_raw2)
grasp_proc2 = grasp_proc_fn(grasp_raw2)

In [10]:
place_agent_config_dir = "config/agent_config/place_agent.yaml"
place_agent_param_dir = "checkpoint/mug_10_demo/place/model_iter_600.pt"
max_N_query_place = 3
langevin_dt_place = 0.001

place_agent = PlaceAgent(config_dir=place_agent_config_dir, 
                       device = device,
                       max_N_query = max_N_query_place, 
                       langevin_dt = langevin_dt_place)

place_agent.load(place_agent_param_dir, strict=False)

In [12]:
state_dict = torch.load("checkpoint/mug_10_demo/place/model_iter_600.pt", map_location='cpu')

In [15]:
current_dict = place_agent.query_model.state_dict()

In [35]:
for k,v in state_dict['query_model_state_dict'].items():
    try:
        if not torch.allclose(current_dict[k].cpu(), v):
            print(k)
    except KeyError:
        print(f"No such key: {k}")
        v_ = place_agent.query_model
        for attr in k.split('.'):
            v_ = v_.__getattr__(attr)
        if not torch.allclose(v_.cpu(), v):
            print(k)

No such key: weight_field.layernorm.scatter_index_elementwise
No such key: weight_field.layernorm.scatter_index_irrepwise
No such key: weight_field.layernorm.counts_elementwise
No such key: weight_field.layernorm.counts_irrepwise
No such key: weight_field.layernorm.counts_irrepwise_unbiased
No such key: weight_field.layernorm.scalar_mask_elementwise
No such key: feature_field.layernorm.scatter_index_elementwise
No such key: feature_field.layernorm.scatter_index_irrepwise
No such key: feature_field.layernorm.counts_elementwise
No such key: feature_field.layernorm.counts_irrepwise
No such key: feature_field.layernorm.counts_irrepwise_unbiased
No such key: feature_field.layernorm.scalar_mask_elementwise


In [33]:
v

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')

In [26]:
place_agent.query_model.__getattr__('weight_field')

TensorFieldLayerJIT(
  (tp): FullyConnectedTensorProduct(10x0e+10x1e+4x2e+2x3e x 1x0e+1x1e+1x2e+1x3e -> 10x0e | 260 paths | 260 weights)
  (fc): FullyConnectedNet[10, 16, 260]
  (linear_out): LinearLayerJIT(
    (layers): Sequential(
      (0): Linear(10x0e -> 10x0e | 100 weights)
    )
  )
  (layernorm): EquivLayerNormJIT()
)

In [None]:
g1 = PointCloud(grasp_proc1.points, grasp_proc1.colors * 0 + torch.tensor([1.,0.,0.],device=device))
g2 = PointCloud(grasp_proc2.points, grasp_proc2.colors * 0 + torch.tensor([0.,1.,0.],device=device))

In [None]:
draw_geometry([g1, g2])

In [None]:
s1 = PointCloud(scene_proc1.points + torch.tensor([0.,0.,0.01],device=device), scene_proc1.colors * 0 + torch.tensor([1.,0.,0.],device=device))
s2 = PointCloud(scene_proc2.points, scene_proc2.colors * 0 + torch.tensor([0.,1.,0.],device=device))

In [None]:
draw_geometry([s1, s2])

In [None]:
def draw_points(points, colors='blue'):
    if colors=='blue':
        colors=[0.1, 0.7, 0.7]
    else:
        colors=[0.7, 0.1, 0.1]

    points_visual = []
    for query_point in points.cpu():
        mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.5)
        mesh_sphere.compute_vertex_normals()
        mesh_sphere.paint_uniform_color(colors)
        mesh_sphere.translate(query_point)
        points_visual.append(mesh_sphere)
    return points_visual

In [None]:
T_seed = 100
place_policy = 'sorted'
place_mh_iter = 1000
place_langevin_iter = 300
place_dist_temp = 1.
place_policy_temp = 1.
place_optim_iter = 100
place_optim_lr = 0.005
place_query_temp = 1.

In [None]:
stein_iter = 100
stein_lr = 1e-1

# Proc 1

In [None]:
Ts1, edf_outputs1, logs1 = place_agent.forward(scene=scene_proc1, T_seed=T_seed, grasp=grasp_proc1, policy = place_policy, mh_iter=place_mh_iter, langevin_iter=place_langevin_iter, 
                                            temperature=place_dist_temp, policy_temperature=place_policy_temp, optim_iter=place_optim_iter, optim_lr=place_optim_lr, query_temperature=place_query_temp)

In [None]:
edf_outputs1['query_points']

In [None]:
draw_geometry([scene_proc1] + [grasp_proc1.transformed(Ts1[0].to(device))])

In [None]:
inputs_Q1 = {"feature": grasp_proc1.colors, "pos": grasp_proc1.points}

In [None]:
place_agent.query_model.get_query(inputs=inputs_Q1, temperature=1., requires_grad=False)

In [None]:
from torch_cluster import radius_graph
temperature = 1.

with torch.no_grad():
    outputs1 = place_agent.query_model.se3T(inputs_Q1)
    feature_se3T1, pos1 = outputs1['feature'].detach(), outputs1['pos'].detach()
num_nodes1 = pos1.shape[-2]
max_num_neighbors1 = num_nodes1 -1
edge_src1, edge_dst1 = radius_graph(pos1.detach(), place_agent.query_model.query_radius, max_num_neighbors = max_num_neighbors1, loop = False)
edge1 = (edge_src1, edge_dst1)

with torch.no_grad():
    pos_weight_logit1 = place_agent.query_model.get_weight(feature=feature_se3T1.detach(), pos=pos1.detach(), query_points = pos1.detach().unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1)
    query_points_init1 = place_agent.query_model.get_init_query_pos(pos = pos1.detach(), edge = edge1, weight_logit = pos_weight_logit1.detach()) # (N_query, 3)

In [None]:
draw_geometry([grasp_proc1] + draw_points(query_points_init1))

In [None]:
draw_geometry([grasp_proc1] + draw_points(query_points_init1[:2]))

In [None]:
with torch.no_grad():
    log_P1 = lambda x: place_agent.query_model.get_weight(feature=feature_se3T1, pos=pos1, query_points = x.unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1) # (N_query)
    #log_P = lambda x: self.get_weight(feature=feature_se3T, pos=pos, query_points = x.unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1) - x.norm(dim=-1)*4  # Debug
query_points1 = place_agent.query_model.stein_vgd(x=query_points_init1.detach(), log_P = log_P1, iters=stein_iter, lr = stein_lr) # (N_query, 3)

In [None]:
draw_geometry([grasp_proc1] + draw_points(query_points1) + draw_points(query_points_init1, 'red'))

In [None]:
query_feature1 = place_agent.query_model.get_feature(feature=feature_se3T1, pos=pos1, query_points = query_points1.unsqueeze(0)).squeeze(0) # (N_query, f)
query_attention1 = place_agent.query_model.get_weight(feature=feature_se3T1, pos=pos1, query_points = query_points1.unsqueeze(0), temperature = temperature).squeeze(0) # (N_query, 1)

In [None]:
query_points1[query_attention1[:,0].argsort(descending=True)[:5]]

In [None]:
draw_geometry([grasp_proc1] + draw_points(query_points1[query_attention1[:,0].argsort(descending=True)[:5]]))

# Proc 2

In [None]:
Ts2, edf_outputs2, logs2 = place_agent.forward(scene=scene_proc2, T_seed=T_seed, grasp=grasp_proc2, policy = place_policy, mh_iter=place_mh_iter, langevin_iter=place_langevin_iter, 
                                            temperature=place_dist_temp, policy_temperature=place_policy_temp, optim_iter=place_optim_iter, optim_lr=place_optim_lr, query_temperature=place_query_temp)

In [None]:
draw_geometry([scene_proc2] + [grasp_proc2.transformed(Ts2[0].to(device))])

In [None]:
edf_outputs2['query_points']

In [None]:
inputs_Q2 = {"feature": grasp_proc2.colors, "pos": grasp_proc2.points}

In [None]:
place_agent.query_model.get_query(inputs=inputs_Q2, temperature=1., requires_grad=False)

In [None]:
from torch_cluster import radius_graph
temperature = 1.

with torch.no_grad():
    outputs2 = place_agent.query_model.se3T(inputs_Q2)
    feature_se3T2, pos2 = outputs2['feature'].detach(), outputs2['pos'].detach()
num_nodes2 = pos2.shape[-2]
max_num_neighbors2 = num_nodes2 -1
edge_src2, edge_dst2 = radius_graph(pos2.detach(), place_agent.query_model.query_radius, max_num_neighbors = max_num_neighbors2, loop = False)
edge2 = (edge_src2, edge_dst2)

with torch.no_grad():
    pos_weight_logit2 = place_agent.query_model.get_weight(feature=feature_se3T2.detach(), pos=pos2.detach(), query_points = pos2.detach().unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1)
    query_points_init2 = place_agent.query_model.get_init_query_pos(pos = pos2.detach(), edge = edge2, weight_logit = pos_weight_logit2.detach()) # (N_query, 3)

In [None]:
draw_geometry([grasp_proc2] + draw_points(query_points_init2[:5]))

In [None]:
with torch.no_grad():
    log_P2 = lambda x: place_agent.query_model.get_weight(feature=feature_se3T2, pos=pos2, query_points = x.unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1) # (N_query)
    #log_P = lambda x: self.get_weight(feature=feature_se3T, pos=pos, query_points = x.unsqueeze(0), temperature = temperature).squeeze(0).squeeze(-1) - x.norm(dim=-1)*4  # Debug
query_points2 = place_agent.query_model.stein_vgd(x=query_points_init2.detach(), log_P = log_P2, iters=stein_iter, lr = stein_lr) # (N_query, 3)

In [None]:
draw_geometry([grasp_proc2] + draw_points(query_points2) + draw_points(query_points_init2, 'red'))

In [None]:
query_feature2 = place_agent.query_model.get_feature(feature=feature_se3T2, pos=pos2, query_points = query_points2.unsqueeze(0)).squeeze(0) # (N_query, f)
query_attention2 = place_agent.query_model.get_weight(feature=feature_se3T2, pos=pos2, query_points = query_points2.unsqueeze(0), temperature = temperature).squeeze(0) # (N_query, 1)

In [None]:
draw_geometry([grasp_proc2] + draw_points(query_points2[query_attention2[:,0].argsort(descending=True)[:5]]))

In [None]:
import torch.nn.functional as F


query_attention = F.softmax(query_attention.squeeze(-1), dim=-1) # (N_query,)
assert query_attention.dim() == 1
sorted_idx = query_attention.argsort(descending=True)
if self.max_N_query is not None:
    sorted_idx = sorted_idx[:self.max_N_query]
query_attention = query_attention[sorted_idx]
query_feature = query_feature[sorted_idx]
query_points = query_points[sorted_idx]
if self.max_N_query is not None:
    query_attention = query_attention / query_attention.sum() # renormalize

In [None]:
draw_geometry([grasp_proc1])

In [None]:
grasps = grasp_raw.transformed(Ts[1].to(device))

In [None]:
draw_geometry([scene_proc] + [grasps])

In [None]:
grasp_raw.points.min(dim=0).values

In [None]:
grasp_raw.points.max(dim=0).values