In [1]:
import numpy as np
import ot
import gwb as gwb
import lgw
from gwb import GM as gm
from sklearn.decomposition import PCA
import trimesh
from tqdm import trange
import matplotlib.cm as cm
import os

import open3d as o3d

# Load FAUST and DF DATA

In [2]:
def make_o3dmesh_from_node_tri_col(X,tri,col):
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(X)
    mesh.triangles = o3d.utility.Vector3iVector(tri)
    mesh.vertex_colors = o3d.utility.Vector3dVector(col)
    return mesh

In [3]:
sqd = 12000

In [4]:
#filepath to the FAUST registration meshes
filepath_FAUST = "../data/MPI-FAUST/training/registrations/"

#filepath to the Mesh Deformation Dataset meshes
filepath_DF = "../data/DEFORM/"

In [5]:
#LOAD FAUST
Nodes_FAUST = []
Tris_FAUST = []
for i in range(100):
    filepath = filepath_FAUST + "tr_reg_{0}.ply".format(str("{:03d}".format(i)))
    pcd = o3d.io.read_triangle_mesh(filepath)
    mesh = trimesh.Trimesh(vertices=pcd.vertices,faces = pcd.triangles)
    Nodes_FAUST.append(np.array(mesh.vertices))
    Tris_FAUST.append(np.array(mesh.faces))

Nodes_FAUST = np.array(Nodes_FAUST)
Tris_FAUST = np.array(Tris_FAUST)

In [6]:
#LOAD DEFORM
class_names_DF = ["camel","cat","elephant","face","head","horse","lion"]
lengths_DF = [11,10,11,10,10,11,10]
lengths_by_class_DF = dict(zip(class_names_DF,[+ np.arange(lengths_DF[i]) for i in range(len(lengths_DF))]))
idxs_by_class_DF = dict(zip(class_names_DF,[int(np.sum(lengths_DF[:i])) + np.arange(lengths_DF[i]) for i in range(len(lengths_DF))]))

#LOAD DEFORM
Nodes_DF = []
Tris_DF = []
for class_name in class_names_DF:
    for i in lengths_by_class_DF[class_name]:
        filepath = filepath_DF + "{0}-poses/{0}-{1}.obj".format(class_name,str(i).zfill(2))
        pcd = o3d.io.read_triangle_mesh(filepath)
        mesh = trimesh.Trimesh(vertices=pcd.vertices,faces = pcd.triangles)
        mesh = mesh.simplify_quadratic_decimation(sqd)
        
        Nodes_DF.append(np.array(mesh.vertices))
        Tris_DF.append(np.array(mesh.faces))

# Create Directoy for Output Meshes

In [7]:
if not os.path.exists("./3d"):
    os.mkdir("./3d")

# 3D 2-Interpolations between FAUST subjects

In [8]:
idxs_FAUST = [[27,29],[44,54]]
for i,j in idxs_FAUST:
    print("----------------")
    print("FAUST {0}  <-> {1}".format(i,j))
    print("----------------")
    
    X = gm(X=Nodes_FAUST[i],Tris=Tris_FAUST[i],mode="surface",gauge_mode="djikstra")
    Y = gm(X=Nodes_FAUST[j],Tris=Tris_FAUST[j],mode="surface",gauge_mode="djikstra")
    print("GM spaces generated!")
    
    print("Diam X: {0}".format(np.max(ot.dist(Nodes_FAUST[i],metric="euclidean"))))
    print("Diam Y: {0}".format(np.max(ot.dist(Nodes_FAUST[j],metric="euclidean"))))
    
    #One iteration of TB
    idxs,meas,Ps,idxs_ref = gwb.tb(0,[X,Y],numItermaxEmd=500000)
    print("Barycenter computed!")
    
    print("gw_dist: {0}".format(lgw.LGW_via_idxs([X,Y],idxs,meas)[1,0]))
    
    
    #Create Embeddings etc
    N_interpolation = 5
    cX,cY,bary_embs,bary_embs_3d,bary_c,pca_scores = gwb.interpolate_two(X,Y,idxs,meas,Ps[1],N_interpolation)
    bary_tris = gwb.generate_triangles_from_single(X,idxs,0)
    print("Surfaces constructed!")
    
    print("Diameters of Barys: {0}".format([np.max(ot.dist(emb,metric="euclidean")) for emb in bary_embs]))
    print("PCA Scores: {0}".format(pca_scores))

    #save meshes
    tmp = cX / np.max(cX)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(X.X,X.Tris,col)
    o3d.io.write_triangle_mesh(filename="./3d/FAUST_{0}_{1}_X.ply".format(i,j),mesh=mesh)

    tmp = cY / np.max(cY)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(Y.X,Y.Tris,col)
    o3d.io.write_triangle_mesh(filename="./3d/FAUST_{0}_{1}_Y.ply".format(i,j),mesh=mesh)

    for k in range(N_interpolation):
        tmp = bary_c[k] / np.max(bary_c[k])
        col = cm.viridis(tmp)[:,:3]
        mesh = make_o3dmesh_from_node_tri_col(bary_embs_3d[k],bary_tris,col)
        o3d.io.write_triangle_mesh(filename="./3d/FAUST_{0}_{1}_B_{2}.ply".format(i,j,k),mesh=mesh)
    print("Outputs saved!")

