In [94]:
%load_ext autoreload
%autoreload 2

import os
import torch
import time
import json
import numpy as np

from src import tools
from src.sdf_convertion import reconstruct_voxels, reconstruct_trimesh
from src.simulation import simulate

import plotly.graph_objects as go

DEVICE = "cuda:2"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# Load the model

import sys
sys.path.insert(1, "external/MeshSDF")
from lib.models.decoder import DeepSDF

experiment_dir = "runs/wolfish_e256/"

specs = json.load(open(os.path.join(experiment_dir, "specs.json")))
train_mapping = json.load(open(specs["TrainSplit"]))
data_mapping = json.load(open("/".join(specs["TrainSplit"].split("/")[:-1] + ["mapping.json"])))

# Load the model
decoder = torch.nn.DataParallel(DeepSDF(specs["CodeLength"],  **specs["NetworkSpecs"]), device_ids=[DEVICE])
saved_model_state = torch.load(
    os.path.join(experiment_dir, "ModelParameters", "latest.pth"), map_location=DEVICE
)
decoder.load_state_dict(saved_model_state["model_state_dict"])
decoder = decoder.eval()

# Load latent codes
orig_latents = torch.load(os.path.join(experiment_dir, "LatentCodes/latest.pth"), map_location=DEVICE)["latent_codes"]["weight"]

In [3]:
from pathlib import Path
import sys
import time
import math
import random
import copy
from collections import deque
from tqdm import trange, tqdm

import scipy
import scipy.optimize
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import trimesh
import pyvista as pv
import imageio

from diffpd.fem import DeformableHex, HydrodynamicsStateForceHex
from diffpd.sim import Sim
from diffpd.nn import OpenFoldController
from diffpd import transforms
from diffpd.mesh import MeshHex

In [4]:
seed = 42
pv.start_xvfb()

# Mesh parameters.
length = 20
dx = 1. / length

# Hydrodynamics parameters.
rho = 1e1
v_water = [0, 0, 0]   # Velocity of the water.
# Cd_points = (angle, coeff) pairs where angle is normalized to [0, 1].
Cd_points = np.array([
    [0.0, 0.05],
    [0.4, 0.05],
    [0.7, 1.85],
    [1.0, 2.05],
]) # * 1.0
# Ct_points = (angle, coeff) pairs where angle is normalized to [-1, 1].
Ct_points = np.array([
    [-1, -0.8],
    [-0.3, -0.5],
    [0.3, 0.1],
    [1, 2.5],
])

# FEM parameters.
youngs_modulus = 1e6
poissons_ratio = 0.45
average_density = 1e1
dt = 3.33e-3
method = 'pd_eigen'
options = {
    'max_pd_iter': 2000, 'max_ls_iter': 10, 'abs_tol': 1e-4, 'rel_tol': 1e-3,
    'verbose': 0, 'thread_ct': 64, 'use_bfgs': 1, 'bfgs_history_size': 10
}


# def simulate(voxels, num_frames=100):
#     random.seed(seed)
#     np.random.seed(seed)
#     torch.manual_seed(seed)
#     # torch.set_default_dtype(torch.float64)

# Optimisation

In [11]:
def simmm(voxels):
    ## Inint simulator
    num_frames = 20

    shape = voxels.shape
    rest_mesh = MeshHex.load(voxels.clone().detach().numpy(), dx=dx)
    transform = []
    transform.append(transforms.AddStateForce(
        'hydrodynamics', [rho] + v_water + Cd_points.ravel().tolist() + Ct_points.ravel().tolist() + rest_mesh.boundary.ravel().tolist()))

    actuator_scale = 0.04
    actuator_height = int(shape[2] * actuator_scale)
    actuator_width =  int(shape[1] * actuator_scale)

    all_muscles = []
    shared_muscles = []
    for z in range(int(shape[2] / 2) - actuator_height, int(shape[2] / 2) + actuator_height):
        muscle_pair = []
        for y in range(int(shape[1] / 2) - actuator_width, int(shape[1] / 2) + actuator_width):
            indices = rest_mesh.cell_indices[int(0.45 * shape[0]):int(0.5 * shape[0]), y, z].tolist()
            transform.append(transforms.AddActuationEnergy(1e6, [1.0, 0.0, 0.0], indices))
            muscle_pair.append(indices)

            # print(indices)
        shared_muscles.append(muscle_pair)
    all_muscles.append(shared_muscles)

    transform = transforms.Compose(transform)

    deformable = DeformableHex(
        rest_mesh, density=average_density, dt=dt, method=method, options=options,)
    deformable = transform(deformable)

    dofs = deformable.dofs()
    act_dofs = deformable.act_dofs()

    q0 = torch.as_tensor(rest_mesh.vertices).view(-1).clone().detach().to(torch.float64)
    v0 = torch.zeros(dofs).detach().to(torch.float64)

    head_indices = rest_mesh.node_indices[0, 0, 0]

    sim = Sim(deformable)
    sim.add_default_pd_energies(['corotated'], youngs_modulus, poissons_ratio)

    voxel_mesh = q0.clone()
    voxel_mesh.requires_grad = True
    voxel_mesh = voxel_mesh.to(torch.float64)

    controller = OpenFoldController(
        deformable, all_muscles,
        num_steps=num_frames,
        segment_len=1,
        init_period=16.0,
        init_magnitude=128.0).to(torch.float64)

    a = None
    q, v = q0, v0

    qs, vs = [], []
    for a in controller():
        q, v = sim(q, v, a, shape=voxel_mesh)

        qs.append(q)
        vs.append(v)

    # for idx in range(20):
    #     print(torch.mean(vs[idx].reshape(-1, 3)[:, 0]).item(),
    #           torch.mean(vs[idx].reshape(-1, 3)[:, 1]).item(),
    #           torch.mean(vs[idx].reshape(-1, 3)[:, 2]).item()
    #         )
    speed = torch.mean((qs[-1] - q0).reshape(-1, 3)[:, 0])#torch.mean(torch.concatenate(vs).reshape(-1, 3))
    
    return speed, voxel_mesh

