In [1]:
import numpy as np
import json
import os
import copy
import pickle

import mesh_sampling
import trimesh
from shape_data import ShapeData

from autoencoder_dataset import autoencoder_dataset
from torch.utils.data import DataLoader

from spiral_utils import get_adj_trigs, generate_spirals
from models import SpiralAutoencoder
from train_funcs import train_autoencoder_dataloader
from test_funcs import test_autoencoder_dataloader

import torch
from tensorboardX import SummaryWriter

from sklearn.metrics.pairwise import euclidean_distances
meshpackage = 'trimesh' # 'mpi-mesh', 'trimesh'
root_dir = 'Data'

dataset = 'COMA'
name = ''

GPU = True
device_idx = 0
torch.cuda.get_device_name(device_idx)

'NVIDIA RTX A3000 Laptop GPU'

In [2]:
args = {}

generative_model = 'autoencoder'
downsample_method = 'COMA_downsample' # choose'COMA_downsample' or 'meshlab_downsample'


# below are the arguments for the DFAUST run
reference_mesh_file = os.path.join(root_dir, dataset, 'template', 'template.obj')
downsample_directory = os.path.join(root_dir, dataset,'template', downsample_method)
ds_factors = [4, 4, 4, 4]
step_sizes = [1, 1, 1, 1, 1]
filter_sizes_enc = [[3, 64, 64, 64, 128],[[],[],[],[],[]]]
filter_sizes_dec = [[128, 64, 64, 64, 3],[[],[],[],[],[]]]
dilation_flag = False
if dilation_flag:
    dilation=[2, 2, 1, 1, 1] 
else:
    dilation = None
reference_points = [[3567,4051,4597]] #[[414]]  # [[3567,4051,4597]] used for COMA with 3 disconnected components

args = {'generative_model': generative_model,
        'name': name, 'data': os.path.join(root_dir, dataset, 'preprocessed_identity',name),
        'results_folder':  os.path.join(root_dir, dataset,'results_identity/spirals_'+ generative_model),
        'reference_mesh_file':reference_mesh_file, 'downsample_directory': downsample_directory,
        'checkpoint_file': 'checkpoint',
        'seed':2, 'loss':'l1',
        'batch_size':16, 'num_epochs':250, 'eval_frequency':200, 'num_workers': 4,
        'filter_sizes_enc': filter_sizes_enc, 'filter_sizes_dec': filter_sizes_dec,
        'nz':128, 
        'ds_factors': ds_factors, 'step_sizes' : step_sizes, 'dilation': dilation,
        
        'lr':1e-3, 
        'regularization': 5e-5,         
        'scheduler': True, 'decay_rate': 0.99,'decay_steps':1,  
        'resume': False,
        
        'mode':'train', 'shuffle': True, 'nVal': 100, 'normalization': True}

args['results_folder'] = os.path.join(args['results_folder'],'latent_'+str(args['nz']))
    
if not os.path.exists(os.path.join(args['results_folder'])):
    os.makedirs(os.path.join(args['results_folder']))

summary_path = os.path.join(args['results_folder'],'summaries',args['name'])
if not os.path.exists(summary_path):
    os.makedirs(summary_path)  
    
checkpoint_path = os.path.join(args['results_folder'],'checkpoints', args['name'])
if not os.path.exists(checkpoint_path):
    os.makedirs(checkpoint_path)
    
samples_path = os.path.join(args['results_folder'],'samples', args['name'])
if not os.path.exists(samples_path):
    os.makedirs(samples_path)
    
prediction_path = os.path.join(args['results_folder'],'predictions', args['name'])
if not os.path.exists(prediction_path):
    os.makedirs(prediction_path)

if not os.path.exists(downsample_directory):
    os.makedirs(downsample_directory)

In [3]:
np.random.seed(args['seed'])
print("Loading data .. ")
if not os.path.exists(args['data']+'/mean.npy') or not os.path.exists(args['data']+'/std.npy'):
    shapedata =  ShapeData(nVal=args['nVal'], 
                          train_file=args['data']+'/train.npy', 
                          test_file=args['data']+'/test.npy', 
                          reference_mesh_file=args['reference_mesh_file'],
                          normalization = args['normalization'],
                          meshpackage = meshpackage, load_flag = True)
    np.save(args['data']+'/mean.npy', shapedata.mean)
    np.save(args['data']+'/std.npy', shapedata.std)
