#### imports

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
import matplotlib.pyplot as plt
%matplotlib inline

import lie_tools
from lie_tools import rodrigues

%config IPCompleter.greedy=True

#### functions

In [None]:
# point cloud creation and se3 creation + action
def get_pointcloud(k):
    x = torch.tensor(np.random.normal(0, 2, (k, 3)), dtype=torch.float32)
    return x

def get_se3(so3, r):
    fill_shape = list(r.shape[:-2])
    filler = torch.tensor([[0.,0.,0.,1.]]).view([1]*\
                            (len(fill_shape)+1)+[4]).repeat(fill_shape + [1,1])
    se3 = torch.cat([so3, r], -1).type(torch.float32)
    se3 = torch.cat([se3, filler], -2)
    return se3

def do_se3_action(se3, x):
    # [tuple1, 4, 4]
    # x = [tuple2, 3]
    # x_hat = [tuple2, tuple1, 3]
    
    tuple1 = list(se3.shape[:-2])
    tuple2 = list(x.shape[:-1])
    
    ones1 = [1]*len(tuple1)
    ones2 = [1]*len(tuple2)
    
    se3 = se3.view(tuple1 + ones2 + [4,4])
    x_hat = torch.cat([x, torch.ones(tuple2+[1])], -1)
    x_hat = se3 @ x_hat.view(ones1 + tuple2 + [4,1])
    
    return x_hat[...,[0,1,2],0]

# plotting
def z_diff(z, z_var):
    return ((z-z[0]) - (z_var-z_var[0]).squeeze()).abs().mean()

def print_progress(x, x_recon, x_label='original', 
                   x_recon_label='recon',  s=100):
    fig = plt.figure(figsize=(5,5))
    plt.scatter(x_recon.detach().numpy()[...,0], x_recon.detach().numpy()[...,1],
                c='r', alpha=0.5, s=s, label=x_recon_label)
    plt.scatter(x.numpy()[...,0], x.numpy()[...,1],
                c='g', alpha=0.5, s=s, label=x_label)
    plt.xlim(-8,8)
    plt.ylim(-8,8);
    plt.legend()
    

#### create true data

In [None]:
N_POINTS = 5

def create_true_data(n_points=N_POINTS, lie_group='se3', no_rot=False, show=False):
    
    v = torch.tensor(np.random.normal(0, 2, (3)))
    v0 = torch.tensor(np.zeros((3))) + 1e-5
    
    if lie_group == 'se3':
        r = torch.tensor(np.random.normal(0, 1, (3,1)))
    elif lie_group == 'so3':
        r = torch.tensor(np.random.normal(0, 0, (3,1)))
    else:
        raise Exception('use either so3 or se3')
    
    if no_rot:
        v = v0

    so3 = rodrigues(v)
    se3 = get_se3(so3, r)

    # generate point cloud 
    cloud = get_pointcloud(n_points)
    rotated_cloud = do_se3_action(se3, cloud)

    cloud_xy = cloud[...,:-1]
    cloud_z = cloud[...,-1]

    rotated_cloud_xy = rotated_cloud[...,:-1]
    rotated_cloud_z = rotated_cloud[..., -1]
    
    if show:
        print_progress(cloud_xy, rotated_cloud_xy, 'original', 'rotated')
    
    return cloud, cloud_xy, cloud_z,\
            rotated_cloud, rotated_cloud_xy, rotated_cloud_z, se3
    

#### train data

In [None]:
N_ITER = 10000

def create_train_data(cloud, use_z=True, rot=True, trans=False):
    optim_params = []
    
    # initialize z_axis
    z_var = cloud[...,-1].clone().unsqueeze(-1)
    if use_z:
        z_var = torch.tensor(np.random.normal(0., 1.,
                                          (list(cloud.shape[:-1]) + [1])),
                             dtype=torch.float32,
                            requires_grad=True)
        optim_params.append(z_var)
        
    # se3 elements to update
    v_var = torch.tensor(np.zeros((3))) + 1e-5
    if rot:
        v_var = torch.tensor(np.random.normal(0, 1, (3)), requires_grad=True)
        optim_params.append(v_var)
    
    r_var = torch.tensor(np.random.normal(0, 0, (3,1)))
    if trans:
        r_var = torch.tensor(np.random.normal(0, 1, (3,1)), requires_grad=True)
        optim_params.append(r_var)
    
    return z_var, v_var, r_var, optim_params   
        
    
    

#### trainer

In [None]:
class trainer:
    """Model to estimate rotation and depth"""
    
    def __init__(self, true_data, train_data):
        self.i_trained = 0
        self.cloud, self.cloud_xy, self.cloud_z, self.rotated_cloud,\
            self.rotated_cloud_xy, self.rotated_cloud_z, self.se3 = true_data
        self.z_var, self.v_var, self.r_var, optim_params = train_data
        
        self.optimize = Adam(optim_params)

    def forward(self, v, r, xy):
        so3 = rodrigues(v)
        se3 = get_se3(so3, r)

        xyz = torch.cat([xy, self.z_var], -1)
        xyz_hat = do_se3_action(se3, xyz)

        return xyz_hat[...,:-1], xyz_hat[...,-1]

    def loss(self, xy, xy_hat):
        # xy = (tuple2, 2)
        # xy_hat = (tuple1, tuple2, 2)
        l = (xy - xy_hat)**2
        l = l.sum(-1).mean()

        return l
 
     
    
    def train(self, n_iter=10000, print_freq=500, plot_freq=2000):
        print('train model with %d points in cloud:' % (N_POINTS))
        for i in range(n_iter):
            self.i_trained += 1
            self.optimize.zero_grad()
            xy_hat, z_hat = self.forward(self.v_var, self.r_var, self.cloud_xy)
            l = self.loss(self.rotated_cloud_xy, xy_hat)
            l.backward()
            self.optimize.step()

            if (i % print_freq) == 0:
                print('\r it:%d:\t loss: %.6f \t z_diff: %.3f' %
                      (self.i_trained, l, z_diff(self.cloud_z, self.z_var)))
            if (i % plot_freq) == 0:
                print_progress(self.rotated_cloud_xy, xy_hat)
        
        self.plot_res()
        
        
        
        
    def plot_res(self):
        so3_rec = rodrigues(self.v_var)
        se3_rec = get_se3(so3_rec, self.r_var)

        print('\nse3 analysis')
        print('true')
        print(self.se3)
        print('rec')
        print(se3_rec)
        print('diff')
        print(self.se3 - se3_rec)

        print('\nz analysis')
        cloud_z_rel = (self.cloud_z - self.cloud_z[0])
        z_var_rel  = (self.z_var - self.z_var[0]).squeeze()
        print('true')
        print(cloud_z_rel)
        print('rec')
        print(z_var_rel)
        print('diff')
        print(cloud_z_rel - z_var_rel)

#### Different Experiment Settings:
- fixing z
- no translation
- no rotation

#### fixing z

In [None]:
true_data = create_true_data(lie_group='se3', no_rot=False, show=False)
train_data = create_train_data(true_data[0], use_z=False)

model = trainer(true_data, train_data)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=1500)

#### no translation (SO3)

In [None]:
true_data = create_true_data(lie_group='so3', no_rot=False, show=False)
train_data = create_train_data(true_data[0], trans=False)

model = trainer(true_data, train_data)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=1500)

#### no rotation

In [None]:
true_data = create_true_data(lie_group='se3', no_rot=True, show=False)
train_data = create_train_data(true_data[0], rot=False)

model = trainer(true_data, train_data)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=1500)