In [115]:
dx_sdf = 1. / 32

latent_codes = []
metric  = []
latent = torch.clone(orig_latents[14]).requires_grad_(True)
optimizer = torch.optim.Adam([latent], lr=3e-4) #3e-4

latent_codes.append(latent.clone())

start_time = time.time()
for it in range(10):
    optimizer.zero_grad()

    # Forward
    voxels = reconstruct_voxels(decoder, latent, N=[64, 32, 32])
    speed, voxel_mesh = simmm(voxels)
    loss = -speed
    loss.backward()

    # Backward through MeshSDF - don't forget to convert the 
    # Modify the dL/dx_i to be in the DeepSDF coordinate system
    dL_dx_i = dx / dx_sdf * voxel_mesh.grad.reshape(-1, 3).type(torch.FloatTensor)
    dL_dx_i = dL_dx_i.to(DEVICE)
    # use vertices to compute full backward pass
    optimizer.zero_grad()

    # Convert voxel_mesh to DeepSDF cooridnate system !!!!!!!!
    voxel_mesh = voxel_mesh.reshape(-1, 3) * dx_sdf / dx + torch.tensor([[-1, -0.5, -0.5]], dtype=voxel_mesh.dtype, device=voxel_mesh.device)
    xyz = voxel_mesh.clone().detach().type(torch.FloatTensor)
    xyz = xyz.to(DEVICE).requires_grad_(True)
    latent_inputs = latent.expand(xyz.shape[0], -1)

    #first compute normals
    pred_sdf = decoder(latent_inputs, xyz)
    # for df, grad in zip(pred_sdf, dL_dx_i):
    #     print(df.item(), grad.item())

    loss_normals = torch.sum(pred_sdf)
    loss_normals.backward(retain_graph = True)
    # normalization to take into account for the fact sdf is not perfect...

    normals = xyz.grad/torch.norm(xyz.grad, 2, 1).unsqueeze(-1)
    # now assemble inflow derivative
    optimizer.zero_grad()
    
    # Drop points inside the mesh - gradients there are not reliable
    filt = (pred_sdf[:, 0] <= dx_sdf / 2) & (pred_sdf[:, 0] >= -dx_sdf / 2)
    dL_ds_i = -torch.matmul(dL_dx_i[filt].unsqueeze(1), normals[filt].unsqueeze(-1)).squeeze(-1)
    # refer to Equation (4) in the main paper
    loss_backward = torch.sum(dL_ds_i * pred_sdf[filt])
    loss_backward.backward()
    # and update params
    optimizer.step()

    time_elapsed = time.time() - start_time
    metric.append(speed)
    latent_codes.append(latent.clone())
    print(f"#{it:4d} : Fish Speed: {speed:.5f}       Time Elapsed: {time_elapsed:.2f}", end='\n')

In [111]:
import plotly.graph_objects as go

fig = go.Figure(data=[go.Scatter(y=[s.cpu().detach() for s in metric])])
fig.update_layout(template="plotly_dark")
fig.show()

In [112]:
shapes_reconstructed = []
for code in tqdm(latent_codes):
    voxels = reconstruct_trimesh(decoder, code, N=[128, 64, 64])
    shapes_reconstructed.append(voxels)

100%|██████████| 6/6 [00:02<00:00,  2.24it/s]


In [113]:
traces = []
for verts, faces in shapes_reconstructed[::2]:
    traces.append(tools.plot_3d_mesh(verts, faces))
tools.show_grid(*traces)
tools.show_grid([tr[0] for tr in traces])