else:
    shapedata = ShapeData(nVal=args['nVal'], 
                         train_file=args['data']+'/train.npy',
                         test_file=args['data']+'/test.npy', 
                         reference_mesh_file=args['reference_mesh_file'],
                         normalization = args['normalization'],
                         meshpackage = meshpackage, load_flag = False)
    shapedata.mean = np.load(args['data']+'/mean.npy')
    shapedata.std = np.load(args['data']+'/std.npy')
    shapedata.n_vertex = shapedata.mean.shape[0]
    shapedata.n_features = shapedata.mean.shape[1]

if not os.path.exists(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl')):
    if shapedata.meshpackage == 'trimesh':
        raise NotImplementedError('Rerun with mpi-mesh as meshpackage')
    print("Generating Transform Matrices ..")
    if downsample_method == 'COMA_downsample':
        M,A,D,U,F = mesh_sampling.generate_transform_matrices(shapedata.reference_mesh, args['ds_factors'])
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'wb') as fp:
        M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
        pickle.dump({'M_verts_faces':M_verts_faces,'A':A,'D':D,'U':U,'F':F}, fp)
else:
    print("Loading Transform Matrices ..")
    with open(os.path.join(args['downsample_directory'],'downsampling_matrices.pkl'), 'rb') as fp:
        #downsampling_matrices = pickle.load(fp,encoding = 'latin1')
        downsampling_matrices = pickle.load(fp)
    print(downsampling_matrices)
    M_verts_faces = downsampling_matrices['M_verts_faces']
    if shapedata.meshpackage == 'mpi-mesh':
        M = [Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1]) for i in range(len(M_verts_faces))]
    elif shapedata.meshpackage == 'trimesh':
        M = [trimesh.base.Trimesh(vertices=M_verts_faces[i][0], faces=M_verts_faces[i][1], process = False) for i in range(len(M_verts_faces))]
    A = downsampling_matrices['A']
    D = downsampling_matrices['D']
    U = downsampling_matrices['U']
    F = downsampling_matrices['F']
        
# Needs also an extra check to enforce points to belong to different disconnected component at each hierarchy level
print("Calculating reference points for downsampled versions..")
for i in range(len(args['ds_factors'])):
    if shapedata.meshpackage == 'mpi-mesh':
        dist = euclidean_distances(M[i+1].v, M[0].v[reference_points[0]])
    elif shapedata.meshpackage == 'trimesh':
        dist = euclidean_distances(M[i+1].vertices, M[0].vertices[reference_points[0]])
    reference_points.append(np.argmin(dist,axis=0).tolist())



Loading data .. 
Vertices normalized


NotImplementedError: Rerun with mpi-mesh as meshpackage

In [5]:
if shapedata.meshpackage == 'mpi-mesh':
    sizes = [x.v.shape[0] for x in M]
elif shapedata.meshpackage == 'trimesh':
    sizes = [x.vertices.shape[0] for x in M]
Adj, Trigs = get_adj_trigs(A, F, shapedata.reference_mesh, meshpackage = shapedata.meshpackage)

spirals_np, spiral_sizes,spirals = generate_spirals(args['step_sizes'], 
                                                    M, Adj, Trigs, 
                                                    reference_points = reference_points, 
                                                    dilation = args['dilation'], random = False, 
                                                    meshpackage = shapedata.meshpackage, 
                                                    counter_clockwise = True)

bU = []
bD = []
for i in range(len(D)):
    d = np.zeros((1,D[i].shape[0]+1,D[i].shape[1]+1))
    u = np.zeros((1,U[i].shape[0]+1,U[i].shape[1]+1))
    d[0,:-1,:-1] = D[i].todense()
    u[0,:-1,:-1] = U[i].todense()
    d[0,-1,-1] = 1
    u[0,-1,-1] = 1
    bD.append(d)
    bU.append(u)


