In [1]:
import torch
from math import factorial
from pytorch3d.io import load_obj

In [2]:
# m = load_obj("model_normalized.obj",load_textures=False)
# v_matrix = m[0]
# f_matrix = m[1].verts_idx

m = torch.load('mesh.pt')
v_matrix = m['verts']
f_matrix = m['faces']

In [3]:
def M_ijk_torch(f_matrix,v_matrix,max_m,volume_winding_test=False):
    
    # Define face ordering
    order = [0,1,2]
    # Test the volume. Swap face order if needed
    if (not volume_winding_test) and torch.all(M_ijk_torch(f_matrix,v_matrix,1,volume_winding_test=True)<0):
        order = [0,2,1]
        print('Corrected face ordering')

        
    num_faces = len(f_matrix)
    
    # Extract coordinates
    ABC = v_matrix[f_matrix[:,torch.LongTensor(order)]]
    # Calculate Determinants
    dets = torch.det(ABC)
    
    # Allocate Tensors
    M_tensor = torch.zeros([num_faces,max_m,max_m,max_m],dtype=torch.float32)
    C_tensor = torch.zeros([num_faces,max_m,max_m,max_m],dtype=torch.float32)
    D_tensor = torch.zeros([num_faces,max_m,max_m,max_m],dtype=torch.float32)
    S_tensor = torch.zeros([num_faces,max_m,max_m,max_m],dtype=torch.float32)
    
    # Calculate C Tensor, parallellized over faces
    for i in range(max_m):
            for j in range(max_m):
                for k in range(max_m):
                    if (i+j+k)<=max_m:
                        C_tensor[:,i,j,k] = (ABC[:,2][:,0]**i)*(ABC[:,2][:,1]**j)*(ABC[:,2][:,2]**k)*(factorial(i+j+k)/(factorial(i)*factorial(j)*factorial(k)))
                        
    # Calculate D Tensor, parallellized over faces
    for i in range(max_m):
            for j in range(max_m):
                for k in range(max_m):
                    if (i+j+k)<=max_m:
                        if (i<0) or (j<0) or (k<0):
                            # D_ijk=0
                            pass
                        elif (0==i) and (0==j) and (0==k):
                            # D_ijk = 1
                            D_tensor[:,i,j,k] = 1
                        else:
                            D_tensor[:,i,j,k] = ABC[:,1][:,0]*D_tensor[:,i-1,j,k]+ABC[:,1][:,1]*D_tensor[:,i,j-1,k]+ABC[:,1][:,2]*D_tensor[:,i,j,k-1]+C_tensor[:,i,j,k]
    
    
    # Calculate S Tensor, parallellized over faces
    for i in range(max_m):
            for j in range(max_m):
                for k in range(max_m):
                    if (i+j+k)<=max_m:
                        if (i<0) or (j<0) or (k<0):
                            # S_ijk = 0
                            pass
                        elif (0==i) and (0==j) and (0==k):
                            # S_ijk = 1
                            S_tensor[:,i,j,k] = 1
                        else:
                            S_tensor[:,i,j,k] = ABC[:,0][:,0]*S_tensor[:,i-1,j,k]+ABC[:,0][:,1]*S_tensor[:,i,j-1,k]+ABC[:,0][:,2]*S_tensor[:,i,j,k-1]+D_tensor[:,i,j,k]
        
    # Calculate M Tensor, parallellized over faces
    for i in range(max_m):
        for j in range(max_m):
            for k in range(max_m):
                if (i+j+k)<=max_m:
                    M_tensor[:,i,j,k] = ((factorial(i)*factorial(j)*factorial(k))/factorial(i+j+k+3))*(dets[:]*S_tensor[:,i,j,k])
                    
    return torch.sum(M_tensor,0)
    

# M_ijk_torch(torch.tensor(f_matrix,dtype=torch.long),torch.tensor(v_matrix),3)

In [4]:
def moment_loss_torch(f_matrix_1, v_matrix_1, f_matrix_2, v_matrix_2, max_m = 10):
    return torch.linalg.norm(M_ijk_torch(f_matrix_1, v_matrix_1,max_m)-M_ijk_torch(f_matrix_2, v_matrix_2,max_m))

In [5]:
moment_loss_torch(f_matrix,v_matrix,f_matrix,v_matrix,10)

tensor(0.)

In [6]:
M_ijk_torch(f_matrix,v_matrix,10)

tensor([[[ 5.9870e-02, -4.1261e-05,  5.8025e-04, -1.1728e-06,  1.0728e-05,
          -3.5438e-08,  2.4527e-07, -1.1064e-09,  6.2883e-09, -3.5317e-11],
         [-1.6944e-03,  1.1383e-06, -2.1911e-05,  4.3252e-08, -4.7772e-07,
           1.5162e-09, -1.2175e-08,  5.1949e-11, -3.3739e-10,  1.7620e-12],
         [ 3.3326e-04, -2.3302e-07,  3.2194e-06, -6.5359e-09,  6.0155e-08,
          -1.9666e-10,  1.4022e-09, -6.1469e-12,  3.6868e-11,  0.0000e+00],
         [-1.9756e-05,  1.3802e-08, -2.5480e-07,  5.1459e-10, -5.4768e-09,
           1.7683e-11, -1.3772e-10,  5.9463e-13,  0.0000e+00,  0.0000e+00],
         [ 3.7181e-06, -2.5943e-09,  3.7998e-08, -7.7019e-11,  7.3872e-10,
          -2.4065e-12,  1.7690e-11,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [-2.8839e-07,  2.0751e-10, -3.8626e-09,  7.8989e-12, -8.4164e-11,
           2.7380e-13,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 5.1920e-08, -3.6796e-11,  5.7844e-10, -1.1786e-12,  1.1825e-11,
           0.0000e+