In [None]:
# Start Training

#!CUDA_VISIBLE_DEVICES=5 python external/MeshSDF/train_deep_sdf.py -e runs/wolfish_e32

In [None]:
%load_ext autoreload
%autoreload 2

import os
import igl
import torch
import json
import numpy as np
from src import tools
import plotly.graph_objects as go
from tqdm import tqdm

def load_logs(experiment_directory):

    full_filename = os.path.join(experiment_directory, "Logs.pth")

    if not os.path.isfile(full_filename):
        raise Exception('log file "{}" does not exist'.format(full_filename))

    data = torch.load(full_filename)

    return (
        data["loss"],
        data["epoch"],
    )

## Training Log

In [None]:
fig = go.Figure()
loss = load_logs("runs/wolfish_e16/")
fig.add_trace(go.Scatter(x=list(range(loss[1])), y=loss[0], name="Emb 16"))
loss = load_logs("runs/wolfish_e32/")
fig.add_trace(go.Scatter(x=list(range(loss[1])), y=loss[0], name="Emb 32"))
loss = load_logs("runs/wolfish_e256/")
fig.add_trace(go.Scatter(x=list(range(loss[1])), y=loss[0], name="Emb 256"))

fig.update_layout(
    title="DeepSDF Training Curve",
    xaxis_title="Epoch",
    yaxis_title="Loss"
)
fig.show()

## Load Trained Decoder

In [None]:
import sys
sys.path.insert(1, "external/MeshSDF")
from lib.models.decoder import DeepSDF

DEVICE="cuda:2"

experiment_dir = "runs/wolfish_e256/"

specs = json.load(open(os.path.join(experiment_dir, "specs.json")))
train_mapping = json.load(open(specs["TrainSplit"]))
data_mapping = json.load(open("/".join(specs["TrainSplit"].split("/")[:-1] + ["mapping.json"])))

# Load the model
decoder = torch.nn.DataParallel(DeepSDF(specs["CodeLength"],  **specs["NetworkSpecs"]), device_ids=[DEVICE])
saved_model_state = torch.load(
    os.path.join(experiment_dir, "ModelParameters", "latest.pth"), map_location=DEVICE
)
decoder.load_state_dict(saved_model_state["model_state_dict"])
decoder = decoder.eval()

# Load latent codes
orig_latents = torch.load(os.path.join(experiment_dir, "LatentCodes/latest.pth"), map_location=DEVICE)["latent_codes"]["weight"]

## Reconstruct 3 random meshes

In [None]:
from src.sdf_convertion import reconstruct_trimesh, reconstruct_voxels

latents_ids = [14] #np.random.choice(orig_latents.shape[0], 3)
meshes = []

for li in tqdm(latents_ids):
    
    # create_mesh or create_mesh_optim_fast
    verts, faces = reconstruct_trimesh(decoder, orig_latents[li], N=[64, 32, 32]) 
    vox = reconstruct_voxels(decoder, orig_latents[li], N=[128, 64, 64])

    meshes.append((verts, faces))

In [None]:
# np.save("fish14.npy", vox.data.cpu().numpy())

# Predicted
traces = []
for m in meshes:    
    traces.append(tools.plot_3d_mesh(m[0], m[1]))

tools.show_grid(*traces)

# GT
traces = []
for li in latents_ids:
    mesh_path = data_mapping[f"{li:04d}"]
    verts,_,_,faces,_,_ = igl.read_obj(mesh_path)
    traces.append(tools.plot_3d_mesh(verts, faces))

tools.show_grid(*traces)

## Interpolation in latent space

In [None]:
steps_num = 5

# Shapes to interpolate inbetween
# 32 - skat
# 15 - ugor
# 10 - dolphin
# 14 - standart fish
# 41 - shark
lstart_d, lend_id = np.random.choice(orig_latents.shape[0], 2, replace=False)

meshes = []
for aint in tqdm(range(steps_num + 1)):
    a = aint / steps_num
    latent = (1. - a) * orig_latents[lstart_d] + a * orig_latents[lend_id]
    # create_mesh or create_mesh_optim_fast 
    verts, faces, samples, next_indices = create_mesh(decoder, latent, output_mesh=True, N=128, device=DEVICE) 

    meshes.append((verts, faces))

In [None]:
# Predicted
traces = []
for m in meshes:    
    traces.append(tools.plot_3d_mesh(m[0], m[1]))

tools.show_grid(*traces)