In [0]:
!pip install pykeops[full] > install.log

In [0]:
from google.colab import files
src = list(files.upload().values())[0]
open('mesh.py','wb').write(src)
import mesh

Saving mesh.py to mesh (2).py


In [0]:
import os
import numpy as np
import pickle

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from plotly.offline import init_notebook_mode, iplot
import plotly.graph_objs as go

from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
import imageio

from torch.autograd import grad

import time

from pykeops.torch import Kernel, kernel_product
from pykeops.torch.kernel_product.formula import *

from scipy.optimize import minimize
from scipy.spatial import Delaunay

In [0]:
def configure_plotly_browser_state():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
            },
          });
        </script>
        '''))

In [0]:
# torch type and device
use_cuda = torch.cuda.is_available()
torchdeviceId = torch.device('cuda:0') if use_cuda else 'cpu'
torchdtype = torch.float32

In [0]:
def GaussKernel(sigma):
    def K(x, y, b):
        params = {
            'id': Kernel('gaussian(x,y)'),
            'gamma': 1 / (sigma * sigma),
            'backend': 'auto'
        }
        return kernel_product(params, x, y, b)
    return K
  
def GaussLinKernel(sigma):
    def K(x, y, u, v, b):
        params = {
            'id': Kernel('gaussian(x,y) * linear(u,v)'),
            'gamma': (1 / (sigma * sigma), None),
            'backend': 'auto'
        }
        return kernel_product(params, (x, u), (y, v), b)
    return K
  
def sumGaussLinKernel(sigma1, sigma2):
    f1 = GaussLinKernel(sigma1)
    f2 = GaussLinKernel(sigma2)
    def K(x, y, u, v, b):
      return f1(x, y, u, v, b) + f2(x, y, u, v,b)
    return K

In [0]:
def lossHippSurfNp(FSj, fmap, VH, FH, K):
    def compCN(V, F):
        V0, V1, V2 = V.index_select(0, F[:, 0]), V.index_select(0, F[:, 1]), V.index_select(0, F[:, 2])
        C = (V0 + V1 + V2) / 3
        N = 0.5 * torch.cross(V1 - V0, V2 - V0)

        return C, N

    CT, NT = compCN(VH, FH)
    BT = torch.ones([CT.shape[0], 1], dtype = torchdtype, device = torchdeviceId)
    cst = torch.dot(K(CT, CT, NT, NT, BT).view(-1), torch.ones_like(K(CT, CT, NT, NT, BT)).view(-1))

    def loss(qn, wv):

        Qd = mesh.doubleQ(qn)
        #VS = mesh.generateSourceULW(Qd, wv, FSj, fmap)
        VS = mesh.generateSourceULfast(Qd, wv, FSj, fmap)

        CS, NS = compCN(VS, FSj)

        BS = torch.ones([CS.shape[0], 1], dtype=torchdtype, device = torchdeviceId)
        a = K(CS, CS, NS, NS, BS)
        CSdot = torch.dot(a.view(-1), torch.ones_like(a).view(-1))
        b = K(CS, CT, NS, NT, BT)
        CSTdot = torch.dot(b.view(-1), torch.ones_like(b).view(-1))

        cost = cst + CSdot - 2*CSTdot
        
        return cost
    
    return loss

In [0]:
def RalstonIntegrator():
    def f(ODESystem, x0, nt, deltat=1.0):
        x = tuple(map(lambda x: x.clone(), x0))
        dt = deltat / nt
        l = [x]
        for i in range(nt):
            xdot = ODESystem(*x)
            xi = tuple(map(lambda x, xdot: x + (2 * dt / 3) * xdot, x, xdot))
            xdoti = ODESystem(*xi)
            x = tuple(map(lambda x, xdot, xdoti: x + (.25 * dt) * (xdot + 3 * xdoti), x, xdot, xdoti))
            l.append(x)
        return l

    return f
  
def Hamiltonian(K):
    def H(p, q):
      return .5 * (p * K(q, q, p)).sum()
    return H
  
def HamiltonianSystem(K):
    H = Hamiltonian(K)
    def HS(p, q):
        Gp, Gq = grad(H(p, q), (p, q), create_graph=True)
        return -Gq, Gp
    return HS
  
def Shooting(p0, q0, K, nt=10, Integrator=RalstonIntegrator()):
    return Integrator(HamiltonianSystem(K), (p0, q0), nt)

def TotalLossInteg(K, dataloss, gamma = 0):
    def loss(p0, q0, w):
        p,q = Shooting(p0, q0, K)[-1]
        return gamma * Hamiltonian(K)(p0, q0) + dataloss(q, w)
    return loss

In [0]:
class lossWrapper():
  
  def __init__(self, fun):
    self.obj = fun
    
  def loss(self, param):
    return self.obj(param)

In [0]:
class PytorchObjective(object):
  
  def __init__(self, objfun, param, q, dtype, device):
    self.f = objfun #loss function
    self.x0 = param.cpu().data.numpy()
    self.pshape = q.shape
    self.q0 = q
    self.dtype = dtype
    self.device = device
    
  def is_new(self, x):
      # if this is the first thing we've seen
      if not hasattr(self, 'cached_x'):
          return True
      else:
          # compare x to cached_x to determine if we've been given a new input
          x, self.cached_x = np.array(x), np.array(self.cached_x)
          error = np.abs(x - self.cached_x)
          return error.max() > 1e-8
 
  def conv_param(self, x):
      psect = x[0:self.pshape[0]*self.pshape[1]]
      convp = torch.from_numpy(psect).view(self.pshape[0],self.pshape[1])
      wsect = x[self.pshape[0]*self.pshape[1]:]
      convw = torch.from_numpy(wsect)
      return convp, convw
  
  def cache(self, x):
      #convert x to tensor
      ptensor, wtensor = self.conv_param(x)
      ptensor = ptensor.to(self.device).type(self.dtype).requires_grad_(True)
      wtensor = wtensor.to(self.device).type(self.dtype).requires_grad_(True)
      # store the raw array
      self.cached_x = x
      # calculate the objective
      L = self.f(ptensor, q0, wtensor)
      # backprop the objective
      L.backward()
      self.cached_f = L.item()
      print(self.cached_f)
      print('w =', wtensor)
      pgrad = ptensor.grad.type(torch.float64).cpu().data.numpy().ravel()
      wgrad = wtensor.grad.type(torch.float64).cpu().data.numpy().ravel()
      self.cached_jac = np.concatenate([pgrad, wgrad])
      
  def fun(self, x):
      if self.is_new(x):
          self.cache(x)
      return self.cached_f

  def jac(self, x):
      if self.is_new(x):
          self.cache(x)
      return self.cached_jac

In [0]:
src = list(files.upload().keys())[0]
with open(src, 'rb') as input:
    surface = pickle.load(input)
    


Saving spline_splines_4_100.df to spline_splines_4_100 (3).df


Saving targetV to targetV (3)


Saving targetFf to targetFf (2)


In [0]:
vsrc = list(files.upload().keys())[0]
with open(vsrc, 'rb') as input:
    VH = pickle.load(input)
    
fsrc = list(files.upload().keys())[0]
with open(fsrc, 'rb') as input:
    FH = pickle.load(input)

Saving targetV to targetV (6)


Saving targetFf to targetFf (4)


In [0]:
vsrcds = list(files.upload().keys())[0]
with open(vsrcds, 'rb') as input:
    VHds = pickle.load(input)
    
fsrcds = list(files.upload().keys())[0]
with open(fsrcds, 'rb') as input:
    FHds = pickle.load(input)

Saving targetVds to targetVds (1)


Saving targetFds to targetFds (1)


In [0]:
def meshSynthTarget(m, n, a, w):
    
    def felev(x,y):
        z = a*np.sin(np.pi*x)
        return z, a*np.pi*np.cos(np.pi*x), np.zeros_like(z)
    
    x_grid_axis = np.linspace(0, 2, m)
    y_grid_axis = np.linspace(0, 1, n)
   
    x_grid, y_grid = np.meshgrid(x_grid_axis, y_grid_axis)
    x_grid = x_grid.flatten()
    y_grid = y_grid.flatten()
    points_to_triangulate = np.vstack([x_grid, y_grid]).transpose()
    F = Delaunay(points_to_triangulate).simplices
    
    z_grid, dxz, dyz = felev(x_grid, y_grid)
    
    midz_grid = np.zeros_like(z_grid)
    midvtcs = np.vstack([x_grid, y_grid, midz_grid]).transpose()
    
    refvtcs = np.vstack([x_grid, y_grid, z_grid]).transpose()
    numrefvtcs = np.shape(x_grid)[0]
    
    normals = np.ones((numrefvtcs, 3))
    normals[:, 0] = -dxz
    normals[:, 1] = -dyz
    norm = np.sqrt(np.sum(normals**2, axis = 1))
    normals = normals/np.tile(norm, (3, 1)).transpose()
    
    uppervtcs = refvtcs + w*normals
    lowervtcs = refvtcs - w*normals
    
    vtcsfull = np.vstack([uppervtcs, lowervtcs])

    lowerF = F.copy()
    lowerF[:, 0] = F[:, 1]
    lowerF[:, 1] = F[:, 0]

    upr_edgevtcs = m*np.arange(n) + m - 1
    lowr_edgevtcs = upr_edgevtcs + numrefvtcs
    redgeFa = np.vstack([upr_edgevtcs[0:-1], lowr_edgevtcs[0:-1], lowr_edgevtcs[1:]]).transpose()
    redgeFb = np.vstack([lowr_edgevtcs[1:], upr_edgevtcs[1:], upr_edgevtcs[0:-1]]).transpose()
    redgeF = np.vstack([redgeFa, redgeFb])
    
    upl_edgevtcs = m*np.arange(n)
    lowl_edgevtcs = upl_edgevtcs + numrefvtcs
    ledgeFa = np.vstack([upl_edgevtcs[0:-1], lowl_edgevtcs[0:-1], lowl_edgevtcs[1:]]).transpose()
    ledgeFb = np.vstack([lowl_edgevtcs[1:], upl_edgevtcs[1:], upl_edgevtcs[0:-1]]).transpose()
    ledgeF = np.vstack([ledgeFa, ledgeFb])
    
    ledgeF_flip = ledgeF.copy()
    ledgeF_flip[:, 0] = ledgeF[:, 1]
    ledgeF_flip[:, 1] = ledgeF[:, 0]

    Ffull = np.vstack([F, lowerF + numrefvtcs, redgeF, ledgeF_flip])

    tV = torch.as_tensor(vtcsfull, dtype = torch.float)
    tF = torch.as_tensor(Ffull, dtype = torch.long)
    tVmid = torch.as_tensor(midvtcs, dtype = torch.float)
    tFmid = torch.as_tensor(F, dtype = torch.long)
    
    return tV, tF, tVmid, tFmid
 

In [0]:
num_points = 50
source = mesh.downsample(surface, 50, 50)
Q, FS = mesh.meshSource(source)

In [0]:
#w = 4*torch.ones(50*50)
w = torch.tensor([4])
sigma = torch.tensor([1], dtype=torchdtype, device=torchdeviceId)

Qd = mesh.doubleQ(Q)
Fjoined = mesh.joinFlip(FS, 50, 50)
facemap = mesh.incidentFaceMap(2*50*50, Fjoined)

In [0]:
VHs, FHs, Q, FS = meshSynthTarget(30, 15, 0.2, 0.2)

In [0]:
w = torch.tensor([0.1])
sigma1 = torch.tensor([0.1], dtype=torchdtype, device=torchdeviceId)
sigma2 = torch.tensor([0.2], dtype=torchdtype, device=torchdeviceId)

Qd = mesh.doubleQ(Q)
Fjoined = mesh.joinFlip(FS, 30, 15)
facemap = mesh.incidentFaceMap(2*30*15, Fjoined)

In [0]:
q0 = Q.clone().detach().to(dtype = torchdtype, device = torchdeviceId).requires_grad_(True)
w0 = w.clone().detach().to(dtype = torchdtype, device = torchdeviceId).requires_grad_(True)
VH = VHs.clone().detach().to(dtype = torchdtype, device = torchdeviceId)
FH = FHs.clone().detach().to(dtype = torch.long, device = torchdeviceId)
Fjoined = Fjoined.clone().detach().to(dtype = torch.long, device = torchdeviceId)

In [0]:
dataloss = lossHippSurfNp(Fjoined, facemap, VH, FH, sumGaussLinKernel(sigma1, sigma2))

sigmadiff1 = torch.tensor([0.4], dtype=torchdtype, device=torchdeviceId)

loss = TotalLossInteg(GaussKernel(sigmadiff1), dataloss, gamma = 0.001)

In [0]:
p0 = torch.zeros(q0.shape, dtype = torchdtype, device = torchdeviceId, requires_grad = True)

In [0]:
pw = torch.cat((p0.flatten(), w0.flatten()))

In [0]:
obj = PytorchObjective(loss, pw, q0, torchdtype, torchdeviceId)

In [0]:
res = minimize(obj.fun, obj.x0, method = 'L-BFGS-B', jac = obj.jac, options = {'disp':True, 'maxiter':20})

0.6610786318778992
w = tensor([0.1000], device='cuda:0', requires_grad=True)
9.437012672424316
w = tensor([0.0705], device='cuda:0', requires_grad=True)
2.087275505065918
w = tensor([0.0917], device='cuda:0', requires_grad=True)
0.4713042974472046
w = tensor([0.0978], device='cuda:0', requires_grad=True)
0.34626972675323486
w = tensor([0.0993], device='cuda:0', requires_grad=True)
0.3162565231323242
w = tensor([0.1052], device='cuda:0', requires_grad=True)
0.25635185837745667
w = tensor([0.1290], device='cuda:0', requires_grad=True)
0.17728422582149506
w = tensor([0.1622], device='cuda:0', requires_grad=True)
0.17293085157871246
w = tensor([0.2097], device='cuda:0', requires_grad=True)
0.11901816725730896
w = tensor([0.1843], device='cuda:0', requires_grad=True)
0.07183104008436203
w = tensor([0.1923], device='cuda:0', requires_grad=True)
0.05228741839528084
w = tensor([0.2015], device='cuda:0', requires_grad=True)
0.04830857366323471
w = tensor([0.1971], device='cuda:0', requires_grad

In [0]:
pres, wres = obj.conv_param(res.x)
prestens = pres.to(torchdeviceId).type(torchdtype).requires_grad_(True)
wrestens = wres.to(torchdeviceId).type(torchdtype).requires_grad_(True)

In [0]:
wres

tensor([0.1971], dtype=torch.float64)

In [0]:
nt = 15
pqlist = Shooting(prestens, q0, K, nt=nt)

In [0]:
configure_plotly_browser_state()
init_notebook_mode(connected=False)
FS = FS.clone().detach().to(dtype = torch.long, device = torchdeviceId)
CMS, NMS = mesh.compCN(pqlist[-1][1], FS)
figMS = mesh.visualize(pqlist[-1][1].cpu(), FS.cpu(), CMS.cpu(), NMS.cpu(), 'Reds', normals = False)

figMS = go.Figure(data = [figMS.data[0], figMS.data[1], figMS.data[2]])

figMS['layout']['scene'].update(go.layout.Scene(
    aspectmode='manual',
    aspectratio=go.layout.scene.Aspectratio(
        x=2, y=1, z=0.4
    )
))

iplot(figMS)

In [0]:
VSlist = []
for i in range(len(pqlist)):
  Qd = mesh.doubleQ(pqlist[i][1])
  VS = mesh.generateSourceULfast(Qd.cpu(), wres.cpu().type(torchdtype), Fjoined.cpu(), facemap)
  VSlist.append(VS)

In [0]:
configure_plotly_browser_state()
init_notebook_mode(connected=False)

CS, NS = mesh.compCN(VSlist[-1].cpu(), Fjoined.cpu())
figS = mesh.visualize(VSlist[-1].cpu(), Fjoined.cpu(), CS.cpu(), NS.cpu(), 'Blues', normals = True)

figSonly = go.Figure(data = [figS.data[0], figS.data[1], figS.data[2]])
iplot(figSonly)

In [0]:
configure_plotly_browser_state()
init_notebook_mode(connected=False)

CHs, NHs = mesh.compCN(VHs, FHs)

figT = mesh.visualize(VHs.cpu(), FHs.cpu(), CHs.cpu(), NHs.cpu(),'Portland')

figcomb = go.Figure(data = [figMS.data[0], figMS.data[1], figMS.data[2],
                            figS.data[0], figS.data[1], figS.data[2],
                           figT.data[0], figT.data[1], figT.data[2]])


figcomb['layout']['scene'].update(go.layout.Scene(
    aspectmode='manual',
    aspectratio=go.layout.scene.Aspectratio(
        x=2, y=1, z=0.8
    )
))


iplot(figcomb)

In [0]:
configure_plotly_browser_state()
init_notebook_mode(connected=False)

CMSo, NMSo = mesh.compCN(pqlist[0][1], FS)
figMSo = mesh.visualize(pqlist[0][1].cpu(), FS.cpu(), CMS.cpu(), NMS.cpu(), 'Reds', normals = False)
figMSo = go.Figure(data = [figMSo.data[0], figMSo.data[1], figMSo.data[2]])

Qdo = mesh.doubleQ(pqlist[0][1])
VSo = mesh.generateSourceULW(Qdo.cpu(), w0.cpu(), Fjoined.cpu(), facemap)
figSo = mesh.visualize(VSo.cpu(), Fjoined.cpu(), CS.cpu(), NS.cpu(), 'Blues', normals = True)
figSo = go.Figure(data = [figSo.data[0], figSo.data[1], figSo.data[2]])

figcomb = go.Figure(data = [figMSo.data[0], figMSo.data[1], figMSo.data[2],
                            figSo.data[0], figSo.data[1], figSo.data[2],
                           figT.data[0], figT.data[1], figT.data[2]])

iplot(figcomb)

In [0]:
w0

tensor([4., 4., 4.,  ..., 4., 4., 4.], device='cuda:0', requires_grad=True)

In [0]:
wres

tensor([3.9679, 4.0078, 4.0135,  ..., 4.0303, 4.0215, 3.9396],
       dtype=torch.float64)