In [108]:
import numpy as np
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter
from torch.autograd import Variable
import ot # optimal transport solver

In [109]:
# Create data points, e.g. noisy points on a 2-d plane living in 3-d
num_points = 40
noise_var = 0.05

def sample_plane(num_points, noise_var):
    X = torch.empty(num_points, 3).uniform_(0, 2)
    X[:,2] = 0
    noise = noise_var * torch.randn(num_points, 3)
    
    return X.double(), noise.double()

def sample_sphere(num_points, noise_var):
    X = np.random.randn(3, num_points)
    X /= np.linalg.norm(X, axis = 0)
    X = X.T
    X_torch = torch.from_numpy(X)
    noise = noise_var * torch.randn(num_points, 3)
    
    return X_torch.double(), noise.double()

def pairwise_distance_matrix(X, Y=None):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    X_norm = (X**2).sum(1).view(-1, 1)
    if Y is not None:
        Y_t = torch.transpose(Y, 0, 1)
        Y_norm = (Y**2).sum(1).view(1, -1)
    else:
        Y_t = torch.transpose(X, 0, 1)
        X_norm = X_norm.view(1, -1)
    
    Dist = X_norm + Y_norm - 2.0 * torch.mm(X, Y_t)
    # Ensure diagonal is zero if X=Y
    # if Y is None:
    #     Dist = Dist - torch.diag(Dist.diag)
    return Dist #torch.clamp(Dist, 0.0, np.inf)

X, noise = sample_plane(num_points, noise_var)
X = X.float()
Y = X + noise.float()
noise_level = (noise**2).sum().sum()

In [110]:
# Build network
# num_params = 2*hidden1 + hidden1*3
from torchsummary import summary
dtype = torch.float
device = torch.device('cpu')

N = num_points
D_in = 2
D_out = 3
hidden1 = 200
hidden2 = 400

# Models
v1 = torch.randn((N, D_in), device = device, dtype = dtype, requires_grad = True)
v2 = torch.randn((N, D_in), device = device, dtype = dtype, requires_grad = True)
model1 = torch.nn.Sequential(
        torch.nn.Linear(D_in, hidden1),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden1, hidden2),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden2, D_out))
model2 = torch.nn.Sequential(
        torch.nn.Linear(D_in, hidden1),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden1, hidden2),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden2, D_out))
Y1_pred = model1(v1)
Y2_pred = model2(v2)

def project_stochastic_to_perm(P):
    max_in_rows = np.amax(P, axis = 1)
    max_list = [np.where(P == max_in_rows[ii]) for ii in range(P.shape[0])]
    for ii in range(len(max_list)):
        P[max_list[ii]] = 1
    Omega = np.array((P == 1))
    return np.multiply(Omega, P)

In [111]:
# Optimization preliminaries:
error_train_net1 = []
error_test_net1 = []
error_train_net2 = []
error_test_net2 = []

loss_fn = torch.nn.MSELoss(reduction = "sum")
iteration = 0
maxiter = 10000
lr = 1e-4
lam = 0.01
beta = 0 # if beta > 0, then we add a consistency loss when fitting local patches
optimizer_w1 = torch.optim.Adam(model1.parameters(), lr)
optimizer_w2 = torch.optim.Adam(model2.parameters(), lr)

# Get permutation matrices using EMD or sinkhorn
C1 = pairwise_distance_matrix(Y1_pred, Y).double()
C2 = pairwise_distance_matrix(Y2_pred, Y).double()
a = np.ones((num_points,))
b = a    
C1hat = C1.clone().detach().numpy()
C2hat = C2.clone().detach().numpy()
# Use EMD
P1hat = ot.emd(a, b, C1hat, numItermax=100000, log=False) 
P2hat = ot.emd(a, b, C2hat, numItermax=100000, log=False)
# Or Sinkhorn
#P1hat = ot.sinkhorn(a, b, C1hat, lam, method='sinkhorn', numItermax=100000, stopThr=1e-09)
#P2hat = ot.sinkhorn(a, b, C2hat, lam, method='sinkhorn', numItermax=100000, stopThr=1e-09)

P1hat_inv = np.linalg.inv(P1hat)
P2hat_inv = np.linalg.inv(P2hat)

P1 = Variable(torch.from_numpy(P1hat), requires_grad=False)
P2 = Variable(torch.from_numpy(P2hat), requires_grad=False)

