In [80]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir('/content/drive/MyDrive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [93]:
import numpy as np
import torch
import torch.nn.functional as F
import sys
# from scipy import sparse
import plotly.graph_objects as go

# Pytorch implementation of spherical conformal map proposed in https://scholar.harvard.edu/sites/scholar.harvard.edu/files/choi/files/2015_siims_flash.pdf

## Properties of triangulations

In [82]:
def tria_areas(v, t):
    """Compute the area of triangles using Heron's formula.

    `Heron's formula <https://en.wikipedia.org/wiki/Heron%27s_formula>`_
    computes the area of a triangle by using the three edge lengths.

    Returns
    -------
    areas : array
        Array with areas of each triangle.
    """
    v0 = v[t[:, 0], :]
    v1 = v[t[:, 1], :]
    v2 = v[t[:, 2], :]
    v1mv0 = v1 - v0
    v2mv1 = v2 - v1
    v0mv2 = v0 - v2
    a = torch.sqrt(torch.sum(v1mv0 * v1mv0, dim=1))
    b = torch.sqrt(torch.sum(v2mv1 * v2mv1, dim=1))
    c = torch.sqrt(torch.sum(v0mv2 * v0mv2, dim=1))
    ph = 0.5 * (a + b + c)
    areas = torch.sqrt(ph * (ph - a) * (ph - b) * (ph - c))
    return areas

def tria_qualities(v, t):
    """Compute triangle quality for each triangle in mesh where.

    q = 4 sqrt(3) A / (e1^2 + e2^2 + e3^2 )
    where A is the triangle area and ei the edge length of the three edges.
    Constants are chosen so that q=1 for the equilateral triangle.

    .. note::

        This measure is used by FEMLAB and can also be found in:
        R.E. Bank, PLTMG ..., Frontiers in Appl. Math. (7), 1990.

    Returns
    -------
    array
        Array with triangle qualities.
    """
    # Compute vertex coordinates and a difference vectors for each triangle:
    v0 = v[t[:, 0], :]
    v1 = v[t[:, 1], :]
    v2 = v[t[:, 2], :]
    v1mv0 = v1 - v0
    v2mv1 = v2 - v1
    v0mv2 = v0 - v2
    # Compute cross product
    n = torch.cross(v1mv0, -v0mv2)
    # compute length (2*area)
    ln = torch.sqrt(torch.sum(n * n, dim=1))
    q = 2.0 * 3**(0.5) * ln
    es = (v1mv0 * v1mv0).sum(1) + (v2mv1 * v2mv1).sum(1) + (v0mv2 * v0mv2).sum(1)
    return q / es

def euler(v, t):
    """Compute the Euler Characteristic.

    The Euler characteristic is the number of vertices minus the number
    of edges plus the number of triangles  (= #V - #E + #T). For example,
    it is 2 for the sphere and 0 for the torus.
    This operates only on triangles array.

    Returns
    -------
    int
        Euler characteristic.
    """
    # v can contain unused vertices so we get vnum from trias
    vnum = len(torch.unique(t.reshape(-1)))
    tnum = max(t.shape)

    t0 = t[:, 0]
    t1 = t[:, 1]
    t2 = t[:, 2]
    i = torch.column_stack((t0, t1, t1, t2, t2, t0)).reshape(-1)
    j = torch.column_stack((t1, t0, t2, t1, t0, t2)).reshape(-1)
    n = v.shape[0]
    adj_sym = torch.zeros((n, n))
    adj_sym[i, j] =  1
    enum = int(len(torch.nonzero(adj_sym)) / 2)
    return vnum - enum + tnum

## Laplacian Stiffness & Mass Solver

In [83]:
def computeAB(v, t):  # computeABtria(v,t)
    """Compute the 2 symmetric matices of the Laplace-Beltrami operator for a triangle mesh.

    The 2 symmetric matrices are computed for a given triangle mesh using the
    linear finite element method (assuming a closed mesh or Neumann boundary
    condition).

    Parameters
    ----------
    v : vertices
    t : faces

    Returns
    -------
    A : matrix of shape (n, n)
        symmetric positive semi definite matrix.
    B : matrix of shape (n, n)
        symmetric positive definite matrix.
    """


    # Compute vertex coordinates and a difference vector for each triangle:
    t1 = t[:, 0]
    t2 = t[:, 1]
    t3 = t[:, 2]
    v1 = v[t1, :]
    v2 = v[t2, :]
    v3 = v[t3, :]
    v2mv1 = v2 - v1
    v3mv2 = v3 - v2
    v1mv3 = v1 - v3
    # Compute cross product and 4*vol for each triangle:
    cr = torch.cross(v3mv2, v1mv3)
    vol = 2 * torch.sqrt(torch.sum(cr * cr, dim=1))
    # zero vol will cause division by zero below, so set to small value:
    vol_mean = 0.0001 * torch.mean(vol)
    vol[vol < sys.float_info.epsilon] = vol_mean
    # compute cotangents for A
    # using that v2mv1 = - (v3mv2 + v1mv3) this can also be seen by
    # summing the local matrix entries in the old algorithm
    a12 = torch.sum(v3mv2 * v1mv3, dim=1) / vol
    a23 = torch.sum(v1mv3 * v2mv1, dim=1) / vol
    a31 = torch.sum(v2mv1 * v3mv2, dim=1) / vol
    # compute diagonals (from row sum = 0)
    a11 = -a12 - a31
    a22 = -a12 - a23
    a33 = -a31 - a23
    # stack columns to assemble data
    local_a = torch.column_stack(
        (a12, a12, a23, a23, a31, a31, a11, a22, a33)
    ).reshape(-1)
    i = torch.column_stack((t1, t2, t2, t3, t3, t1, t1, t2, t3)).reshape(-1)
    j = torch.column_stack((t2, t1, t3, t2, t1, t3, t1, t2, t3)).reshape(-1)
    # Construct sparse matrix:
    # a = sparse.csr_matrix((local_a.detach().numpy(), (i.numpy(), j.numpy())))
    a = torch.zeros((len(v), len(v)))
    a.index_put_((i, j), local_a.float(), accumulate=True)
    # construct mass matrix (sparse or diagonal if lumped)
    # create b matrix data (account for that vol is 4 times area)
    b_ii = vol / 24
    b_ij = vol / 48
    local_b = torch.column_stack(
        (b_ij, b_ij, b_ij, b_ij, b_ij, b_ij, b_ii, b_ii, b_ii)
    ).reshape(-1)
    b = torch.zeros((len(v), len(v)))
    b.index_put_((i, j), local_b.float(), accumulate=True)
    return a, b

## Solvers

In [85]:
def symmetric_solve(A, b):
    """Symmetric solver for ``A x = b``.

    Parameters
    ----------
    A : matrix of shape (n, n)
        symmetric matrix.
    b : 1-D tensor of length n
        Vector for right hand side of equation.

    Returns
    -------
    x: Solution to  ``A x = b``.
    """

    x = torch.linalg.solve(A, b)
    return x

def inverse_stereographic(u):
    """Map from complex plane to sphere via inverse stereographic projection.

    Parameters
    ----------
    u : tensor
        two columns (real, img)
        for coordinates on complex plane.

    Returns
    -------
    v: tensor of shape (n, 3)
        Coordinates on sphere in 3D.
    """
    if len(u.shape) == 1:
      x = u.real
      y = u.imag
    else:
      x = u[:, 0]
      y = u[:, 1]
    z = 1 + x**2 + y**2
    v = torch.column_stack((2 * x / z, 2 * y / z, (-1 + x**2 + y**2) / z))
    return v


def beltrami_coefficient(v, t, mapping):
    """Compute the Beltrami coefficient of a mapping.

    Parameters
    ----------
    v, t: TriaMesh
        Genus-0 closed triangle mesh.
        Should be planar mapping on complex plane.
    mapping : array
        3D coordinates of the spherical conformal parameterization.

    Returns
    -------
    mu : tensor
        Complex Beltrami coefficient per triangle.
    """
    # here we should be in the plane
    if torch.amax(v[:, 2]) - torch.amin(v[:, 2]) > 0.001:
        print("ERROR: mesh should be on complex plane ..")
        raise ValueError("not planar")

    # get 2d vetrices, edges and area
    v0 = (v[t[:, 0], :])[:, :-1]
    v1 = (v[t[:, 1], :])[:, :-1]
    v2 = (v[t[:, 2], :])[:, :-1]
    e0 = v2 - v1
    e1 = v0 - v2
    e2 = v1 - v0
    # double areas
    e0_pad = torch.nn.functional.pad(e0, (0, 1))
    e1_pad = torch.nn.functional.pad(e1, (0, 1))
    # print(e0_pad.shape, e1_pad.shape)
    areas2 = torch.linalg.cross(e0_pad, e1_pad)[:, -1]  # returns z-component is length
    # print(areas2.shape)

    # create tria,vertex matrices (summing area normalized edge coords)
    nf = t.shape[0]
    tids = torch.arange(nf)
    i = torch.column_stack((tids, tids, tids)).reshape(-1)
    j = t.reshape(-1)
    datx = (
        torch.column_stack((e0[:, 1], e1[:, 1], e2[:, 1])) / areas2[:, None]
    ).reshape(-1)
    daty = -(
        torch.column_stack((e0[:, 0], e1[:, 0], e2[:, 0])) / areas2[:, None]
    ).reshape(-1)
    nv = v.shape[0]
    Dx = torch.zeros((nf, nv))
    Dx.index_put_((i, j), datx, accumulate=True)
    Dy = torch.zeros((nf, nv))
    Dy.index_put_((i, j), daty, accumulate=True)

    dXdu = Dx@mapping[:, 0]
    dXdv = Dy@mapping[:, 0]
    dYdu = Dx@mapping[:, 1]
    dYdv = Dy@mapping[:, 1]
    dZdu = Dx@mapping[:, 2]
    dZdv = Dy@mapping[:, 2]

    E = dXdu**2 + dYdu**2 + dZdu**2
    G = dXdv**2 + dYdv**2 + dZdv**2
    F = dXdu * dXdv + dYdu * dYdv + dZdu * dZdv
    mu = (E - G + 2j * F) / (E + G + 2.0 * torch.sqrt(E * G - F**2))

    return mu



def linear_beltrami_solver(v, t, mu, landmark, target):
    """Linear Beltrami solver.

    Parameters
    ----------
    v, t : Genus-0 closed triangle mesh.
        Should be planar mapping on complex plane.
    mu :Complex Beltrami coefficients.
    landmark :Fixed vertex indices.
    target : 3d tensor with 2d landmark target coordinates (3rd coordinate is zero).

    Returns
    -------
    mapping : tensor
        3d vertex coordinates of new mapping.
    """
    # here we should be in the plane
    if torch.amax(v[:, 2]) - torch.amin(v[:, 2]) > 0.001:
        print("ERROR: mesh should be on complex plane ..")
        raise ValueError("not planar")

    af = (1.0 - 2 * torch.real(mu) + torch.abs(mu) ** 2) / (1.0 - torch.abs(mu) ** 2)
    bf = -2.0 * torch.imag(mu) / (1.0 - torch.abs(mu) ** 2)
    gf = (1.0 + 2 * torch.real(mu) + torch.abs(mu) ** 2) / (1.0 - torch.abs(mu) ** 2)

    # get 2D vertices (drop 3rd dim)
    t0 = t[:, 0]
    t1 = t[:, 1]
    t2 = t[:, 2]
    v0 = (v[t0, :])[:, :-1]
    v1 = (v[t1, :])[:, :-1]
    v2 = (v[t2, :])[:, :-1]

    uxv0 = v1[:, 1] - v2[:, 1]
    uyv0 = v2[:, 0] - v1[:, 0]
    uxv1 = v2[:, 1] - v0[:, 1]
    uyv1 = v0[:, 0] - v2[:, 0]
    uxv2 = v0[:, 1] - v1[:, 1]
    uyv2 = v1[:, 0] - v0[:, 0]

    c0 = torch.sqrt(uxv0**2 + uyv0**2)
    c1 = torch.sqrt(uxv1**2 + uyv1**2)
    c2 = torch.sqrt(uxv2**2 + uyv2**2)
    s = 0.5 * (c0 + c1 + c2)
    area2 = 2 * torch.sqrt(s * (s - c0) * (s - c1) * (s - c2))

    v00 = (af * uxv0 * uxv0 + 2 * bf * uxv0 * uyv0 + gf * uyv0 * uyv0) / area2
    v11 = (af * uxv1 * uxv1 + 2 * bf * uxv1 * uyv1 + gf * uyv1 * uyv1) / area2
    v22 = (af * uxv2 * uxv2 + 2 * bf * uxv2 * uyv2 + gf * uyv2 * uyv2) / area2
    v01 = (
        af * uxv1 * uxv0 + bf * uxv1 * uyv0 + bf * uxv0 * uyv1 + gf * uyv1 * uyv0
    ) / area2
    v12 = (
        af * uxv2 * uxv1 + bf * uxv2 * uyv1 + bf * uxv1 * uyv2 + gf * uyv2 * uyv1
    ) / area2
    v20 = (
        af * uxv0 * uxv2 + bf * uxv0 * uyv2 + bf * uxv2 * uyv0 + gf * uyv0 * uyv2
    ) / area2

    # create symmetric A
    i = torch.column_stack((t0, t1, t2, t0, t1, t1, t2, t2, t0)).reshape(-1)
    j = torch.column_stack((t0, t1, t2, t1, t0, t2, t1, t0, t2)).reshape(-1)
    dat = torch.column_stack((v00, v11, v22, v01, v01, v12, v12, v20, v20)).reshape(-1)
    nv = v.shape[0]
    A = torch.zeros((nv, nv), dtype=torch.cfloat)
    A.index_put_((i, j), dat.type(torch.cfloat), accumulate=True)

    # convert target to complex and set b vector
    targetc = target[:, 0] + 1j * target[:, 1]
    b = -A[:, landmark] * targetc
    b[landmark] = targetc

    # set all rows and columns in landmark to zero and put diag 1
    mrow, mcol = torch.nonzero(A[landmark, :], as_tuple=True)
    A[landmark[mrow], mcol] = 0
    mrow, mcol = torch.nonzero(A[:, landmark], as_tuple=True)
    A[mrow,landmark[mcol]] = 0
    # Aones = torch.zeros((nv, nv))
    # Aones.index_put_((landmark, landmark), torch.ones(landmark.shape[0]), accumulate=True)
    A[landmark, landmark] += 1

    x = symmetric_solve(A, b)

    mapping = x.squeeze()
    mapping = torch.column_stack((torch.real(mapping), torch.imag(mapping)))
    return mapping

## Conformal map

In [86]:
def spherical_conformal_mapping(v, t):
    """Linear method for computing spherical conformal map of a genus-0 closed surface.

    Parameters
    ----------
    v, t : Triangle mesh.

    Returns
    -------
    mapping: tensor
        Vertex coordinates (3d) of the spherical conformal parameterization.
    """
    # Check whether the input mesh is spherical topology (genus-0)
    if euler(v, t) != 2:
        print("ERROR: The mesh is not a genus-0 closed surface ..")
        raise ValueError("not genus-0")

    # Find the most regular triangle as the "big triangle"
    tquals = tria_qualities(v, t)
    bigtri = torch.argmax(tquals)
    # print(bigtri, tquals[bigtri])
    # If it turns out that the spherical parameterization result is homogeneous
    # you can try to change bigtri to the id of some other triangles with good quality

    # North pole step: Compute spherical map
    #   by solving laplace equation on a big triangle
    nv = torch.tensor(v.shape[0])
    M = computeAB(v, t)[0]

    p0 = t[bigtri, 0]
    p1 = t[bigtri, 1]
    p2 = t[bigtri, 2]
    fixed = t[bigtri, :]

    # set all rows and cols with fixed vidxs to zero
    # and set diag entries to 1
    mrow, mcol = torch.nonzero(M[fixed, :], as_tuple=True)
    mval = M[fixed[mrow], mcol]
    M[fixed[mrow], mcol] = 0
    M[fixed, fixed] = 1

    # find embedding of the bigtria (boundary condition later)
    # arbitrarily set first two points
    x0, y0, x1, y1 = torch.tensor(0), torch.tensor(0), torch.tensor(1), torch.tensor(0)
    a = v[p1, :] - v[p0, :]
    b = v[p2, :] - v[p0, :]
    sin1 = torch.linalg.norm(torch.cross(a, b)) / (torch.linalg.norm(a) * torch.linalg.norm(b))
    ori_h = torch.linalg.norm(b) * sin1
    ratio = torch.sqrt(((x0 - x1) ** 2 + (y0 - y1) ** 2)) / torch.linalg.norm(a)
    y2 = ori_h * ratio  # compute the coordinates of the third vertex
    x2 = torch.sqrt(torch.linalg.norm(b) ** 2 * ratio**2 - y2**2)
    # should be around (0.5, sqrt(3)/2) if we found an equilateral bigtri

    # Solve the Laplace equation to obtain a harmonic map
    c = torch.zeros((nv, 1))
    c[p0], c[p1], c[p2] = x0, x1, x2
    d = torch.zeros((nv, 1))
    d[p0], d[p1], d[p2] = y0, y1, y2
    rhs = torch.empty(c.shape[:-1], dtype=torch.cfloat)
    rhs.real = c.flatten()
    rhs.imag = d.flatten()

    z = symmetric_solve(M.to(torch.cfloat), rhs)
    z = z.squeeze()
    z = z - torch.mean(z, dim=0)

    # inverse stereographic projection (not scaled well)
    S = inverse_stereographic(z)

    # Find optimal big triangle size
    w = torch.empty(S.shape[:-1], dtype=torch.cfloat)
    w.real = (S[:, 0] / (1 + S[:, 2])).flatten()
    w.imag = (S[:, 1] / (1 + S[:, 2])).flatten()

    # find the index of the southernmost triangle
    index = torch.argsort(
        torch.abs(z[t[:, 0]]) + torch.abs(z[t[:, 1]]) + torch.abs(z[t[:, 2]])
    )
    inner = index[0]
    if inner == bigtri:
        inner = index[1]

    # Compute the size of the northern most and the southern most triangles
    NorthTriSide = (
        torch.abs(z[t[bigtri, 0]] - z[t[bigtri, 1]])
        + torch.abs(z[t[bigtri, 1]] - z[t[bigtri, 2]])
        + torch.abs(z[t[bigtri, 2]] - z[t[bigtri, 0]])
    ) / 3.0

    SouthTriSide = (
        torch.abs(w[t[inner, 0]] - w[t[inner, 1]])
        + torch.abs(w[t[inner, 1]] - w[t[inner, 2]])
        + torch.abs(w[t[inner, 2]] - w[t[inner, 0]])
    ) / 3.0

    # rescale to get the best distribution
    z = z * torch.sqrt(NorthTriSide * SouthTriSide) / NorthTriSide

    # inverse stereographic projection (now distributed well)
    S = inverse_stereographic(z)

    if torch.isnan(torch.sum(S)):
        raise ValueError("Error: projection contains nan value(s)!")
        # could revert to spherical tutte map here

    # South pole step
    idx = torch.argsort(S[:, 2])

    # number of points near the south pole to be fixed
    # simply set it to be 1/10 of the total number of vertices (can be changed)
    # In case the spherical parameterization is not good, change 10 to
    # something smaller (e.g. 2)
    fixnum = torch.maximum(torch.round(nv*0.1), torch.tensor(3))
    fixed = idx[0 : torch.minimum(nv, fixnum).type(torch.int)]

    # south pole stereographic projection
    P = torch.column_stack(
        (S[:, 0] / (1 + S[:, 2]), S[:, 1] / (1 + S[:, 2]), torch.zeros(nv))
    )

    # compute the Beltrami coefficient (value per triangle)
    mu = beltrami_coefficient(P, t, v.float())

    # compose the map with another quasi-conformal map to cancel the distortion
    mapping = linear_beltrami_solver(P, t, mu, fixed, P[fixed, :])

    if torch.isnan(torch.sum(mapping)):
        # if the result has NaN entries, then most probably the number of
        # boundary constraints is not large enough
        # increase the number of boundary constrains and run again
        print("South pole compsed map has nan value(s)!")
        fixnum = fixnum * 5  # again, this number can be changed
        fixed = idx[0 : torch.minimum(nv, fixnum).type(torch.int)]
        mapping = linear_beltrami_solver(P, t, mu, fixed, P[fixed, :])
        if torch.isnan(torch.sum(mapping)):
            mapping = P  # use the old result

    # inverse south pole stereographic projection
    mapping = inverse_stereographic(mapping)
    return mapping

## Loading mesh

In [88]:
! pip install meshio
import meshio

mesh = meshio.read('dog.obj')
vert = mesh.points
face = mesh.cells



In [89]:
# Send vert and face to tensor
v = torch.from_numpy(vert)
v.requires_grad = True
t = torch.from_numpy(face[0].data)

In [92]:
s = spherical_conformal_mapping(v, t)
print(euler(s, t)==2) # Check if it's still genus zero

South pole compsed map has nan value(s)!
True


In [97]:
# Mesh plotting function
def plot_mesh(v, t, fig):
  tri_vertices = v[t]
  Xe = []
  Ye = []
  Ze = []

  for T in tri_vertices:
      Xe += [T[k%3][0] for k in range(4)] + [ None]
      Ye += [T[k%3][1] for k in range(4)] + [ None]
      Ze += [T[k%3][2] for k in range(4)] + [ None]


  fig.add_trace(go.Scatter3d(x=Xe,
                      y=Ye,
                      z=Ze,
                      mode='lines',
                      name='',
                      line=dict(color= 'rgb(40,40,40)', width=0.5)));
  fig.show()

In [98]:
fig = go.Figure()
plot_mesh(vert, t, fig)

In [99]:
# Spherical conformal mapped mesh
fig = go.Figure()
plot_mesh(s.detach().numpy(), t, fig)

## Appendix: Spectral basis

In [100]:
! pip install torch_geometric
from torch_geometric.utils.mesh_laplacian import get_mesh_laplacian
ind, w = get_mesh_laplacian(v, t.T)
M = torch.zeros((len(v), len(v)))
M[ind.T[:,0], ind.T[:,1]] = w.float()



In [106]:
L, V = torch.lobpcg(-M, k=4, largest=False)
vv = V[:, 1:] / torch.linalg.norm(V[:, 1:], dim=-1)[:, None]
fig = go.Figure()
plot_mesh(vv.detach().numpy(), t, fig)

## Additional human body mesh data if needed

In [None]:
# ! wget https://download.is.tue.mpg.de/faust/MPI-FAUST.zip -P ./FAUST/raw
!pip install openmesh
from torch_geometric.datasets import FAUST
train = FAUST(root='./FAUST',train=True)