#### imports

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam
import matplotlib.pyplot as plt
%matplotlib inline

import lie_tools
from lie_tools import rodrigues

%config IPCompleter.greedy=True

In [9]:
import datetime
datetime.datetime.now().isoformat()

'2018-10-01T17:00:32.640046'

#### functions

In [None]:
# point cloud creation and se3 creation + action
def get_pointcloud(k):
    x = torch.tensor(np.random.normal(0, 2, (k, 3)), dtype=torch.float32)
    return x

def get_se3(so3, r):
    fill_shape = list(r.shape[:-2])
    filler = torch.tensor([[0.,0.,0.,1.]]).view([1]*\
                            (len(fill_shape)+1)+[4]).repeat(fill_shape + [1,1])
    se3 = torch.cat([so3, r], -1).type(torch.float32)
    se3 = torch.cat([se3, filler], -2)
    return se3

def do_se3_action(se3, x):
    # [tuple1, 4, 4]
    # x = [tuple2, 3]
    # x_hat = [tuple2, tuple1, 3]
    
    tuple1 = list(se3.shape[:-2])
    tuple2 = list(x.shape[:-1])
    
    ones1 = [1]*len(tuple1)
    ones2 = [1]*len(tuple2)
    
    se3 = se3.view(tuple1 + ones2 + [4,4])
    x_hat = torch.cat([x, torch.ones(tuple2+[1])], -1)
    x_hat = se3 @ x_hat.view(ones1 + tuple2 + [4,1])
    
    return x_hat[...,[0,1,2],0]

# plotting
def z_diff(z, z_var):
    d_pos = ((z-z[0]) - (z_var-z_var[0]).squeeze()).abs().mean()
    d_min = ((z-z[0]) - (-1*(z_var-z_var[0]).squeeze())).abs().mean()
    
    return min(d_pos, d_min)

def print_progress(x, x_recon, x_label='original', 
                   x_recon_label='recon', title='', s=100):
    fig = plt.figure(figsize=(5,5))
    plt.scatter(x_recon.detach().numpy()[...,0], x_recon.detach().numpy()[...,1],
                c='r', alpha=0.5, s=s, label=x_recon_label)
    plt.scatter(x.numpy()[...,0], x.numpy()[...,1],
                c='g', alpha=0.5, s=s, label=x_label)
    plt.xlim(-8,8)
    plt.ylim(-8,8);
    plt.legend()
    plt.title(title)

#### create true data

In [None]:
N_POINTS = 16
N_VIEWS = 1

def xy_z_decomp(xyz):
    return xyz[...,:-1], xyz[...,-1]

def create_true_data(n_points=N_POINTS, n_views=N_VIEWS,
                     lie_group='se3', no_rot=False, show=False):
    
    # generate point cloud 
    cloud = get_pointcloud(n_points)
    
    se3_elements = []
    rotated_clouds = []
    for view in range(n_views):
        v = torch.tensor(np.random.normal(0, 2, (3)))
        v0 = torch.tensor(np.zeros((3))) + 1e-5

        if lie_group == 'se3':
            r = torch.tensor(np.random.normal(0, 1, (3,1)))
        elif lie_group == 'so3':
            r = torch.tensor(np.random.normal(0, 0, (3,1)))
        else:
            raise Exception('use either so3 or se3')

        if no_rot:
            v = v0

        so3 = rodrigues(v)
        se3 = get_se3(so3, r)
        rotated_cloud = do_se3_action(se3, cloud)
        
        se3_elements.append(se3)
        rotated_clouds.append(rotated_cloud)

        if show:
            cloud_xy, _ = xy_z_decomp(cloud)
            rotated_cloud_xy, _ = xy_z_decomp(rotated_cloud)
            print_progress(cloud_xy, rotated_cloud_xy,
                           'original', ('rotated %d' % view), title='creation')
        
    return cloud, rotated_clouds, se3_elements
    

#### train data

In [None]:
def init_train_vars(cloud, n_views=N_VIEWS, use_z=True, rot=True, trans=True):
    optim_params = []
    v_vars, r_vars = [], []
    
    # initialize z_axis
    z_var = cloud[...,-1].clone().unsqueeze(-1)
    if use_z:
        z_var = torch.tensor(np.random.normal(0., 1.,
                                          (list(cloud.shape[:-1]) + [1])),
                             dtype=torch.float32,
                             requires_grad=True)
        optim_params.append(z_var)
  
    for view in range(n_views):
        # se3 rotation
        v_var = torch.tensor(np.zeros((3))) + 1e-5
        if rot:
            v_var = torch.tensor(np.random.normal(0, 1, (3)), requires_grad=True)
            optim_params.append(v_var)
        v_vars.append(v_var)
        
        # se3 translation
        r_var = torch.tensor(np.random.normal(0, 0, (3,1)))
        if trans:
            r_var = torch.tensor(np.random.normal(0, 1, (3,1)), requires_grad=True)
            optim_params.append(r_var)
        r_vars.append(r_var)
        
        
    return z_var, v_vars, r_vars, optim_params   
        

