In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from matcher import Dinov2Matcher
from utils.spd import read_pointcloud
from utils.geometric_vision import solve_pnp_ransac
from torchvision import transforms
import matplotlib.pyplot as plt
from utils.spd import depth_map_to_pointcloud
from utils.spd import  transform_pointcloud
import random
from utils.quaternion_utils import *
import pickle

In [1]:
# 看单个优化过程，用了c++之后可能用不上了
seq_id = 169         # from 1 to 33
frame_id = 0       # from 0 to 32

device = 'cuda:0'

with open(f'./vis_results/memory_pool_tydata/step{seq_id}.pkl', 'rb') as f:
    result_dict = pickle.load(f)
    rgb = np.array(result_dict['rgbs'][frame_id])
    matches_3d = np.array(result_dict['matches_3ds'][frame_id])
    gt_pose = np.array(result_dict['gt_poses'][frame_id])
    qt_pred = result_dict['qt_preds'][frame_id]
    keypoint_from_depth = np.array(result_dict['keypoint_from_depths'][frame_id])
    # fullpoint_from_depth = np.array(result_dict['fullpoint_from_depths'][frame_id])
# print(rgb)
# match_3d中，选的关键点的实际位置
fig_rgb = plt.figure()
plt.subplot(1,1,1)
rgb =rgb/255.0
for i_match in range(matches_3d.shape[0]):
    y,x = int(matches_3d[i_match,1]), int(matches_3d[i_match,2])
    rgb[:,x,y] = [1,0,0]
rgb[:,5,10] = [0,1,0]
plt.imshow(rgb.transpose(1,2,0))
plt.axis('off')

q = torch.tensor(qt_pred[0][0],device=device)
t = torch.tensor(qt_pred[0][1],device=device)
disturb = np.random.randn(3)*0.005 # 为了方便看清关键点而加的扰动
gripper_pointcloud = read_pointcloud("./pointclouds/gripper.txt")
gt_key_pointcloud = transform_pointcloud(matches_3d[:,3:],np.linalg.inv(gt_pose))-disturb  # 实际关键点
gt_full_pointcloud = transform_pointcloud(gripper_pointcloud,np.linalg.inv(gt_pose)) # 实际点云
pred_begin_key_pointcloud = np.array((quaternion_apply(q,torch.tensor(matches_3d[:,3:],device=device).float()) + t).cpu().detach())-disturb   # 未优化关键点
pred_begin_full_pointcloud = np.array((quaternion_apply(q,torch.tensor(gripper_pointcloud,device=device).float()) + t).cpu().detach())# 未优化点云

point_cloud = np.concatenate(( 
                                # fullpoint_from_depth,
                                gt_full_pointcloud,
                            pred_begin_full_pointcloud,
                                gt_key_pointcloud,
                                pred_begin_key_pointcloud,
                              keypoint_from_depth,
                             ),axis=0)

# n1 = fullpoint_from_depth.shape[0]
n1 = gt_full_pointcloud.shape[0]
n2 = pred_begin_full_pointcloud.shape[0]
n3 = gt_key_pointcloud.shape[0]
n4 = pred_begin_key_pointcloud.shape[0]
n5 = keypoint_from_depth.shape[0]
point_color = np.zeros((30000,3),dtype=float)
point_color[:n1] = np.array([1,0,1])                  # 红色的点实际关键点
point_color[n1:n1+n2] = np.array([0,1,0])             # 绿色的点是实际点云
point_color[n1+n2:n1+n2+n3] = np.array([1,1,0])       
point_color[n1+n2+n3:n1+n2+n3+n4] = np.array([0,0,1])
point_color[n1+n2+n3+n4:n1+n2+n3+n4+n5] = np.array([0,1,1])
# point_color[n1+n2+n3+n4+n5:] = np.array([1,0,1])
import plotly.graph_objects as go
# 创建图形对象
fig = go.Figure()

# 添加初始数据点
fig.add_trace(
    go.Scatter3d(
        x=point_cloud[:,0],
        y=point_cloud[:,1],
        z=point_cloud[:,2],
        mode='markers',
        marker=dict(size=8,
                    color=point_color,
                    opacity=1.0)
    )
)

# 更新数据并生成动画

