## 0. 数据准备

In [1]:
import sys
sys.path.append('/root/Workspace/project/krrnv3')
from dataset.linemod.batchdataset import PoseDataset
from lib.network.krrn import KRRN
from lib.utils.metric import Metric
from lib.utils.utlis import batch_intrinsic_transform
from lib.utils.utlis import load_part_module
from lib.transform.trans import estimateSimilarityTransform
from tools.trainer import Trainer
from lib.transform.coordinate import crop_resize_by_warp_affine

In [2]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import kornia as kn
from mmcv import Config
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import Variable

In [3]:
cfg = Config.fromfile('/root/Workspace/project/krrnv3/config/linemod/lm_v3.py')
%matplotlib notebook

In [4]:
dataset_root = '/root/Source/ymy_dataset/Linemod_preprocessed'
num_points = 500
num_cls = 1
num_kp = 8
cls_type = 'ape'

dataset_test = PoseDataset('test', num_points, True, dataset_root, 0.0, 8, cls_type=cls_type, cfg=cfg)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=True, num_workers=0)
sym_list = dataset_test.sym_obj
metric = Metric(sym_list)
tst_ds_iter = dataloader_test.__iter__()
diameters = dataset_test.diameter
diameter = diameters[0]

test Use Object:  [1]
len of real test:  1048
len of synthetic test:  0
len of all test:  1048 



## 1. 查看回归得到的xyz

In [5]:
estimator = KRRN(num_cls=cfg.Module.NUM_CLS, cfg=cfg)
estimator.cuda()
estimator.eval()
cpkt = '/root/Source/ymy_dataset/trained/krrnv3/linemod/01/pose_model_4_0.00843839899596593.pth'
estimator = load_part_module(estimator, cpkt)
trainer = Trainer(estimator, None, None, None, cfg=cfg)

In [6]:
datas = tst_ds_iter.__next__()
with torch.no_grad():
    pred = estimator(
        datas['img_croped'].cuda(),
        datas['cloud'].cuda(),
        datas['choose'].cuda(),
        False
    )

In [18]:
ori_img = datas['ori_img'][0].cpu()
fx, fy, cx, cy = datas['intrinsic'][0].cpu().numpy()
rmin, rmax, cmin, cmax = datas['bbox'][0].cpu().numpy().astype(np.int16)
base_r, base_t = trainer.get_pose(pred, datas)
base_r, base_t = base_r.cpu(), base_t.cpu()

roi_extents = datas['extent']
left_border = datas['lfborder']
xyz = pred['xyz'].detach().cpu()

coordintae = torch.cat([
    xyz[:, 0:1, :, :] * roi_extents[:, 0:1].view(-1, 1, 1, 1) + left_border[:, 0:1].view(-1, 1, 1, 1),
    xyz[:, 1:2, :, :] * roi_extents[:, 1:2].view(-1, 1, 1, 1) + left_border[:, 1:2].view(-1, 1, 1, 1),
    xyz[:, 2:3, :, :] * roi_extents[:, 2:3].view(-1, 1, 1, 1) + left_border[:, 2:3].view(-1, 1, 1, 1)
], dim=1)

# 2D Keypoint
_bbox = datas['bbox'].cpu()
_intrinsic = datas['intrinsic'].cpu()
model_points = datas['model_points']
target = datas['target'][0].cpu()
target_r = datas['target_r'].cpu()
target_t = datas['target_t'].cpu()
# coordinate
coordintae_masked = coordintae[0].permute(1, 2, 0).detach().cpu().numpy()

model_points = datas['model_points'][0].cpu().clone()
_choose = datas['choose'][0, 0].cpu().clone()

coordinate_choose = coordintae_masked.reshape(-1, 3)[_choose]
coordinate_choose_gt = datas['coordinate_choosed'][0].cpu().clone().numpy()

# mask
mask_gt = datas['mask'][0].cpu().clone()
mask_label = datas['mask_label'][0].cpu().clone()
mask_obj = datas['mask_obj'][0].cpu().clone()
mask_depth = datas['mask_depth'][0].cpu().clone()
mask_pred = pred['mask'][0].detach().cpu().clone()

# region
region_pred = (pred['region'].detach().cpu()*datas['mask']).clone()[0]
region_gt = (datas['region'].detach().cpu()*datas['mask'].squeeze(1).long()).clone()[0]

points = datas['cloud'][0].cpu().clone()
pred_base_point = model_points @ base_r.permute(0, 2, 1) + base_t
dis_base_add, _ = metric.cal_adds_cuda(pred_base_point[0].cpu(), target.cpu(), 0)

# depth 不用的时候注释掉



if dis_base_add < 0.1 * diameter:
    print(f'Success, dis_base_add: {dis_base_add:.04f}')
else:
    print(f'Failed, dis_base_add: {dis_base_add:.04f}')

Success, dis_base_add: 0.0018


### 1.1 查看重建的normalized的xyz模型

