In [None]:
import matplotlib.pyplot as plt
import torch
import trimesh

from dcv.dataset import PIFODataset, RandomImageWarper
from dcv.feature import SDF_Feature
from dcv.frame import Frame
from dcv.utils import to_device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# exp_name = 'noPixelAligned_best'
exp_name = "PIFO_best"
state = torch.load("network/" + exp_name + ".pth.tar", map_location=torch.device("cpu") if device == "cpu" else None)
C = state["config"]

trainset = PIFODataset(
    C["DATA_FILENAME"],
    num_views=C["NUM_VIEWS"],
    num_points=C["NUM_POINTS"],
    num_grasps=C["NUM_GRASPS"],
    num_hangs=C["NUM_HANGS"],
    grasp_draw_points=C["GRASP_DRAW_POINTS"],
    hang_draw_points=C["HANG_DRAW_POINTS"],
    random_erase=False,
    on_gpu_memory=(device == "cuda"),
)
# trainset.show_data(0)
testset = PIFODataset(
    "data/test_batch.hdf5",
    num_views=C["NUM_VIEWS"],
    num_points=C["NUM_POINTS"],
    num_grasps=C["NUM_GRASPS"],
    num_hangs=C["NUM_HANGS"],
    grasp_draw_points=C["GRASP_DRAW_POINTS"],
    hang_draw_points=C["HANG_DRAW_POINTS"],
    random_erase=False,
    on_gpu_memory=(device == "cuda"),
)
warper = RandomImageWarper(img_res=C["IMG_RES"])

# Model
obj = Frame()
obj.build_backbone(pretrained=True, **C)
obj.build_sdf_head(C["SDF_HEAD_HIDDEN"])
obj.build_keypoint_head("grasp", C["GRASP_HEAD_HIDDEN"], C["GRIPPER_POINTS"])
obj.build_keypoint_head("hang", C["HANG_HEAD_HIDDEN"], C["HOOK_POINTS"])
obj.load_state_dict(state["network"])
obj.to(device).eval()
F_sdf = SDF_Feature(obj)

# Visualize SDF

In [None]:
dataset = testset
dataset = trainset

num_grid = 50
lim = 0.15
dx = torch.linspace(-lim, lim, num_grid, device=device)
grid_x, grid_y = torch.meshgrid(dx, dx)
grid_x = grid_x.flatten()
grid_y = grid_y.flatten()

for ind in [0]:
    print("============================= " + str(ind) + " ==============================")
    dataset.show_data(ind, image_only=True)
    data = to_device(dataset[ind], device)
    rgb, projections = warper(data["rgb"].unsqueeze(0), data["cam_extrinsic"].unsqueeze(0), data["cam_intrinsic"].unsqueeze(0))
    obj.backbone.encode(rgb, projections)

    with torch.no_grad():
        pos = torch.stack([grid_x * 0, grid_x, grid_y], axis=1)
        sdf_pred = -F_sdf(pos.unsqueeze(0)).cpu()
        sdf_pred /= C["SDF_SCALE"]
    #         sdf_pred += .005

    plt.figure(figsize=(15, 5))
    color = sdf_pred.sign() * sdf_pred.abs().pow(0.2)
    plt.scatter(grid_x.cpu(), grid_y.cpu(), c=color)
    plt.colorbar()
    plt.grid()
    plt.axis([-lim, lim, -lim, lim])
    plt.axis("square")

    plt.show()

# Mesh reconstruction via marching cube

In [None]:
dataset = trainset
# dataset = testset
draw_mesh = False
mesh_list, true_mesh_list = [], []
T = torch.eye(4)
for j, ind in enumerate(range(0, 9)):
    data = to_device(dataset[ind], device)
    rgb, projections = warper(data["rgb"].unsqueeze(0), data["cam_extrinsic"].unsqueeze(0), data["cam_intrinsic"].unsqueeze(0))

    vertices, faces, normals = obj.extract_mesh(rgb, projections, scale=0.2, delta=0.00, sdf_scale=C["SDF_SCALE"], draw=draw_mesh)
    if not draw_mesh:
        render_images = rgb.squeeze(0).cpu()
        num_views = render_images.shape[0]
        fig = plt.figure()
        for i in range(num_views):
            ax = plt.subplot(1, num_views, i + 1)
            ax.imshow(render_images[i, ...].permute(1, 2, 0))
        plt.tight_layout()
        plt.show()
        T[0, 3] = 0.4 * (j // 3)
        T[1, 3] = 0.4 * (j % 3)
        T[2, 3] = 0.0
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces).apply_transform(T)
        print(mesh.is_watertight)
        mesh.visual.vertex_colors = [0.5, 0.5, 1.0]
        mesh_list.append(mesh)

        T[2, 3] = 0.3
        true_mesh = trimesh.load("data/meshes_coll/" + data["filenames"].decode()).apply_transform(T)
        true_mesh.visual.vertex_colors = [1.0, 0.5, 0.5]
        true_mesh_list.append(true_mesh)

trimesh.Scene(mesh_list + true_mesh_list).show()