frames = []
for i,(q,t) in enumerate(qt_pred):
    q = torch.tensor(q,device=device)
    t = torch.tensor(t,device=device)

    pred_key_pointcloud = np.array((quaternion_apply(q,torch.tensor(matches_3d[:,3:],device=device).float()) + t).cpu().detach())-disturb   # 优化关键点
    pred_full_pointcloud = np.array((quaternion_apply(q,torch.tensor(gripper_pointcloud,device=device).float()) + t).cpu().detach())# 优化点云
    
    point_cloud = np.concatenate(( 
                                # fullpoint_from_depth,
                                gt_full_pointcloud,
                                 pred_full_pointcloud,
                                gt_key_pointcloud,
                                pred_key_pointcloud,
                              keypoint_from_depth,
                             ),axis=0)


    frame = go.Frame(
        data=[
            go.Scatter3d(
                x=point_cloud[:,0],
                y=point_cloud[:,1],
                z=point_cloud[:,2],
                mode='markers',
                marker=dict(size=8,
                            color=point_color,
                           opacity=1.0)
            )
        ],
        name=f'frame{i}'
    )
    frames.append(frame)

fig.frames = frames

# 设置布局
fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-0.4, 0.4]),
        yaxis=dict(range=[-0.4, 0.4]),
        zaxis=dict(range=[0.6, 1.5]),
        aspectmode='cube'
    ),
    updatemenus=[
        {
            "buttons": [
                {
                    "args": [None, {"frame": {"duration": 100, "redraw": True}, "fromcurrent": True}],
                    "label": "Play",
                    "method": "animate",
                },
                {
                    "args": [[None], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate", "transition": {"duration": 0}}],
                    "label": "Pause",
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 87},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top",
        }
    ],
    width = 1000,
    height = 1000
)

# 显示图形
fig.show()

NameError: name 'pickle' is not defined

In [None]:
# 看单个优化结果
frame_id = 170

device = 'cuda:3'
results = np.loadtxt('results/memory32_key8_iter500_use_depth_no_full_no_adjust.txt')
q_pred = results[frame_id,:4]     # 1024,4
t_pred = results[frame_id,4:7]    # 1024,3
gt_pose = results[frame_id,7:23].reshape(4,4)    # 1024,4,4
q_before_opt = results[frame_id,23:27]
t_before_opt = results[frame_id,27:30]
gripper_pointcloud = read_pointcloud("./pointclouds/gripper.txt")   # 8192,3

q = torch.tensor(q_pred,device=device)
t = torch.tensor(t_pred,device=device)
q_before = torch.tensor(q_before_opt,device=device)
t_before = torch.tensor(t_before_opt,device=device)
gripper_pointcloud = read_pointcloud("./pointclouds/gripper.txt")
gt_full_pointcloud = transform_pointcloud(gripper_pointcloud,np.linalg.inv(gt_pose)) # 实际点云
pred_begin_full_pointcloud = np.array((quaternion_apply(q,torch.tensor(gripper_pointcloud,device=device).float()) + t).cpu().detach())# 优化点云
pred_init_pointcloud = np.array((quaternion_apply(q_before,torch.tensor(gripper_pointcloud,device=device).float()) + t_before).cpu().detach())# 未优化点云
print(gt_full_pointcloud.shape)
print(pred_begin_full_pointcloud.shape)
point_cloud = np.concatenate((gt_full_pointcloud,pred_begin_full_pointcloud,pred_init_pointcloud),axis=0)


n1 = gt_full_pointcloud.shape[0]
n2 = pred_begin_full_pointcloud.shape[0]

point_color = np.zeros((30000,3),dtype=float)
point_color[:n1] = np.array([1,0,0])                  
point_color[n1:n1+n2] = np.array([0,1,0])             
point_color[n1+n2:] = np.array([0,0,1]) 


import plotly.graph_objects as go
# 创建图形对象
fig = go.Figure()

# 添加初始数据点
fig.add_trace(
    go.Scatter3d(
        x=point_cloud[:,0],
        y=point_cloud[:,1],
        z=point_cloud[:,2],
        mode='markers',
        marker=dict(size=8,
                    color=point_color,
                    opacity=1.0)
    )
)
fig.update_layout(
    scene=dict(
        # xaxis=dict(range=[-0.5, 0.5]),
        # yaxis=dict(range=[-0.5, 0.5]),
        # zaxis=dict(range=[0.0, 3.0]),
        # aspectmode='cube'
    ),
    width = 1000,
    height = 1000
)