In [8]:
masked_xyz = (xyz[0] * datas['mask'][0]).permute(1, 2, 0).cpu().clone()
# masked_xyz = (xyz[0]).permute(1, 2, 0).cpu().clone()
masked_xyz_gt = datas['coordinate'][0].permute(1, 2, 0).cpu().clone()
masked_xyz_res = torch.abs(masked_xyz_gt-masked_xyz)

plt.figure()
ax1 = plt.subplot(131)
ax1.set_title('gt')
plt.imshow(masked_xyz_gt, vmin=masked_xyz_gt.min(), vmax=masked_xyz_gt.max())
ax2 = plt.subplot(132)
ax2.set_title('pred')
plt.imshow(masked_xyz, vmin=masked_xyz.min(), vmax=masked_xyz.max())
ax3 = plt.subplot(133)
ax3.set_title('different')
plt.imshow(masked_xyz_res, vmin=masked_xyz_res.min(), vmax=masked_xyz_res.max())
plt.show()

<IPython.core.display.Javascript object>

### 1.2 查看重建模型

In [9]:
fig = plt.figure()
ax = Axes3D(fig,  auto_add_to_figure=False)
fig.add_axes(ax)
# ax.set_zlim(1.0, 1.1)
# ax.set_xlim(-0.1, 0.1)
# ax.set_ylim(-0.15, 0.05)
ax.scatter(coordinate_choose_gt[:, 0], coordinate_choose_gt[:, 1], coordinate_choose_gt[:, 2], c='r', s=0.5)
ax.scatter(coordinate_choose[:, 0], coordinate_choose[:, 1], coordinate_choose[:, 2], c='g', s=0.5)
# ax.scatter(model_points[:, 0], model_points[:, 1], model_points[:, 2], c='b', s=0.5)
plt.show()

<IPython.core.display.Javascript object>

In [10]:
coordinate_choose_gt_trans = model_points @ target_r[0].numpy().T + target_t.numpy()
coordinate_choose_trans = model_points @ base_r[0].numpy().T + base_t[0].numpy()

fig = plt.figure()
ax = Axes3D(fig,  auto_add_to_figure=False)
fig.add_axes(ax)
ax.scatter(coordinate_choose_gt_trans[:, 0], coordinate_choose_gt_trans[:, 1], coordinate_choose_gt_trans[:, 2], c='r', s=0.5)
ax.scatter(coordinate_choose_trans[:, 0], coordinate_choose_trans[:, 1], coordinate_choose_trans[:, 2], c='g', s=0.5)
# ax.scatter([0], [0], [0], c='y', s=5.0)
plt.show()

<IPython.core.display.Javascript object>

### 1.3 对mask的预测

In [11]:
plt.figure()
ax1 = plt.subplot(121)
plt.imshow(mask_gt[0])
ax1.set_title('gt')
ax2 = plt.subplot(122)
plt.imshow(mask_pred[0])
ax2.set_title('pred')
plt.show()

<IPython.core.display.Javascript object>

### 1.4 对region的预测

In [12]:
region_pred_map = torch.argmax(region_pred, 0)
plt.figure()
ax1 = plt.subplot(121)
plt.imshow(region_gt)
ax1.set_title('gt')
ax2 = plt.subplot(122)
plt.imshow(region_pred_map)
ax2.set_title('pred')
plt.show()

<IPython.core.display.Javascript object>

### 1.5 depth

In [13]:
target.shape

torch.Size([2600, 3])

In [17]:
fig = plt.figure()
ax = Axes3D(fig,  auto_add_to_figure=False)
fig.add_axes(ax)
# ax.scatter(target[..., 0], target[..., 1], target[..., 2], c='r', s=0.5)
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c='g', s=0.5)
plt.show()

<IPython.core.display.Javascript object>

### 1.6 choose

In [19]:
mv, mu = torch.where(mask_gt[0])
mlv, mlu = torch.where(mask_label[0])
mov, mou = torch.where(mask_obj[0])
mdv, mdu = torch.where(mask_depth[0])

h = mask_gt.shape[1]

cu = _choose % h
cv = _choose / h

In [20]:
plt.figure()
plt.subplot(121)
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(cu, cv, s=1.0)
plt.subplot(122)
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mu, mv, s=1.0)
plt.show()

<IPython.core.display.Javascript object>

In [23]:
plt.figure()

ax1 = plt.subplot(221)
ax1.set_title('mask')
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mu, mv, s=1.0)

ax2 = plt.subplot(222)
ax2.set_title('mask_label')
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mlu, mlv, s=1.0)

ax3 = plt.subplot(223)
ax3.set_title('mask_obj')
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mou, mov, s=1.0)

ax4 = plt.subplot(224)
ax4.set_title('mask_depth')
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mdu, mdv, s=1.0)

plt.show()

<IPython.core.display.Javascript object>

In [None]:
plt.figure()
plt.subplot(121)
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(cu, cv, s=1.0)
plt.subplot(122)
plt.imshow(ori_img[rmin:rmax, cmin:cmax, :])
plt.scatter(mu, mv, s=1.0)
plt.show()