# Get permutation pi_p->q between parametric indices
P3 = Variable(torch.from_numpy(np.matmul(P1hat, P2hat_inv)), requires_grad = False)

In [112]:
# Step 1: fit local patches to the surface
print("Fitting of local patches...")
while iteration < maxiter:
        
    # Prediction & pairwise distance matrices 
    Y1_pred = model1(v1)
    Y2_pred = model2(v2)
    C1 = pairwise_distance_matrix(Y1_pred, Y).double()
    C2 = pairwise_distance_matrix(Y2_pred, Y).double()
    C1x = pairwise_distance_matrix(Y1_pred, X).double()
    C2x = pairwise_distance_matrix(Y2_pred, X).double()
    
    # Losses
    # Train loss
    train_loss_net1 = torch.mul(P1, C1).sum().sum()
    train_loss_net2 = torch.mul(P2, C2).sum().sum()
    if beta > 0:
        C3 = pairwise_distance_matrix(Y1_pred, Y2_pred).double()
        train_consistency = torch.mul(P3, C3).sum().sum()
        total_train_loss = train_loss_net1 + train_loss_net2 + beta*train_consistency
    else:
        total_train_loss = train_loss_net1 + train_loss_net2
    
    # Test loss
    test_loss_net1 = torch.mul(P1, C1x).sum().sum() 
    test_loss_net2 = torch.mul(P2, C2x).sum().sum()    
        
    # Optimization
    optimizer_w1.zero_grad()
    optimizer_w2.zero_grad()
    total_train_loss.backward()
    optimizer_w1.step()
    optimizer_w2.step()
        
    if iteration % 500 == 0:
        print("Iteration:" + str(iteration) 
                  + ", Training loss (net 1):" + str(train_loss_net1.item()) 
                  + ", Test loss (net 1):" + str(test_loss_net1.item()))
        print("Iteration:" + str(iteration) 
                  + ", Training loss (net 2):" + str(train_loss_net2.item()) 
                  + ", Test loss (net 2):" + str(test_loss_net2.item()))
        print("Noise level:" + str(noise_level.item()))

    error_train_net1.append(train_loss_net1.item())
    error_train_net2.append(train_loss_net2.item())
    error_test_net1.append(test_loss_net1.item())
    error_test_net2.append(test_loss_net2.item())
    
    iteration += 1
print('--------------------------------------------------------------')

Fitting of local patches...
Iteration:0, Training loss (net 1):93.33770030736923, Test loss (net 1):95.06112602353096
Iteration:0, Training loss (net 2):101.18487125635147, Test loss (net 2):102.99303275346756
Noise level:0.28329038341339585
Iteration:500, Training loss (net 1):1.826964259147644, Test loss (net 1):2.057450234889984
Iteration:500, Training loss (net 2):1.7186776399612427, Test loss (net 2):1.930122435092926
Noise level:0.28329038341339585
Iteration:1000, Training loss (net 1):1.0343542695045471, Test loss (net 1):1.3066586256027222
Iteration:1000, Training loss (net 2):0.9485240578651428, Test loss (net 2):1.1981324553489685
Noise level:0.28329038341339585
Iteration:1500, Training loss (net 1):0.6668016910552979, Test loss (net 1):0.9530415534973145
Iteration:1500, Training loss (net 2):0.5634211301803589, Test loss (net 2):0.8280289769172668
Noise level:0.28329038341339585
Iteration:2000, Training loss (net 1):0.4089251756668091, Test loss (net 1):0.6861580014228821
It

In [113]:
# Step 2: Ensure that the output of the networks is consistent with one another
print("Consistency fitting...")
consistency_loss = []
iteration = 0
if beta > 0:
    maxiter = 0
else:
    maxiter = 1000

