Skip to content

Commit

Permalink
fix jmlr inference
Browse files Browse the repository at this point in the history
  • Loading branch information
nttstar committed Aug 13, 2022
1 parent 8650520 commit e089d66
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions reconstruction/jmlr/inference_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,39 @@
from pathlib import Path
from backbones import get_network
from skimage import transform as sktrans

from scipy.spatial.transform import Rotation

def batch_euler2matrix(batch_euler):
n = batch_euler.shape[0]
assert batch_euler.shape[1] == 3
batch_matrix = np.zeros([n, 3, 3], dtype=np.float32)

for i in range(n):
pitch, yaw, roll = batch_euler[i]
R = Rotation.from_euler('yxz', [yaw, pitch, roll], degrees=False).as_matrix().T
batch_matrix[i] = R

return batch_matrix

def euler2matrix(euler):
assert len(euler)==3
matrix = np.zeros([3, 3], dtype=np.float32)

pitch, yaw, roll = euler
R = Rotation.from_euler('yxz', [yaw, pitch, roll], degrees=False).as_matrix().T
matrix = R
return matrix

def Rt_from_6dof(pred_6dof):
assert pred_6dof.ndim==1 or pred_6dof.ndim==2
if pred_6dof.ndim==1:
R_pred = euler2matrix(pred_6dof[:3])
t_pred = pred_6dof[-3:]
return R_pred, t_pred
else:
R_pred = batch_euler2matrix(pred_6dof[:,:3])
t_pred = pred_6dof[:,-3:].reshape(-1,1,3)
return R_pred, t_pred

def solver_rigid(pts_3d , pts_2d , camera_matrix):
# pts_3d Nx3
Expand All @@ -22,7 +54,8 @@ def solver_rigid(pts_3d , pts_2d , camera_matrix):
dist_coeffs = np.zeros((4,1))
pts_3d = pts_3d.copy()
pts_2d = pts_2d.copy()
success, rotation_vector, translation_vector = cv2.solvePnP(pts_3d, pts_2d, camera_matrix.copy(), dist_coeffs, flags=0)
#print(pts_3d.shape, pts_3d.dtype, pts_2d.shape, pts_2d.dtype)
success, rotation_vector, translation_vector = cv2.solvePnP(pts_3d, pts_2d, camera_matrix, dist_coeffs, flags=0)
assert success
R, _ = cv2.Rodrigues(rotation_vector)
R = R.T
Expand Down Expand Up @@ -50,11 +83,24 @@ def __init__(self, cfg, local_rank=0):
backbone.load_state_dict(backbone_ckpt)
backbone.eval()
backbone.requires_grad_(False)
self.backbone = backbone
self.num_verts = cfg.num_verts
self.input_size = cfg.input_size
self.flipindex = cfg.flipindex.copy()
self.data_root = Path(cfg.root_dir)
txt_path = self.data_root / 'resources/projection_matrix.txt'
self.M_proj = np.loadtxt(txt_path, dtype=np.float32)
M1 = np.array([
[400.0, 0, 0, 0],
[ 0, 400.0, 0, 0],
[ 0, 0, 1, 0],
[400.0, 400.0, 0, 1]
])
camera_matrix = self.M_proj @ M1
camera_matrix = camera_matrix[:3,:3].T
camera_matrix[0,2] = 400
camera_matrix[1,2] = 400
self.camera_matrix = camera_matrix.copy()

def set_raw_image_size(self, width, height):
w = width / 2.0
Expand Down Expand Up @@ -112,6 +158,7 @@ def convert_2d(self, pred2, tforms, meta):
return verts2d, points2d

def solve(self, verts3d, verts2d):
print(verts3d.shape, verts2d.shape)
B = verts3d.shape[0]
R = np.zeros([B, 3, 3], dtype=np.float32)
t = np.zeros([B, 1, 3], dtype=np.float32)
Expand Down

0 comments on commit e089d66

Please sign in to comment.