In [2]:
%load_ext autoreload
%autoreload 2

import json
from run_optimisations import run_optimisation
import plotly.graph_objects as go
import numpy as np

import torch
from tqdm import tqdm
import sys
sys.path.insert(1, "external/MeshSDF")
from lib.models.decoder import DeepSDF

import src.simulation as simm
from src.sdf_convertion import reconstruct_trimesh, reconstruct_voxels
from src import tools

DEVICE = "cuda:0"



# Optimisation

In [3]:
sum = run_optimisation("optimisation_runs/test_2_actu_lr_1e3_round2", num_iters=20) #, N=[128, 64, 64], dx_sdf=1./64

Start optimisation optimisation_runs/test_2_actu_lr_1e3_round2 from shape #14 with lr=0.001 and 20 iterations [64, 32, 32]
#   4 : Fish Speed: -0.02608       Time Elapsed: 381.96

# Visualize 

In [29]:
!ls optimisation_runs

test  test_1_actu  test_2_actu


In [38]:

fig_name = "optimisation_runs/test_2_actu_lr_1e3/summary.json"
with open(fig_name, "r") as f:
    summary = json.load(f)

fig = go.Figure(data=[go.Scatter(y=summary["run"]["metrics"])])
fig.update_layout(template="plotly_dark", title=fig_name)
fig.show()

In [None]:
# Load the model
experiment_dir = "runs/wolfish_e256/"

specs = json.load(open(os.path.join(experiment_dir, "specs.json")))
train_mapping = json.load(open(specs["TrainSplit"]))
data_mapping = json.load(open("/".join(specs["TrainSplit"].split("/")[:-1] + ["mapping.json"])))

# Load the model
decoder = torch.nn.DataParallel(DeepSDF(specs["CodeLength"],  **specs["NetworkSpecs"]), device_ids=[DEVICE])
saved_model_state = torch.load(
    os.path.join(experiment_dir, "ModelParameters", "latest.pth"), map_location=DEVICE
)
decoder.load_state_dict(saved_model_state["model_state_dict"])
decoder = decoder.eval()

# Load latent codes
orig_latents = torch.load(os.path.join(experiment_dir, "LatentCodes/latest.pth"), map_location=DEVICE)["latent_codes"]["weight"]

shapes_reconstructed = []
rec_idxs = np.arange(0, len(summary["run"]["latents"]) + 1, len(summary["run"]["latents"]) // 2).tolist()
rec_idxs[-1] = len(summary["run"]["latents"]) - 1
for code in tqdm(np.array(summary["run"]["latents"])[rec_idxs]):
    latent = torch.tensor(code, dtype=torch.float32).to(DEVICE)
    mesh = reconstruct_trimesh(decoder, latent, N=[128, 64, 64])
    shapes_reconstructed.append(mesh)

100%|██████████| 3/3 [00:01<00:00,  1.67it/s]


In [40]:
traces = []
for verts, faces in shapes_reconstructed:
    traces.append(tools.plot_3d_mesh(verts, faces))
tools.show_grid(*traces, names=[f"iteration #{idx}" for idx in rec_idxs])
tools.show_grid([tr[0] for tr in traces])

# Generate Videos

In [41]:
video_root = "/".join(fig_name.split("/")[:-1])
rec_idxs = np.arange(0, len(summary["run"]["latents"]) + 1, len(summary["run"]["latents"]) // 2).tolist()
rec_idxs[-1] = len(summary["run"]["latents"]) - 1

for idx in tqdm(rec_idxs):
    latent = torch.tensor(summary["run"]["latents"][idx], dtype=torch.float32).to(DEVICE)
    voxels = reconstruct_voxels(decoder, latent, N=[64, 32, 32])

    video_name = os.path.join(video_root, f"{idx:04d}.mp4")
    speed, voxel_mesh = simm.simulate(voxels, make_video=video_name, num_frames=100)

100%|██████████| 3/3 [00:49<00:00, 16.58s/it]


In [8]:
tools.show(tools.plot_3d_point_cloud(voxel_mesh.reshape(-1, 3).cpu().detach().numpy()))

# Visualize muscles

In [34]:
from diffpd.mesh import MeshHex
from diffpd.mesh.utils import filter_unused_vertices
from diffpd import transforms

latent = torch.tensor(summary["run"]["latents"][0], dtype=torch.float32).to(DEVICE)
voxels = reconstruct_voxels(decoder, latent, N=[64, 32, 32])

shape = voxels.shape
rest_mesh = MeshHex.load(voxels.clone().detach().numpy(), dx=1./20)
transform = []
transform.append(transforms.AddStateForce(
    'hydrodynamics', [simm.rho] + simm.v_water + simm.Cd_points.ravel().tolist() + simm.Ct_points.ravel().tolist() + rest_mesh.boundary.ravel().tolist()))

muscles = simm.add_muscles(shape, rest_mesh, transform)
muscles = np.concatenate([np.concatenate([np.concatenate(mm) for mm in m]) for m in muscles])

verts, faces = filter_unused_vertices(rest_mesh._vertices, rest_mesh._boundary)
faces = np.concatenate([faces[:, :3], faces[:, [0, 2, 3]]])

voxels_muscles = np.isin(rest_mesh.cell_indices, muscles)
voxels_muscles_mesh = MeshHex.load(voxels_muscles, dx=1./20)
verts_muscles, faces_muscles = filter_unused_vertices(voxels_muscles_mesh._vertices, voxels_muscles_mesh._boundary)
faces_muscles = np.concatenate([faces_muscles[:, :3], faces_muscles[:, [0, 2, 3]]])

traces = []
traces += [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], 
                     i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
                     opacity=0.5)]
traces += [go.Mesh3d(x=verts_muscles[:, 0], y=verts_muscles[:, 1], z=verts_muscles[:, 2], 
                     i=faces_muscles[:, 0], j=faces_muscles[:, 1], k=faces_muscles[:, 2],
                     opacity=0.8)]

tools.show(traces)

# Misc

In [None]:
from diffpd.mesh import MeshHex
from chamferdist import ChamferDistance

dx = 1./20
latent = torch.tensor(summary["run"]["latents"][-1], dtype=torch.float32).to(DEVICE)
voxels_ours = reconstruct_voxels(decoder, latent, N=[128, 64, 64])
voxels_target = reconstruct_voxels(decoder, orig_latents[41], N=[128, 64, 64])

voxels_ours = MeshHex.load(voxels_ours.clone().detach().numpy(), dx=dx)
voxels_ours = torch.as_tensor(voxels_ours.vertices).view(-1).clone().detach().to(torch.float32)
voxels_target = MeshHex.load(voxels_target.clone().detach().numpy(), dx=dx)
voxels_target = torch.as_tensor(voxels_target.vertices).view(-1).clone().detach().to(torch.float32)

chamferDist = ChamferDistance()
chamferDist(voxels_ours.reshape(1, -1, 3), voxels_target.reshape(1, -1, 3))