In [1]:
import torch

In [2]:
import numpy as np

In [3]:
torch.__version__

'1.4.0'

In [4]:
torch.cuda.device('cuda')

<torch.cuda.device at 0x22bdd83b668>

In [5]:
collect_point2d = []
views_list = []
point3d_list = []
views_count = 0
point3d_count = 0
uv_count = 0
with open('penguinguy_one_angle_matched.db.txt','r') as f:
    views_count, point3d_count, uv_count = f.readline().strip().split(' ')
    collect_point2d = []
    for _ in range(int(uv_count)):
        view, p3d, u,v = list(filter(None,f.readline().strip().split(' ')))
        collect_point2d.append([float(u),float(v)])
        views_list.append(int(view))
        point3d_list.append(int(p3d))

In [6]:
camera_matrix = torch.rand(int(views_count), 9, requires_grad=True)
point3d = torch.rand(int(point3d_count), 3, requires_grad=True)
point2d = torch.tensor(collect_point2d, requires_grad=False)
point3d_index = torch.tensor(point3d_list,dtype=torch.long,requires_grad=False)
extrinsic_index = torch.tensor(views_list,dtype=torch.long,requires_grad=False)

In [7]:
optimizer = torch.optim.Adam([camera_matrix,point3d], lr=0.01) 

In [8]:
# https://github.com/kashif/ceres-solver/blob/master/include/ceres/rotation.h#L457
def angle_axis_rotate_point(angle_axis, point3d):
    # define variable
    theta2 = angle_axis * angle_axis 
    theta2 = theta2.sum(1) #this is equal vector dot
    away_zero = theta2.gt(0.0)
    near_zero = theta2.lt(0.0)
    # away from zero
    theta = torch.sqrt(theta2) # some value will be nan in this step
    w = angle_axis.t() / theta
    w = w.t()
    costheta = torch.cos(theta)
    sintheta = torch.sin(theta)
    w_cross_pt = torch.cross(w,point3d)
    w_dot_pt = angle_axis*point3d
    w_dot_pt = w_dot_pt.sum(1)
    result_farzero = torch.empty(point3d.shape)
    for i in range(3):
        result_farzero[:,i] = point3d[:,i] * costheta + w_cross_pt[:,i] * sintheta + w[:,i] * (1.0 - costheta) * w_dot_pt
    # near zero
    w_cross_pt = torch.cross(angle_axis,point3d)
    result_nearzero = point3d + w_cross_pt
    result = torch.empty(point3d.shape)
    for i in range(3):
        result[:,i] = result_nearzero[:,i]*near_zero + result_farzero[:,i]*away_zero
    return result

In [9]:
# http://ceres-solver.org/nnls_tutorial.html#bundle-adjustment
def ceres_projector(camera_matrix, point3d):
    #camera[0,1,2] are the angle-axis rotation.
    p = angle_axis_rotate_point(camera_matrix[:,:3], point3d)
    # camera[3,4,5] are the translation.
    p = p + camera_matrix[:,3:6]
    # Compute the center of distortion. The sign change comes from
    # the camera model that Noah Snavely's Bundler assumes, whereby
    # the camera coordinate system has a negative z axis.
    xp = - p[:,0] / p[:,2]
    yp = - p[:,1] / p[:,2]
    #Apply second and fourth order radial distortion.
    l1 = camera_matrix[:,7]
    l2 = camera_matrix[:,8]
    r2 = xp*xp + yp*yp;
    distortion = 1.0 + r2  * (l1 + l2  * r2);
    # Compute final projected point position.
    focal = camera_matrix[:,6]
    predicted_x = focal * distortion * xp;
    predicted_y = focal * distortion * yp;
    # stack into predicted projection
    predicted_projection = torch.stack([predicted_x,predicted_y],dim=1)
    return predicted_projection;

In [10]:
for i in range(10000):
    optimizer.zero_grad()
    # build camera_matrix
    camera_list = camera_matrix[extrinsic_index]
    # build point3d matrix
    point3d_list = point3d[point3d_index]
    # Do projection
    projected = ceres_projector(camera_list, point3d_list)
    # compare after projection with point2d
    uv_difference = torch.pow(point2d - projected,2) 
    total_loss = uv_difference.sum()
    total_loss.backward(retain_graph=True)
    optimizer.step()
    if i % 10 == 0:
        print("EPOCH %4d - loss %30.6f" % (i,total_loss.item()))

EPOCH    0 - loss              2405446656.000000
EPOCH   10 - loss              2371023616.000000
EPOCH   20 - loss              2367840768.000000
EPOCH   30 - loss              2361380864.000000
EPOCH   40 - loss              2331547648.000000
EPOCH   50 - loss              2279141888.000000
EPOCH   60 - loss              2209241088.000000
EPOCH   70 - loss              2156968960.000000
EPOCH   80 - loss              2119727360.000000
EPOCH   90 - loss              2093940352.000000
EPOCH  100 - loss              2070022912.000000
EPOCH  110 - loss              2039484416.000000
EPOCH  120 - loss              2001260800.000000
EPOCH  130 - loss              1971859968.000000
EPOCH  140 - loss              1937112320.000000
EPOCH  150 - loss              1911097600.000000
EPOCH  160 - loss              1927401472.000000
EPOCH  170 - loss              1839246848.000000
EPOCH  180 - loss              1771015168.000000
EPOCH  190 - loss              1719340032.000000
EPOCH  200 - loss   

KeyboardInterrupt: 

In [None]:
np.save('point3d.npy',point3d.cpu().detach().numpy())

In [None]:
point3d.detach()