In [1]:
import trimesh
import numpy as np
import torch
from sample import *
from hermite_spline import *
from unet import *
from torch.optim import Adam, LBFGS
from tqdm import tqdm

In [2]:
def h0(diff, der = 0):
    abs_diff = torch.abs(diff)
    if der == 0:
        return (1-abs_diff)**2*(1+2*abs_diff)
    elif der == 1:
        return 6*diff*(abs_diff-1)
    elif der == 2:
        return 12*abs_diff-6

def h1(diff, der = 0):
    abs_diff = torch.abs(diff)
    if der == 0:
        return diff*(1-abs_diff)**2
    elif der == 1:
        return 3*abs_diff**2 - 4*abs_diff + 1
    elif der == 2:
        return 6*diff - 4*abs_diff

basis_functions = [h0, h1]

In [9]:
def get_field(coefficients, field_idx, points, x_support, y_support, z_support, step, basis, der_x = 0, der_y = 0, der_z = 0):
    field = torch.zeros(points.shape[0])
    for l, p in enumerate(points):
        result = 0.0
        coeff = coefficients[field_idx, :, x_support[0, l]:x_support[1, l]+1, y_support[0, l]:y_support[1, l]+1, z_support[0, l]:z_support[1, l]+1].view(2, 2, 2, 2, 2, 2)
        x_indices = torch.arange(x_support[0, l], x_support[1, l] + 1)  # x range
        y_indices = torch.arange(y_support[0, l], y_support[1, l] + 1)  # y range
        z_indices = torch.arange(z_support[0, l], z_support[1, l] + 1)  # z range
        x_coords = x_indices * step[0]
        y_coords = y_indices * step[1]
        z_coords = z_indices * step[2]
        # Generate all combinations of grid coordinates
        xx, yy, zz = torch.meshgrid(x_coords, y_coords, z_coords, indexing="ij")
        grid_points = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)  # Shape: (8, 3)
        for a in range(2):
            for b in range(2):
                for c in range(2):
                    for i in range(2):
                        for j in range(2):
                            for k in range(2):
                                h_i = basis[i](((p[0] - grid_points[a * 4 + b * 2 + c, 0]))/step[0], der_x)
                                h_j = basis[j]((p[1] - grid_points[a * 4 + b * 2 + c, 1])/step[1], der_y)
                                h_k = basis[k]((p[2] - grid_points[a * 4 + b * 2 + c, 2])/step[2], der_z)
                                result += coeff[i, j, k, a, b, c] * h_i * h_j * h_k
        field[l] = result
    return field

In [4]:
def get_support_points(points):
    x = points[:,0]
    y = points[:,1]
    z = points[:,2]
    x_floor = (x//step[0]).long()
    y_floor = (y//step[1]).long()
    z_floor = (z//step[2]).long()
    x_support_indices = torch.vstack((x_floor, torch.clamp(x_floor+1,max=grid_resolution[0]-1)))
    y_support_indices = torch.vstack((y_floor, torch.clamp(y_floor+1,max=grid_resolution[1]-1)))
    z_support_indices = torch.vstack((z_floor, torch.clamp(z_floor+1,max=grid_resolution[2]-1)))
    return x_support_indices,y_support_indices,z_support_indices

In [5]:
obj = trimesh.load("./Baseline_ML4Science.stl")
grid_resolution = np.array([20,20,20])
step = torch.tensor(obj.bounding_box.extents/(grid_resolution-1))

In [6]:
device = 'cpu'
binary_mask = get_binary_mask(obj, grid_resolution)
unet_model = UNet3D().to(device)
optim = Adam(unet_model.parameters(), lr = 1e-3)
unet_input = prepare_mesh_for_unet(binary_mask)

In [7]:
spline_coeff = unet_model(unet_input)[0]
points = torch.tensor(trimesh.sample.volume_mesh(obj, 100))
points.requires_grad_(True)
x_support, y_support, z_support = get_support_points(points)

In [10]:
for i in range(3):
    spline_coeff = unet_model(unet_input)[0]

    points = torch.tensor(trimesh.sample.volume_mesh(obj, 100))
    points.requires_grad_(True)

    x_support, y_support, z_support = get_support_points(points)

    vx = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions)
    vy = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions)
    vz = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions)
    p = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions)
    
    vx_x = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vx_y = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vx_z = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)

    vy_x = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vy_y = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vy_z = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    vz_x = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vz_y = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vz_z = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    p_x = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    p_y = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    p_z = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    vx_xx = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vx_yy = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vx_zz = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    vy_xx = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vy_yy = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vy_zz = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    vz_xx = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vz_yy = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vz_zz = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    rho = 1.010427
    mu = 2.02e-5
    loss_divergence = torch.mean((vx_x + vy_y + vz_z)**2)
    loss_momentum_x = torch.mean((vx*vx_x + vy*vx_y + vz*vx_z + (1/rho)*p_x - (mu/rho)*(vx_xx + vx_yy + vx_zz))**2)
    loss_momentum_y = torch.mean((vx*vy_x + vy*vy_y + vz*vy_z + (1/rho)*p_y - (mu/rho)*(vy_xx + vy_yy + vy_zz))**2)
    loss_momentum_z = torch.mean((vx*vz_x + vy*vz_y + vz*vz_z + (1/rho)*p_z - (mu/rho)*(vz_xx + vz_yy + vz_zz))**2)

    loss = loss_divergence + 10.0*loss_momentum_x + 10.0*loss_momentum_y + 10.0*loss_momentum_z
    print(f'Loss: {loss_divergence.item()}, {loss_momentum_x.item()}, {loss_momentum_y.item()}, {loss_momentum_z.item()}')
    optim.zero_grad()
    loss.backward()
    optim.step()

Loss: 0.00313269579783082, 0.0006003637099638581, 0.00217741378583014, 0.0003365481679793447
Loss: 0.0033193277195096016, 0.000492704682983458, 0.00025023528723977506, 0.00031928118551149964
Loss: 0.0020049293525516987, 0.0008497923263348639, 0.0014850469306111336, 0.00012643424270208925