while iteration < maxiter:
    
    # Prediction & pairwise distance matrices
    Y1_pred = model1(v1)
    Y2_pred = model2(v2)
    C1 = pairwise_distance_matrix(Y1_pred, Y).double()
    C2 = pairwise_distance_matrix(Y2_pred, Y).double()    
    C3 = pairwise_distance_matrix(Y1_pred, Y2_pred).double()
    C1x = pairwise_distance_matrix(Y1_pred, X).double()
    C2x = pairwise_distance_matrix(Y2_pred, X).double()   
            
    # Losses
    train_loss_net1 = torch.mul(P1, C1).sum().sum()
    train_loss_net2 = torch.mul(P2, C2).sum().sum()
    train_consistency = torch.mul(P3, C3).sum().sum()
    test_loss_net1 = torch.mul(P1, C1x).sum().sum() 
    test_loss_net2 = torch.mul(P2, C2x).sum().sum()  
    
    # Optimization
    optimizer_w1.zero_grad()
    optimizer_w2.zero_grad()
    train_consistency.backward()
    optimizer_w1.step()
    optimizer_w2.step()
    
    if iteration % 500 == 0:
        print("Iteration:" + str(iteration) + ", Training loss (net 1):" + str(train_loss_net1.item())
              + ", Test loss (net 1):" + str(test_loss_net1.item()))
        print("Iteration:" + str(iteration) + ", Training loss (net 2):" + str(train_loss_net2.item())
              + ", Test loss (net 2):" + str(test_loss_net2.item()))
        print("Iteration:" + str(iteration) + ", Consistency loss:" + str(train_consistency.item()))    
        print("Noise level:" + str(noise_level.item()))
        
    error_train_net1.append(train_loss_net1.item())
    error_train_net2.append(train_loss_net2.item())    
    error_test_net1.append(test_loss_net1.item())
    error_test_net2.append(test_loss_net2.item())
    
    iteration += 1
print('--------------------------------------------------------------')

Consistency fitting...
Iteration:0, Training loss (net 1):0.0007739067077636719, Test loss (net 1):0.2833601236343384
Iteration:0, Training loss (net 2):6.818771362304688e-05, Test loss (net 2):0.28439807891845703
Iteration:0, Consistency loss:0.0005787014961242676
Noise level:0.28329038341339585
Iteration:500, Training loss (net 1):0.09056711196899414, Test loss (net 1):0.4015588164329529
Iteration:500, Training loss (net 2):0.09056729078292847, Test loss (net 2):0.401559054851532
Iteration:500, Consistency loss:-3.5762786865234375e-07
Noise level:0.28329038341339585
--------------------------------------------------------------


In [114]:
#%matplotlib widget
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import axes3d
Yhat1 = Y1_pred.detach().numpy()
Yhat2 = Y2_pred.detach().numpy()
#yt_limit = noise_level + 1e-2
#yb_limit = max(noise_level - 1e-2, 1e-6)

# Plot all losses(training and test)
plt.figure()
plt.plot(error_train_net1, color='b', linestyle='-', label='Train Loss (net 1)')
plt.plot(error_test_net1, color='b', linestyle='--', label='Test Loss (net 1)')
plt.plot(error_train_net2, color='r', linestyle='-', label='Train Loss (net 2)')
plt.plot(error_test_net2, color='r', linestyle='--',  label='Test Loss (net 2)')
plt.axhline(y=noise_level, color='m', linestyle='-.', label='noise level')
plt.ylabel('l2 loss')
plt.xlabel('number of iterations')
plt.yscale('log')
plt.title('Comparison of losses for each net over time')
plt.legend()

# Plot data points
fig, ax = plt.subplots(1, 1, subplot_kw={'projection':'3d', 'aspect':'equal'})
ax.scatter(Y[:100,0], Y[:100,1], Y[:100,2], s=30, c='b', zorder=10, label='noisy points')
ax.scatter(X[:100,0], X[:100,1], X[:100,2], s=30, c='r', zorder=10, label='clean points')
ax.legend()

# Plot reconstructed points
fig, ax = plt.subplots(1, 1, subplot_kw={'projection':'3d', 'aspect':'equal'})
ax.scatter(Yhat1[:100,0], Yhat1[:100,1], Yhat1[:100,2], s=30, c='b', zorder=10, label = 'recovered pts 1')
ax.scatter(Yhat2[:100,0], Yhat2[:100,1], Yhat2[:100,2], s=30, c='r', zorder=10, label = 'recovered pts 2')
ax.scatter(Y[:100,0], Y[:100,1], Y[:100,2], s = 30, c = 'g', zorder = 10, label = 'noisy pts')
ax.scatter(X[:100,0], X[:100,1], X[:100,2], s = 30, c = 'k', zorder = 10, label = 'clean pts')
ax.legend()

plt.show();

FigureCanvasNbAgg()

FigureCanvasNbAgg()

FigureCanvasNbAgg()