#### trainer

In [None]:
N_ITER = 10000

class depthEstimatorModel:
    """Model to estimate rotation and depth"""
    
    def __init__(self, data, train_vars):
        """
        Args:
            data: (cloud, rotated_clouds, train_vars)
            train_vars: (z_vars, v_vars, r_vars, optim_params)"""
        self.i_trained = 0
        self.cloud, self.rotated_clouds, self.se3_elements = data
        self.z_var, self.v_vars, self.r_vars, self.optim_params = train_vars
        self.optimizer = Adam(self.optim_params)
        
    def forward(self, v, r, z, xy):
        so3 = rodrigues(v)
        se3 = get_se3(so3, r)

        xyz = torch.cat([xy, z], -1)
        xyz_hat = do_se3_action(se3, xyz)

        return xy_z_decomp(xyz_hat)

    def loss(self, xy, xy_hat):
        # xy = (tuple2, 2)
        # xy_hat = (tuple1, tuple2, 2)
        l = (xy - xy_hat)**2
        l = l.sum(-1).mean()

        return l
 
    def train(self, n_iter=N_ITER, print_freq=500, plot_freq=2000):
        print('train model with %d points in cloud:' % (N_POINTS))
        for i in range(n_iter):
            self.i_trained += 1
            self.optimizer.zero_grad()
            l = 0
            
            if (i % print_freq) == 0:
                    print('  it:%d:' % self.i_trained)
                    
            for j, view in enumerate(self.rotated_clouds):
                cloud_xy, cloud_z = xy_z_decomp(self.cloud)
                rotated_cloud_xy, _ = xy_z_decomp(view)
                xy_hat, z_hat = self.forward(self.v_vars[j],
                                             self.r_vars[j],
                                             self.z_var, cloud_xy)
                l += self.loss(rotated_cloud_xy, xy_hat)
                
                if (i % print_freq) == 0:
                    print('\t view %d \t loss: %.6f \t z_diff: %.3f' %
                          (j, l, z_diff(cloud_z, self.z_var)))
                if (i % plot_freq) == 0:
                    print_progress(rotated_cloud_xy, xy_hat,
                                   x_label = 'view %d' % j,
                                   x_recon_label = 'recon',
                                   title='iter: %d' % self.i_trained)
            l.backward()
            self.optimizer.step()
        
        self.plot_res()
   
    def plot_res(self):
        print('-'*50)
        for i in range(len(self.rotated_clouds)):
            print('view %d' % i)
            so3_rec = rodrigues(self.v_vars[i])
            se3_rec = get_se3(so3_rec, self.r_vars[i])

            print('\nse3 analysis')
            print('true')
            print(self.se3_elements[i])
            print('rec')
            print(se3_rec)
            print('diff')
            print(self.se3_elements[i] - se3_rec)

            print('\nz analysis')
            _, cloud_z = xy_z_decomp(self.cloud)
            cloud_z_rel = (cloud_z - cloud_z[0])
            z_var_rel  = (self.z_var - self.z_var[0]).squeeze()
            print('true')
            print(cloud_z_rel)
            print('rec')
            print(z_var_rel)
            print('diff')
            print(cloud_z_rel - z_var_rel)
            
            print('-'*50)

#### Different Experiment Settings:
- fixing z
- no translation
- no rotation
- full se3

#### fixing z

In [None]:
data = create_true_data(lie_group='se3', show=False)
train_vars = init_train_vars(data[0], use_z=False)

model = depthEstimatorModel(data, train_vars)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=2500)

#### no translation (SO3)

In [None]:
data = create_true_data(lie_group='so3', show=False)
train_vars = init_train_vars(data[0], trans=False)

model = depthEstimatorModel(data, train_vars)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=2500)

#### no rotation

In [None]:
data = create_true_data(lie_group='se3', no_rot=True, show=False)
train_vars = init_train_vars(data[0], rot=False)

model = depthEstimatorModel(data, train_vars)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=2500)

#### full SE3

In [None]:
data = create_true_data(lie_group='se3', show=False)
train_vars = init_train_vars(data[0])

model = depthEstimatorModel(data, train_vars)

In [None]:
model.train(n_iter=10000, print_freq=1000,plot_freq=2500)