In [12]:
%load_ext autoreload
%autoreload 2

import json
from run_optimisations import run_optimisation
import plotly.graph_objects as go
import numpy as np

import torch
from tqdm import tqdm
import sys
sys.path.insert(1, "external/MeshSDF")
from lib.models.decoder import DeepSDF

from src.simulation import simulate
from src.sdf_convertion import reconstruct_trimesh, reconstruct_voxels
from src import tools

DEVICE = "cuda:0"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Optimisation

In [13]:
sum = run_optimisation("optimisation_runs/test", num_iters=5, lr=1e-3) #, N=[128, 64, 64], dx_sdf=1./64

Start optimisation optimisation_runs/test from shape #14 with lr=0.001 and 5 iterations [64, 32, 32]
#   1 : Fish Speed: -0.24908       Time Elapsed: 119.37

# Visualize 

In [8]:
!ls optimisation_runs/test

summary.json


In [9]:

fig_name = "optimisation_runs/test/summary.json"
with open(fig_name, "r") as f:
    summary = json.load(f)

fig = go.Figure(data=[go.Scatter(y=summary["run"]["metrics"])])
fig.update_layout(template="plotly_dark", title=fig_name)
fig.show()

In [10]:
# Load the model
experiment_dir = "runs/wolfish_e256/"

specs = json.load(open(os.path.join(experiment_dir, "specs.json")))
train_mapping = json.load(open(specs["TrainSplit"]))
data_mapping = json.load(open("/".join(specs["TrainSplit"].split("/")[:-1] + ["mapping.json"])))

# Load the model
decoder = torch.nn.DataParallel(DeepSDF(specs["CodeLength"],  **specs["NetworkSpecs"]), device_ids=[DEVICE])
saved_model_state = torch.load(
    os.path.join(experiment_dir, "ModelParameters", "latest.pth"), map_location=DEVICE
)
decoder.load_state_dict(saved_model_state["model_state_dict"])
decoder = decoder.eval()

# Load latent codes
orig_latents = torch.load(os.path.join(experiment_dir, "LatentCodes/latest.pth"), map_location=DEVICE)["latent_codes"]["weight"]

shapes_reconstructed = []
rec_idxs = np.arange(0, len(summary["run"]["latents"]), len(summary["run"]["latents"]) // 1).tolist()
rec_idxs[-1] = len(summary["run"]["latents"]) - 1
for code in tqdm(np.array(summary["run"]["latents"])[rec_idxs]):
    latent = torch.tensor(code, dtype=torch.float32).to(DEVICE)
    mesh = reconstruct_trimesh(decoder, latent, N=[128, 64, 64])
    shapes_reconstructed.append(mesh)

100%|██████████| 5/5 [00:01<00:00,  2.54it/s]


In [11]:
traces = []
for verts, faces in shapes_reconstructed:
    traces.append(tools.plot_3d_mesh(verts, faces))
tools.show_grid(*traces, names=[f"iteration #{idx}" for idx in rec_idxs])
tools.show_grid([tr[0] for tr in traces])

# Generate Videos

In [17]:
video_root = "/".join(fig_name.split("/")[:-1])
rec_idxs = np.arange(0, len(summary["run"]["latents"]), len(summary["run"]["latents"]) // 3).tolist()
rec_idxs[-1] = len(summary["run"]["latents"]) - 2

for idx in tqdm(rec_idxs):
    latent = torch.tensor(summary["run"]["latents"][idx], dtype=torch.float32).to(DEVICE)
    voxels = reconstruct_voxels(decoder, latent, N=[64, 32, 32])

    video_name = os.path.join(video_root, f"{idx:04d}.mp4")
    speed, voxel_mesh = simulate(voxels, make_video=video_name)

100%|██████████| 4/4 [00:12<00:00,  3.21s/it]


# Misc

In [None]:
from diffpd.mesh import MeshHex
from chamferdist import ChamferDistance

dx = 1./20
latent = torch.tensor(summary["run"]["latents"][-1], dtype=torch.float32).to(DEVICE)
voxels_ours = reconstruct_voxels(decoder, latent, N=[128, 64, 64])
voxels_target = reconstruct_voxels(decoder, orig_latents[41], N=[128, 64, 64])

voxels_ours = MeshHex.load(voxels_ours.clone().detach().numpy(), dx=dx)
voxels_ours = torch.as_tensor(voxels_ours.vertices).view(-1).clone().detach().to(torch.float32)
voxels_target = MeshHex.load(voxels_target.clone().detach().numpy(), dx=dx)
voxels_target = torch.as_tensor(voxels_target.vertices).view(-1).clone().detach().to(torch.float32)

chamferDist = ChamferDistance()
chamferDist(voxels_ours.reshape(1, -1, 3), voxels_target.reshape(1, -1, 3))

tensor(0.)