In [None]:
import numpy as np
import torch

from dcv.dataset import PIFODataset, RandomImageWarper
from dcv.feature import KeyPoint_Feature
from dcv.frame import Frame
from dcv.utils import random_quaternions, to_device, view_scene_grasp_batch, view_scene_hang_batch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
exp_name = "PIFO_best"
# exp_name = 'noPixelAligned_best'
state = torch.load("network/" + exp_name + ".pth.tar", map_location=torch.device("cpu") if device == "cpu" else None)
C = state["config"]

trainset = PIFODataset(
    C["DATA_FILENAME"],
    num_views=C["NUM_VIEWS"],
    num_points=C["NUM_POINTS"],
    num_grasps=C["NUM_GRASPS"],
    num_hangs=C["NUM_HANGS"],
    grasp_draw_points=C["GRASP_DRAW_POINTS"],
    hang_draw_points=C["HANG_DRAW_POINTS"],
    random_erase=False,
    on_gpu_memory=(device == "cuda"),
)
testset = PIFODataset(
    "data/test_batch.hdf5",
    num_views=C["NUM_VIEWS"],
    num_points=C["NUM_POINTS"],
    num_grasps=C["NUM_GRASPS"],
    num_hangs=C["NUM_HANGS"],
    grasp_draw_points=C["GRASP_DRAW_POINTS"],
    hang_draw_points=C["HANG_DRAW_POINTS"],
    random_erase=False,
    on_gpu_memory=(device == "cuda"),
)
warper = RandomImageWarper(img_res=C["IMG_RES"])
# Model
obj = Frame()
obj.build_backbone(pretrained=True, **C)
obj.build_sdf_head(C["SDF_HEAD_HIDDEN"])
obj.build_keypoint_head("grasp", C["GRASP_HEAD_HIDDEN"], C["GRIPPER_POINTS"])
obj.build_keypoint_head("hang", C["HANG_HEAD_HIDDEN"], C["HOOK_POINTS"])
obj.load_state_dict(state["network"])
obj.to(device).eval()
F_grasp = KeyPoint_Feature(obj, "grasp")
F_hang = KeyPoint_Feature(obj, "hang")

# Grasping

In [None]:
# dataset = trainset
dataset = testset

B, N = 2, 20
rgb_list, projections_list, filename_list, mass_list, com_list = [], [], [], [], []
for i in range(B):
    data = to_device(dataset[i], device)
    rgb, projections = warper(data["rgb"].unsqueeze(0), data["cam_extrinsic"].unsqueeze(0), data["cam_intrinsic"].unsqueeze(0))
    rgb_list.append(rgb)
    projections_list.append(projections)
    filename_list.append(data["filenames"].decode())
    mass_list.append(data["masses"])
    com_list.append(data["coms"])

x = torch.cat([0.2 * torch.randn(B, N, 3, device=device), random_quaternions(B * N, device=device).view(B, N, 4)], dim=2)

print("Init poses")
view_scene_grasp_batch(x.cpu().numpy(), np.ones((B, N)), filename_list, False).show()

In [None]:
x, cost, coll = F_grasp.optimize(x, torch.cat(rgb_list), torch.cat(projections_list))
x, cost, coll = F_grasp.optimize(x, torch.cat(rgb_list), torch.cat(projections_list), w_coll=1e3)

num_best = 5
best_inds = torch.tensor(cost + coll * 100.0).argsort(dim=1)[:, :num_best].to(device).view(B, num_best, 1).expand(-1, -1, 7)
best_poses = torch.gather(x, dim=1, index=best_inds)
f = np.zeros((B, N))

print("optimized")
view_scene_grasp_batch(best_poses.cpu().numpy(), np.ones((B, num_best)), filename_list, False).show()

# Hanging

In [None]:
# dataset = trainset
dataset = testset

B, N = 2, 20
rgb_list, projections_list, filename_list, mass_list, com_list = [], [], [], [], []
for i in range(B):
    data = to_device(dataset[i], device)
    rgb, projections = warper(data["rgb"].unsqueeze(0), data["cam_extrinsic"].unsqueeze(0), data["cam_intrinsic"].unsqueeze(0))
    rgb_list.append(rgb)
    projections_list.append(projections)
    filename_list.append(data["filenames"].decode())
    mass_list.append(data["masses"])
    com_list.append(data["coms"])

x = torch.cat([0.2 * torch.randn(B, N, 3, device=device), random_quaternions(B * N, device=device).view(B, N, 4)], dim=2)

print("Init poses")
view_scene_hang_batch(x.cpu().numpy(), np.ones((B, N)), filename_list).show()

In [None]:
x, cost, coll = F_hang.optimize(x, torch.cat(rgb_list), torch.cat(projections_list))
x, cost, coll = F_hang.optimize(x, torch.cat(rgb_list), torch.cat(projections_list), w_coll=1e2, coll_margin=1e-8)

num_best = 5
best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B, num_best, 1).expand(-1, -1, 7)
best_poses = torch.gather(x, dim=1, index=best_inds)

print("optimized")
view_scene_hang_batch(best_poses.cpu().numpy(), np.ones((B, N)), filename_list).show()