[(3, 1, 0), (0, 356, 3), (0, 355, 356), (355, 0, 378), (378, 0, 1)]
[(3, 1, 0), (382, 2, 3), (0, 356, 3), (3, 347, 382), (3, 2, 1), (3, 356, 347)]
[(3, 1, 0), (2, 296, 1), (378, 1, 294), (1, 295, 294), (3, 2, 1), (378, 0, 1), (1, 296, 295)]
[(382, 2, 3), (2, 296, 1), (381, 1558, 2), (3, 2, 1), (382, 381, 2), (2, 1558, 296)]
[(382, 2, 3), (2, 296, 1), (381, 1558, 2), (3, 2, 1), (382, 381, 2), (2, 1558, 296)]
[(431, 381, 382), (404, 1556, 381), (381, 1558, 2), (382, 381, 2), (431, 404, 381), (381, 1556, 1558)]
[(3, 1, 0), (382, 2, 3), (0, 356, 3), (3, 347, 382), (3, 2, 1), (3, 356, 347)]
[(382, 2, 3), (431, 381, 382), (348, 382, 347), (3, 347, 382), (382, 381, 2), (348, 431, 382)]
[(7, 5, 4), (1528, 4, 1529), (1555, 4, 66), (162, 4, 161), (66, 4, 5), (1528, 161, 4), (1555, 1529, 4), (162, 7, 4)]
[(7, 5, 4), (66, 5, 67), (7, 6, 5), (66, 4, 5), (67, 5, 6)]
[(7, 5, 4), (66, 5, 67), (7, 6, 5), (66, 4, 5), (67, 5, 6)]
[(66, 5, 67), (346, 67, 345), (67, 6, 247), (345, 67, 247), (346, 66, 67), 

[(720, 722, 719), (721, 2620, 722), (722, 653, 3023), (720, 721, 722), (719, 722, 3023), (722, 2620, 653)]
[(720, 722, 719), (719, 2670, 720), (3019, 2669, 719), (719, 3023, 3019), (719, 2669, 2670), (719, 722, 3023)]
[(3572, 723, 713), (723, 2120, 715), (3555, 1649, 723), (713, 723, 715), (3572, 3555, 723), (723, 1649, 2120)]
[(713, 715, 714), (3572, 723, 713), (3704, 713, 3850), (713, 723, 715), (3850, 713, 714), (3704, 3572, 713)]
[(629, 724, 725), (724, 2095, 500), (630, 2070, 724), (629, 630, 724), (725, 724, 500), (724, 2070, 2095)]
[(629, 724, 725), (623, 630, 629), (629, 3902, 623), (725, 3891, 629), (629, 630, 724), (629, 3891, 3902)]
[(629, 724, 725), (725, 500, 501), (725, 3891, 629), (725, 724, 500), (3907, 725, 501), (725, 3907, 3891)]
[(3891, 2114, 3902), (3907, 2113, 3891), (725, 3891, 629), (3891, 2113, 2114), (629, 3891, 3902), (725, 3907, 3891)]
[(727, 183, 726), (1926, 726, 1922), (726, 1999, 1998), (726, 1928, 727), (726, 3760, 1922), (1926, 1927, 726), (726, 183, 1

[(1242, 1233, 1232), (1293, 1233, 1289), (1242, 1241, 1233), (1232, 1233, 1244), (1289, 1233, 1241), (1293, 1244, 1233)]
[(1340, 1244, 1293), (1343, 1293, 1292), (1293, 1233, 1289), (1343, 1340, 1293), (1292, 1293, 1289), (1293, 1244, 1233)]
[(1234, 1269, 1235), (1256, 1234, 1070), (1291, 1916, 1234), (1070, 1234, 1235), (1234, 1916, 1269), (1256, 1291, 1234)]
[(1070, 1235, 1071), (1255, 1070, 1069), (1256, 1234, 1070), (1069, 1070, 1071), (1070, 1234, 1235), (1255, 1167, 1070), (1256, 1070, 1167)]
[(1070, 1235, 1071), (1235, 1270, 1261), (1234, 1269, 1235), (1070, 1234, 1235), (1235, 1269, 1270), (1071, 1235, 1261)]
[(1069, 1071, 830), (1070, 1235, 1071), (1071, 1261, 1260), (1071, 2648, 2649), (1069, 1070, 1071), (830, 1071, 2649), (1071, 1235, 1261), (1071, 1260, 2648)]
[(1236, 1238, 744), (854, 2654, 1236), (1236, 743, 854), (1236, 1237, 1238), (1237, 1236, 2654), (1236, 744, 743)]
[(854, 855, 853), (854, 2654, 1236), (1236, 743, 854), (854, 743, 855), (2652, 854, 853), (854, 2652,

[(2042, 1890, 2154), (1891, 1889, 1890), (1890, 3909, 3890), (2154, 1890, 3890), (2042, 1891, 1890), (1890, 1889, 3909)]
[(3909, 1877, 1892), (1890, 3909, 3890), (1889, 3894, 3909), (3909, 3894, 1877), (3890, 3909, 1892), (1890, 1889, 3909)]
[(1994, 1891, 2042), (1891, 1889, 1890), (1996, 1888, 1891), (1994, 1996, 1891), (2042, 1891, 1890), (1891, 1888, 1889)]
[(1888, 1876, 1889), (1891, 1889, 1890), (1889, 3894, 3909), (1891, 1888, 1889), (1890, 1889, 3909), (1889, 1876, 3894)]
[(3909, 1877, 1892), (1893, 604, 1892), (3890, 1892, 604), (1892, 1825, 1893), (3890, 3909, 1892), (1892, 1877, 1825)]
[(1892, 1825, 1893), (1825, 638, 639), (639, 3567, 1825), (1892, 1877, 1825), (1825, 1877, 638), (1893, 1825, 3567)]
[(1893, 604, 1892), (1892, 1825, 1893), (1893, 3567, 3498), (1893, 603, 604), (603, 1893, 3498), (1893, 1825, 3567)]
[(3567, 2132, 2927), (2971, 3567, 2927), (639, 3567, 1825), (1893, 3567, 3498), (3567, 3559, 2132), (2971, 3498, 3567), (639, 3559, 3567), (1893, 1825, 3567)]
[(57

[(2471, 2470, 2472), (3689, 2471, 2494), (3687, 2467, 2471), (2471, 2467, 2470), (2494, 2471, 2472), (3689, 3687, 2471)]
[(3689, 2473, 3687), (3690, 2495, 3689), (2493, 3689, 2494), (3689, 2471, 2494), (3689, 2495, 2473), (2493, 3690, 3689), (3689, 3687, 2471)]
[(2471, 2470, 2472), (2472, 2462, 2466), (2494, 2472, 2491), (2472, 2470, 2462), (2491, 2472, 2466), (2494, 2471, 2472)]
[(2491, 2466, 2486), (2493, 2491, 2492), (2494, 2472, 2491), (2491, 2472, 2466), (2492, 2491, 2486), (2493, 2494, 2491)]
[(2473, 3700, 3687), (3689, 2473, 3687), (2453, 2473, 2495), (2454, 2460, 2473), (2473, 2460, 3700), (3689, 2495, 2473), (2453, 2454, 2473)]
[(3690, 2495, 3689), (2453, 2473, 2495), (2383, 2495, 2496), (3689, 2495, 2473), (3690, 2496, 2495), (2383, 2453, 2495)]
[(2474, 2476, 2477), (2474, 3557, 3523), (2474, 2475, 2476), (3069, 2474, 2477), (2474, 3069, 3557), (2475, 2474, 3523)]
[(3069, 2477, 2482), (2482, 2770, 3069), (3069, 3545, 3557), (3069, 2474, 2477), (3069, 2770, 3545), (2474, 3069,

[(3118, 3115, 3117), (3439, 3116, 3118), (2711, 3118, 3127), (3118, 3116, 3115), (3127, 3118, 3117), (2711, 3439, 3118)]
[(3127, 3117, 3124), (2019, 3127, 3128), (2711, 3118, 3127), (3127, 3118, 3117), (3128, 3127, 3124), (2019, 2711, 3127)]
[(3117, 3119, 3120), (3115, 3112, 3119), (3109, 3120, 3119), (3107, 3119, 3112), (3117, 3115, 3119), (3107, 3109, 3119)]
[(3117, 3119, 3120), (3124, 3120, 3125), (3109, 3120, 3119), (3111, 3125, 3120), (3124, 3117, 3120), (3109, 3111, 3120)]
[(3117, 3119, 3120), (3124, 3120, 3125), (3109, 3120, 3119), (3111, 3125, 3120), (3124, 3117, 3120), (3109, 3111, 3120)]
[(3123, 3125, 3126), (3124, 3120, 3125), (3121, 3126, 3125), (3111, 3125, 3120), (3123, 3124, 3125), (3111, 3121, 3125)]
[(3121, 3110, 3122), (3121, 1754, 1753), (3121, 3126, 3125), (3121, 3111, 3110), (3121, 3122, 1754), (3121, 1753, 3126), (3111, 3121, 3125)]
[(3121, 3110, 3122), (3110, 3105, 3122), (3122, 2206, 1754), (3121, 3122, 1754), (3122, 3105, 2206)]
[(3121, 3110, 3122), (3110, 3105

[(3718, 2196, 3720), (3782, 2197, 3718), (2186, 3718, 2090), (3718, 2197, 2196), (2186, 3782, 3718), (2090, 3718, 3720)]
[(335, 3773, 3721), (3721, 336, 335), (2199, 3729, 3721), (3721, 2196, 2199), (3721, 3729, 336), (3721, 3773, 2196)]
[(2197, 2199, 2196), (2198, 3068, 2199), (2199, 3729, 3721), (3721, 2196, 2199), (2197, 2198, 2199), (2199, 3068, 3729)]
[(3716, 3722, 3717), (3725, 3717, 3722), (3588, 3722, 3740), (3716, 3715, 3722), (3740, 3722, 3715), (3588, 3725, 3722)]
[(3716, 3722, 3717), (3666, 3717, 3672), (3726, 3672, 3717), (3725, 3717, 3722), (3666, 3716, 3717), (3725, 3726, 3717)]
[(3723, 2016, 3719), (3723, 3169, 3778), (3723, 2013, 2016), (3723, 3167, 3169), (2013, 3723, 3778), (3167, 3723, 3719)]
[(3723, 2016, 3719), (3719, 2032, 3758), (3167, 3719, 3166), (3719, 2016, 2032), (3166, 3719, 3758), (3167, 3723, 3719)]
[(3662, 3724, 3660), (3715, 3724, 3714), (3667, 3660, 3724), (3662, 3714, 3724), (3715, 3716, 3724), (3667, 3724, 3716)]
[(3663, 3667, 3670), (3667, 3660, 37

[(4264, 4278, 4265), (4278, 4290, 4291), (4264, 4277, 4278), (4265, 4278, 4279), (4278, 4277, 4290), (4279, 4278, 4291)]
[(4266, 4280, 4267), (4280, 4292, 4293), (4346, 4267, 4280), (4346, 4280, 4293), (4266, 4279, 4280), (4280, 4279, 4292)]
[(4265, 4279, 4266), (4279, 4291, 4292), (4266, 4279, 4280), (4265, 4278, 4279), (4280, 4279, 4292), (4279, 4278, 4291)]
[(4281, 4295, 4282), (4006, 4281, 4002), (4281, 3998, 4002), (4268, 4281, 4282), (4281, 4294, 4295), (4006, 4294, 4281), (4281, 4268, 3998)]
[(3997, 4002, 3998), (4001, 4006, 4002), (4006, 4281, 4002), (4281, 3998, 4002), (3997, 4001, 4002)]
[(4268, 4282, 4269), (4270, 4282, 4283), (4281, 4295, 4282), (4268, 4281, 4282), (4270, 4269, 4282), (4283, 4282, 4295)]
[(4281, 4295, 4282), (4006, 4281, 4002), (4281, 3998, 4002), (4268, 4281, 4282), (4281, 4294, 4295), (4006, 4294, 4281), (4281, 4268, 3998)]
[(4270, 4282, 4283), (4283, 4297, 4284), (4283, 4295, 4296), (4270, 4283, 4284), (4283, 4296, 4297), (4283, 4282, 4295)]
[(4268, 4282

[(4841, 4853, 4854), (4842, 4854, 4855), (4854, 4866, 4867), (4842, 4841, 4854), (4854, 4853, 4866), (4855, 4854, 4867)]
[(4841, 4853, 4854), (4560, 4853, 4556), (4841, 4840, 4853), (4854, 4853, 4866), (4560, 4866, 4853), (4556, 4853, 4840)]
[(4842, 4854, 4855), (4855, 4867, 4868), (4842, 4855, 4856), (4856, 4855, 4868), (4855, 4854, 4867)]
[(4841, 4853, 4854), (4842, 4854, 4855), (4854, 4866, 4867), (4842, 4841, 4854), (4854, 4853, 4866), (4855, 4854, 4867)]
[(4842, 4856, 4843), (4856, 4868, 4869), (4856, 4870, 4857), (4842, 4855, 4856), (4843, 4856, 4857), (4856, 4855, 4868), (4856, 4869, 4870)]
[(4842, 4854, 4855), (4855, 4867, 4868), (4842, 4855, 4856), (4856, 4855, 4868), (4855, 4854, 4867)]
[(4843, 4857, 4844), (4856, 4870, 4857), (4843, 4856, 4857), (4844, 4857, 4858), (4858, 4857, 4870)]
[(4842, 4856, 4843), (4856, 4868, 4869), (4856, 4870, 4857), (4842, 4855, 4856), (4843, 4856, 4857), (4856, 4855, 4868), (4856, 4869, 4870)]
[(4846, 4858, 4859), (4844, 4858, 4845), (4858, 4872

[(507, 362, 0), (0, 104, 161), (0, 362, 1), (161, 507, 0), (0, 1, 104)]
[(105, 161, 104), (161, 162, 395), (0, 104, 161), (105, 162, 161), (62, 161, 395), (507, 161, 62), (161, 507, 0)]
[(550, 1, 514), (1, 989, 104), (514, 1, 362), (0, 362, 1), (550, 989, 1), (0, 1, 104)]
[(105, 161, 104), (0, 104, 161), (104, 1009, 105), (1, 989, 104), (0, 1, 104), (104, 989, 1009)]
[(4, 3, 2), (3, 362, 2), (2, 588, 219), (4, 2, 219), (507, 2, 362), (2, 507, 588)]
[(588, 63, 566), (507, 62, 588), (2, 588, 219), (588, 62, 63), (219, 588, 566), (2, 507, 588)]
[(4, 3, 2), (3, 362, 2), (3, 852, 855), (3, 497, 362), (497, 3, 855), (3, 4, 852)]
[(3, 362, 2), (507, 362, 0), (513, 362, 497), (3, 497, 362), (514, 1, 362), (0, 362, 1), (513, 514, 362), (507, 2, 362)]
[(4, 3, 2), (4, 219, 845), (4, 850, 852), (4, 2, 219), (3, 4, 852), (4, 845, 850)]
[(4, 219, 845), (2, 588, 219), (4, 2, 219), (845, 219, 566), (219, 588, 566)]
[(5, 7, 6), (574, 7, 5), (1069, 574, 5), (5, 587, 1069), (587, 5, 6)]
[(587, 1066, 1069

[(621, 658, 656), (659, 658, 622), (656, 658, 1023), (658, 659, 666), (621, 622, 658), (1023, 658, 664), (664, 658, 666)]
[(664, 598, 568), (829, 665, 664), (1023, 664, 657), (1023, 658, 664), (657, 664, 662), (662, 664, 665), (664, 658, 666), (664, 666, 598), (568, 829, 664)]
[(659, 658, 622), (659, 622, 626), (658, 659, 666), (598, 666, 659), (1049, 659, 661), (659, 626, 661), (1049, 598, 659)]
[(659, 622, 626), (621, 626, 622), (626, 625, 624), (624, 627, 626), (661, 626, 627), (659, 626, 661), (626, 621, 625)]
[(660, 941, 902), (742, 660, 901), (335, 660, 582), (742, 582, 660), (901, 660, 902), (660, 335, 941)]
[(333, 334, 335), (335, 582, 333), (335, 913, 941), (913, 335, 334), (335, 660, 582), (660, 335, 941)]
[(623, 817, 661), (1049, 659, 661), (661, 627, 623), (661, 626, 627), (659, 626, 661), (661, 817, 1049)]
[(642, 627, 624), (627, 642, 623), (624, 627, 626), (661, 627, 623), (661, 626, 627)]
[(663, 657, 662), (595, 663, 662), (657, 664, 662), (662, 664, 665), (665, 595, 662

[(1190, 1196, 1188), (1196, 1200, 1201), (1194, 1188, 1196), (1190, 1200, 1196), (1196, 1201, 1194)]
[(1190, 1199, 1200), (1196, 1200, 1201), (1199, 1204, 1200), (1200, 1204, 1209), (1190, 1200, 1196), (1200, 1209, 1201)]
[(1170, 1171, 1197), (1189, 1170, 1197), (1197, 1198, 1192), (1189, 1197, 1192), (1197, 1171, 1202), (1198, 1197, 1202)]
[(1168, 1169, 1170), (1170, 1171, 1197), (1170, 1169, 1171), (1189, 1170, 1197), (1168, 1170, 1189)]
[(1197, 1198, 1192), (1195, 1198, 1203), (1195, 1192, 1198), (1198, 1202, 1203), (1198, 1197, 1202)]
[(1170, 1171, 1197), (1189, 1170, 1197), (1197, 1198, 1192), (1189, 1197, 1192), (1197, 1171, 1202), (1198, 1197, 1202)]
[(1190, 1199, 1200), (1195, 1199, 1193), (1199, 1204, 1200), (1199, 1203, 1204), (1190, 1193, 1199), (1199, 1195, 1203)]
[(1195, 1199, 1193), (1195, 1198, 1203), (1193, 1192, 1195), (1195, 1192, 1198), (1199, 1195, 1203)]
[(1190, 1199, 1200), (1196, 1200, 1201), (1199, 1204, 1200), (1200, 1204, 1209), (1190, 1200, 1196), (1200, 1209

[(226, 193, 210), (225, 226, 227), (226, 210, 227), (226, 230, 193), (225, 230, 226)]
[(210, 42, 63), (210, 63, 120), (226, 193, 210), (226, 210, 227), (227, 210, 163), (210, 193, 42), (210, 120, 163)]
[(227, 163, 229), (225, 226, 227), (227, 237, 225), (228, 234, 227), (229, 228, 227), (226, 210, 227), (227, 210, 163), (227, 234, 237)]
[(210, 42, 63), (210, 63, 120), (226, 193, 210), (226, 210, 227), (227, 210, 163), (210, 193, 42), (210, 120, 163)]
[(228, 229, 236), (228, 234, 227), (229, 228, 227)]
[(227, 163, 229), (225, 226, 227), (227, 237, 225), (228, 234, 227), (229, 228, 227), (226, 210, 227), (227, 210, 163), (227, 234, 237)]
[(227, 163, 229), (228, 229, 236), (231, 235, 229), (231, 229, 163), (229, 228, 227), (229, 235, 236)]
[(120, 8, 163), (227, 163, 229), (231, 229, 163), (231, 163, 127), (227, 210, 163), (210, 120, 163), (127, 163, 8)]
[(230, 109, 193), (230, 233, 79), (225, 238, 230), (226, 230, 193), (230, 79, 109), (225, 230, 226), (230, 238, 233)]
[(226, 193, 210), (

In [6]:
torch.manual_seed(args['seed'])

if GPU:
    device = torch.device("cuda:"+str(device_idx) if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(device)

tspirals = [torch.from_numpy(s).long().to(device) for s in spirals_np]
tD = [torch.from_numpy(s).float().to(device) for s in bD]
tU = [torch.from_numpy(s).float().to(device) for s in bU]

cpu


In [7]:
# Building model, optimizer, and loss function

dataset_train = autoencoder_dataset(root_dir = args['data'], points_dataset = 'train',
                                           shapedata = shapedata,
                                           normalization = args['normalization'])

dataloader_train = DataLoader(dataset_train, batch_size=args['batch_size'],\
                                     shuffle = args['shuffle'], num_workers = args['num_workers'])

dataset_val = autoencoder_dataset(root_dir = args['data'], points_dataset = 'val', 
                                         shapedata = shapedata,
                                         normalization = args['normalization'])

dataloader_val = DataLoader(dataset_val, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])

#args['data'] = os.path.join(root_dir, dataset, name)
#print(args["data"])

dataset_test = autoencoder_dataset(root_dir = args['data'], points_dataset = 'test',
                                          shapedata = shapedata,
                                          normalization = args['normalization'])

dataloader_test = DataLoader(dataset_test, batch_size=args['batch_size'],\
                                     shuffle = False, num_workers = args['num_workers'])#args['num_workers'])



if 'autoencoder' in args['generative_model']:
        model = SpiralAutoencoder(filters_enc = args['filter_sizes_enc'],   
                                  filters_dec = args['filter_sizes_dec'],
                                  latent_size=args['nz'],
                                  sizes=sizes,
                                  spiral_sizes=spiral_sizes,
                                  spirals=tspirals,
                                  D=tD, U=tU,device=device).to(device)
 
    
optim = torch.optim.Adam(model.parameters(),lr=args['lr'],weight_decay=args['regularization'])
if args['scheduler']:
    scheduler=torch.optim.lr_scheduler.StepLR(optim, args['decay_steps'],gamma=args['decay_rate'])
else:
    scheduler = None

if args['loss']=='l1':
    def loss_l1(outputs, targets):
        L = torch.abs(outputs - targets).mean()
        return L 
    loss_fn = loss_l1


In [8]:
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params)) 
print(model)
# print(M[4].v.shape)

Total number of parameters is: 989827
SpiralAutoencoder(
  (conv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=27, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conv): Linear(in_features=576, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (2): SpiralConv(
      (conv): Linear(in_features=576, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (3): SpiralConv(
      (conv): Linear(in_features=576, out_features=128, bias=True)
      (activation): ELU(alpha=1.0)
    )
  )
  (fc_latent_enc): Linear(in_features=2688, out_features=128, bias=True)
  (fc_latent_dec): Linear(in_features=128, out_features=2688, bias=True)
  (dconv): ModuleList(
    (0): SpiralConv(
      (conv): Linear(in_features=1152, out_features=64, bias=True)
      (activation): ELU(alpha=1.0)
    )
    (1): SpiralConv(
      (conv): Linear(in_features=576, out_features=64, bias=True)
      (activation): 

In [9]:
if args['mode'] == 'train':
    writer = SummaryWriter(summary_path)
    with open(os.path.join(args['results_folder'],'checkpoints', args['name'] +'_params.json'),'w') as fp:
        saveparams = copy.deepcopy(args)
        json.dump(saveparams, fp)
        
    if args['resume']:
            print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file'])))
            checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
            start_epoch = checkpoint_dict['epoch'] + 1
            model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
            optim.load_state_dict(checkpoint_dict['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint_dict['scheduler_state_dict'])
            print('Resuming from epoch %s'%(str(start_epoch)))     
    else:
        start_epoch = 0
        
    if args['generative_model'] == 'autoencoder':
        train_autoencoder_dataloader(dataloader_train, dataloader_val,
                          device, model, optim, loss_fn,
                          bsize = args['batch_size'],
                          start_epoch = start_epoch,
                          n_epochs = args['num_epochs'],
                          eval_freq = args['eval_frequency'],
                          scheduler = scheduler,
                          writer = writer,
                          save_recons=True,
                          shapedata=shapedata,
                          metadata_dir=checkpoint_path, samples_dir=samples_path,
                          checkpoint_path = args['checkpoint_file'])

In [10]:
args['mode'] = 'test'
if args['mode'] == 'test':
    print('loading checkpoint from file %s'%(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar')))
    checkpoint_dict = torch.load(os.path.join(checkpoint_path,args['checkpoint_file']+'.pth.tar'),map_location=device)
    model.load_state_dict(checkpoint_dict['autoencoder_state_dict'])
        
    target, predictions, norm_l1_loss, l2_loss = test_autoencoder_dataloader(device, model, dataloader_test, 
                                                                     shapedata, mm_constant = 1000)    
    np.save(os.path.join(prediction_path,'predictions'), predictions)
    np.save(os.path.join(prediction_path,'targets'), target) 
        
    print('autoencoder: normalized loss', norm_l1_loss)
    
    print('autoencoder: euclidean distance in mm=', l2_loss)

loading checkpoint from file Data\COMA\results_identity/spirals_autoencoder\latent_128\checkpoints\checkpoint.pth.tar


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 92/92 [01:25<00:00,  1.08it/s]


autoencoder: normalized loss 0.19860273599624634
autoencoder: euclidean distance in mm= 0.9910764098167419


In [11]:
test = np.load(os.path.join(root_dir, dataset, 'preprocessed_identity/points_test/0.npy'))

In [12]:
train = np.load(os.path.join(root_dir, dataset, 'preprocessed_identity/points_train/0.npy'))

In [13]:
np.save(os.path.join(prediction_path,'test_COMA_normalized'), test.reshape((1,5023,3)))

In [14]:
train

array([[ 0.06620687, -0.00801941, -0.05462748],
       [ 0.06969664, -0.00904169, -0.05315745],
       [ 0.07025648, -0.00751739, -0.05270511],
       ...,
       [-0.0377722 ,  0.01930822,  0.01464073],
       [-0.03565262,  0.01973136,  0.01348286],
       [-0.03335112,  0.02019059,  0.0127714 ]])