In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
import matplotlib.pyplot as plt
%matplotlib inline

import lie_tools
from lie_tools import rodrigues

%config IPCompleter.greedy=True

### functions

In [None]:
def get_pointcloud(k):
    x = torch.tensor(np.random.normal(0, 2, (k, 3)), dtype=torch.float32)
    return x

def get_se3(so3, r):
    fill_shape = list(r.shape[:-2])
    filler = torch.tensor([[0.,0.,0.,1.]]).view([1]*\
                            (len(fill_shape)+1)+[4]).repeat(fill_shape + [1,1])
    se3 = torch.cat([so3, r], -1).type(torch.float32)
    se3 = torch.cat([se3, filler], -2)
    return se3

def do_se3_action(se3, x):
    # [tuple1, 4, 4]
    # x = [tuple2, 3]
    # x_hat = [tuple2, tuple1, 3]
    
    tuple1 = list(se3.shape[:-2])
    tuple2 = list(x.shape[:-1])
    
    ones1 = [1]*len(tuple1)
    ones2 = [1]*len(tuple2)
    
    se3 = se3.view(tuple1 + ones2 + [4,4])
    x_hat = torch.cat([x, torch.ones(tuple2+[1])], -1)
    x_hat = se3 @ x_hat.view(ones1 + tuple2 + [4,1])
    
    return x_hat[...,[0,1,2],0]


### Create Data

In [None]:
N_POINTS = 25

# generate random SE3 element 
v = torch.tensor(np.random.normal(0, 2, (3)))
# v_ = torch.tensor(np.array([0., 0., 1.]))
# v0 = torch.tensor(np.zeros((3))) + 1e-5

r = torch.tensor(np.random.normal(0, 1, (3,1)))
# r0 = torch.tensor(np.random.normal(0, 0, (3,1)))

so3 = rodrigues(v)
se3_true = get_se3(so3, r)
#print(se3_true)

# generate point cloud 
cloud = get_pointcloud(N_POINTS)
# cloud = np.array([[2,0,0],
#                   [0,2,0],
#                   [0,0,2]])
# cloud = torch.tensor(cloud, dtype=torch.float32)
rotated_cloud = do_se3_action(se3_true, cloud)

cloud_xy = cloud[...,:-1]
cloud_z = cloud[...,-1]

rotated_cloud_xy = rotated_cloud[...,:-1]
rotated_cloud_z = rotated_cloud[..., -1]

In [None]:
fig = plt.figure(figsize=(5,5))
plt.scatter(cloud_xy.numpy()[...,0], cloud_xy.numpy()[...,1], c='r', alpha=0.5)
plt.scatter(rotated_cloud_xy.numpy()[...,0], rotated_cloud_xy.numpy()[...,1], c='g', alpha=0.5)
plt.xlim(-5,5)
plt.ylim(-5,5);

### Train Functions

In [None]:
def forward(v, r, xy):
    so3 = rodrigues(v)
    se3 = get_se3(so3, r)
    
    xyz = torch.cat([xy, z_var], -1)
    xyz_hat = do_se3_action(se3, xyz)
    
    return xyz_hat[...,:-1], xyz_hat[...,-1]

def loss(xy, xy_hat):
    # xy = (tuple2, 2)
    # xy_hat = (tuple1, tuple2, 2)
    l = (xy - xy_hat)**2
    l = l.sum(-1).mean()
    
    return l

def z_diff(z, z_var):
    return (z - z_var).abs().mean()

def print_progress(x, x_recon):
    fig = plt.figure(figsize=(5,5))
    plt.scatter(x_recon.detach().numpy()[...,0], x_recon.detach().numpy()[...,1], c='r', alpha=0.5)
    plt.scatter(x.numpy()[...,0], x.numpy()[...,1], c='g', alpha=0.5)
    plt.xlim(-6,6)
    plt.ylim(-6,6);
    
N_ITER = 10000
N_SAMPLES = 2

# initialize z_axis
z_var = torch.tensor(np.random.normal(0., 1.,
                                      (list(cloud.shape[:-1]) + [1])),
                     dtype=torch.float32,
                    requires_grad=True)

# se3 elements to update
v_var = torch.tensor(np.random.normal(0, 1, (3)), requires_grad=True)
r_var = torch.tensor(np.random.normal(0, 1, (3,1)), requires_grad=True)

i_trained = 0

In [None]:
optimize = Adam([z_var, v_var, r_var])

print('train model:')
for i in range(N_ITER):
    i_trained += 1
    optimize.zero_grad()
    xy_hat, z_hat = forward(v_var, r_var, cloud_xy)
    l = loss(rotated_cloud_xy, xy_hat)
    l.backward()
    optimize.step()
    
    if (i % 500) == 0:
        print('\r it:%d:\t loss: %.6f \t z_diff: %.3f' %
              (i_trained, l, z_diff(cloud_z, z_var)))
    if (i % 1000) == 0:
        print_progress(rotated_cloud_xy, xy_hat)
        
so3_rec = rodrigues(v_var)
se3_rec = get_se3(so3_rec, r_var)

print()
print('true')
print(se3_true)
print('rec')
print(se3_rec)
print('diff')
print(se3_true - se3_rec)