----------------
FAUST 27  <-> 29
----------------
GM spaces generated!
Diam X: 1.8200711879204465
Diam Y: 2.0612168142413303
Barycenter computed!
gw_dist: 0.06573626096634741
Surfaces constructed!
Diameters of Barys: [1.8619977559553098, 1.903303924688805, 1.9438334247963707, 1.9835628802893566, 2.022762523799196]
PCA Scores: [0.15336821478620918, 0.19938760850928358, 0.19735998501086469, 0.1740279213156925, 0.13008671856960966]
Outputs saved!
----------------
FAUST 44  <-> 54
----------------
GM spaces generated!
Diam X: 1.6176458713210502
Diam Y: 1.6464366764613665
Barycenter computed!
gw_dist: 0.055318974712936134
Surfaces constructed!
Diameters of Barys: [1.6210192515512465, 1.6257134376175062, 1.6303941083678193, 1.6350613798725913, 1.6401958900504618]
PCA Scores: [0.05642603527757578, 0.07232227896048742, 0.07690843165011564, 0.07213720817291962, 0.05656587686021616]
Outputs saved!


# 3D 2-Interpolations between DF animals

In [9]:
idxs_DF = [[idxs_by_class_DF["lion"][-1], idxs_by_class_DF["lion"][-2]],
            [idxs_by_class_DF["horse"][3], idxs_by_class_DF["camel"][3]],
          ]
for i,j in idxs_DF:
    print("----------------")
    print("DF {0}  <-> {1}".format(i,j))
    print("----------------")
    X = gm(X=Nodes_DF[i],Tris=Tris_DF[i],mode="surface",gauge_mode="djikstra",squared=False)#,xi="surface_uniform")
    Y = gm(X=Nodes_DF[j],Tris=Tris_DF[j],mode="surface",gauge_mode="djikstra",squared=False)#,xi="surface_uniform")
    print("GM spaces generated!")
    
    print("Diam X: {0}".format(np.max(ot.dist(Nodes_DF[i],metric="euclidean"))))
    print("Diam Y: {0}".format(np.max(ot.dist(Nodes_DF[j],metric="euclidean"))))

    #One iteration of TB
    idxs,meas,Ps,idxs_ref = gwb.tb(0,[X,Y],numItermaxEmd=500000)
    print("Barycenter computed!")
    
    print("gw_dist: {0}".format(lgw.LGW_via_idxs([X,Y],idxs,meas)[1,0]))
    
    #Create Embeddings etc
    N_interpolation = 5
    cX,cY,bary_embs,bary_embs_3d,bary_c,pca_scores = gwb.interpolate_two(X,Y,idxs,meas,Ps[1],N_interpolation)
    bary_tris = gwb.generate_triangles_from_single(X,idxs,0)
    print("Surfaces constructed!")
    
    print("Diameters of Barys: {0}".format([np.max(ot.dist(emb,metric="euclidean")) for emb in bary_embs]))
    print("PCA Scores: {0}".format(pca_scores))

    #save meshes
    tmp = cX / np.max(cX)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(X.X,X.Tris,col)
    o3d.io.write_triangle_mesh(filename="./3d/DF_{0}_{1}_X.ply".format(i,j),mesh=mesh)

    tmp = cY / np.max(cY)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(Y.X,Y.Tris,col)
    o3d.io.write_triangle_mesh(filename="./3d/DF_{0}_{1}_Y.ply".format(i,j),mesh=mesh)

    for k in range(N_interpolation):
        tmp = bary_c[k] / np.max(bary_c[k])
        col = cm.viridis(tmp)[:,:3]
        mesh = make_o3dmesh_from_node_tri_col(bary_embs_3d[k],bary_tris,col)
        o3d.io.write_triangle_mesh(filename="./3d/DF_{0}_{1}_B_{2}.ply".format(i,j,k),mesh=mesh)
    print("Outputs saved!")

----------------
DF 72  <-> 71
----------------
GM spaces generated!
Diam X: 0.6069540744631324
Diam Y: 0.6258420332203521
Barycenter computed!
gw_dist: 0.029316152915260237
Surfaces constructed!
Diameters of Barys: [0.6099247708264421, 0.6130630961632622, 0.616185437743284, 0.6193334538001614, 0.6225488207045161]
PCA Scores: [0.02491672135989406, 0.03433129741184591, 0.03820806851442808, 0.03676780728318005, 0.029155259421088438]
Outputs saved!
----------------
DF 55  <-> 3
----------------
GM spaces generated!
Diam X: 1.0991010854365595
Diam Y: 1.1741971211637272
Barycenter computed!
gw_dist: 0.149645512523082
Surfaces constructed!
Diameters of Barys: [1.0722014468167407, 1.0930231388865677, 1.1135359738672272, 1.1339154637810602, 1.154221544999327]
PCA Scores: [0.1099636618596759, 0.14692569271688327, 0.17608465045956762, 0.17915426041626475, 0.1420583454749126]
Outputs saved!


# 3D Interpolation between 4 FAUST subjects

In [10]:
#pick four human subjects
idxs_input = [7,13,69,36]

#generate mm spaces
Xs = [gm(X=Nodes_FAUST[i],Tris=Tris_FAUST[i],mode="surface",gauge_mode="djikstra",squared=False) for i in idxs_input]
print("GM spaces generated!")

#iterate tb
n_its_tb = 3
bary = 0
for i in trange(n_its_tb):
    if i == 0:
        init_Ps = None
    else:
        init_Ps = gwb.create_backwards_initplans(Xs,ref_idx,idxs,meas)
    bary_prev = bary
    idxs, meas, Ps, ref_idx = gwb.tb(bary,Xs,numItermaxEmd=500000,init_Ps=init_Ps)
    bary = gwb.bary_from_tb(Xs,idxs,meas)

    gwbl_prev = gwb.gwb_loss(bary_prev,Xs,Ps)
    print("GWBL (prev): {0}".format(gwbl_prev))

    #surface embedding of the central barycenter
    cX = np.linalg.norm(Xs[0].X - Xs[0].X[np.argmin(np.linalg.norm(Xs[0].X,axis=1))],axis=1)
    cXs = [(P.T / np.sum(P,axis=1)).dot(cX) for P in Ps]
    bary_emb,bary_emb_3d,bary_tri,pca_s,bary_c = gwb.single_interpolate_four(Xs,idxs,meas,cXs,w=None)
    print("PCA Score (center): {0}".format(pca_s))

    tmp = bary_c / np.max(bary_c)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(bary_emb_3d,bary_tri,col)
    o3d.io.write_triangle_mesh(filename="./3d/FAUST_four_7_13_69_36_center_per_it_{0}.ply".format(i),mesh=mesh)

#generate 3d embeddings
n_grid = 5
cXs,bary_embs,bary_embs_3d,bary_cs,bary_tris,pcas = gwb.interpolate_four(Xs,idxs,meas,Ps,n_grid = n_grid)

print("PCA Scores:")
print(pcas)

#save meshes
for i in range(len(Xs)):
    X = Xs[i]
    cX = cXs[i]
    tmp = cX / np.max(cX)
    col = cm.viridis(tmp)[:,:3]
    mesh = make_o3dmesh_from_node_tri_col(X.X,X.Tris,col)
    o3d.io.write_triangle_mesh(filename="./3d/FAUST_four_7_13_69_36_Xs_{0}.ply".format(idxs_input[i]),mesh=mesh)
    
for i in range(n_grid):
    for j in range(n_grid):
        tmp = bary_cs[i,j] / np.max(bary_cs[i,j])
        col = cm.viridis(tmp)[:,:3]
        mesh = make_o3dmesh_from_node_tri_col(bary_embs_3d[i,j],bary_tris[i,j],col)
        o3d.io.write_triangle_mesh(filename="./3d/FAUST_four_7_13_69_36_B_{0}_{1}.ply".format(i,j),mesh=mesh)
print("Outputs saved!")

GM spaces generated!


  0%|                                                     | 0/3 [00:00<?, ?it/s]

GWBL (prev): 0.005306184587758289


 33%|█████████████▋                           | 1/3 [38:01<1:16:02, 2281.37s/it]

PCA Score (center): 0.13536133104263748
GWBL (prev): 0.0036027841539943337


 67%|████████████████████████████▋              | 2/3 [56:29<26:31, 1591.42s/it]

PCA Score (center): 0.1353376072687755
GWBL (prev): 0.0036016137175827944


100%|█████████████████████████████████████████| 3/3 [1:11:24<00:00, 1428.21s/it]

PCA Score (center): 0.13533247875591822





PCA Scores:
[[2.38191330e-16 1.62182078e-01 1.49418228e-01 1.14732696e-01
  7.09976039e-17]
 [1.18548619e-01 1.62398749e-01 1.46318415e-01 1.21552600e-01
  6.96010236e-02]
 [1.39336948e-01 1.49823139e-01 1.35332479e-01 1.13733178e-01
  6.99667413e-02]
 [1.21579814e-01 1.27410733e-01 1.16763681e-01 9.80612739e-02
  5.46945640e-02]
 [1.07318217e-16 8.34299610e-02 8.29792334e-02 6.96381969e-02
  1.09673368e-16]]
Outputs saved!
