In [1]:
import trimesh
import numpy as np
import torch
from sample import *
from hermite_spline import *
from unet import *
from torch.optim import Adam, LBFGS
from tqdm import tqdm

In [2]:
def h0(diff, der = 0):
    abs_diff = torch.abs(diff)
    if der == 0:
        return (1-abs_diff)**2*(1+2*abs_diff)
    elif der == 1:
        return 6*diff*(abs_diff-1)
    elif der == 2:
        return 12*abs_diff-6

def h1(diff, der = 0):
    abs_diff = torch.abs(diff)
    if der == 0:
        return diff*(1-abs_diff)**2
    elif der == 1:
        return 3*abs_diff**2 - 4*abs_diff + 1
    elif der == 2:
        return 6*diff - 4*abs_diff

basis_functions = [h0, h1]

In [3]:
def get_field(coefficients, field_idx, points, x_support, y_support, z_support, step, basis, der_x = 0, der_y = 0, der_z = 0):
    field = torch.zeros(points.shape[0])
    for l, p in enumerate(points):
        result = 0.0
        coeff = coefficients[field_idx, :, x_support[0, l]:x_support[1, l]+1, y_support[0, l]:y_support[1, l]+1, z_support[0, l]:z_support[1, l]+1].view(2, 2, 2, 2, 2, 2)
        x_indices = torch.arange(x_support[0, l], x_support[1, l] + 1)  # x range
        y_indices = torch.arange(y_support[0, l], y_support[1, l] + 1)  # y range
        z_indices = torch.arange(z_support[0, l], z_support[1, l] + 1)  # z range
        x_coords = x_indices * step[0]
        y_coords = y_indices * step[1]
        z_coords = z_indices * step[2]
        # Generate all combinations of grid coordinates
        xx, yy, zz = torch.meshgrid(x_coords, y_coords, z_coords, indexing="ij")
        grid_points = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)  # Shape: (8, 3)
        for a in range(2):
            for b in range(2):
                for c in range(2):
                    for i in range(2):
                        for j in range(2):
                            for k in range(2):
                                h_i = basis[i](((p[0] - grid_points[a * 4 + b * 2 + c, 0]))/step[0], der_x)
                                h_j = basis[j]((p[1] - grid_points[a * 4 + b * 2 + c, 1])/step[1], der_y)
                                h_k = basis[k]((p[2] - grid_points[a * 4 + b * 2 + c, 2])/step[2], der_z)
                                result += coeff[i, j, k, a, b, c] * h_i * h_j * h_k
        field[l] = result
    return field

In [4]:
def get_support_points(points):
    x = points[:,0]
    y = points[:,1]
    z = points[:,2]
    x_floor = (x//step[0]).long()
    y_floor = (y//step[1]).long()
    z_floor = (z//step[2]).long()
    x_support_indices = torch.vstack((x_floor, torch.clamp(x_floor+1,max=grid_resolution[0]-1)))
    y_support_indices = torch.vstack((y_floor, torch.clamp(y_floor+1,max=grid_resolution[1]-1)))
    z_support_indices = torch.vstack((z_floor, torch.clamp(z_floor+1,max=grid_resolution[2]-1)))
    return x_support_indices,y_support_indices,z_support_indices

In [5]:
obj = trimesh.load("./Baseline_ML4Science.stl")
grid_resolution = np.array([20,20,20])
step = torch.tensor(obj.bounding_box.extents/(grid_resolution-1))

In [6]:
device = 'cpu'
binary_mask = get_binary_mask(obj, grid_resolution)
unet_model = UNet3D().to(device)
optim = Adam(unet_model.parameters(), lr = 1e-3)
unet_input = prepare_mesh_for_unet(binary_mask)

In [7]:
spline_coeff = unet_model(unet_input)[0]
points = torch.tensor(trimesh.sample.volume_mesh(obj, 100))
points.requires_grad_(True)
x_support, y_support, z_support = get_support_points(points)

In [None]:
for i in range(3):
    spline_coeff = unet_model(unet_input)[0]

    points = torch.tensor(trimesh.sample.volume_mesh(obj, 100))
    points.requires_grad_(True)

    x_support, y_support, z_support = get_support_points(points)

    vx = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions)
    vy = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions)
    vz = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions)
    p = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions)
    
    vx_x = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vx_y = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vx_z = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)

    vy_x = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vy_y = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vy_z = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    vz_x = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    vz_y = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    vz_z = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    p_x = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 1, 0, 0)
    p_y = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 0, 1, 0)
    p_z = get_field(spline_coeff, 3, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 1)
    
    vx_xx = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vx_yy = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vx_zz = get_field(spline_coeff, 0, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    vy_xx = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vy_yy = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vy_zz = get_field(spline_coeff, 1, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    vz_xx = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 2, 0, 0)
    vz_yy = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 2, 0)
    vz_zz = get_field(spline_coeff, 2, points, x_support, y_support, z_support, step, basis_functions, 0, 0, 2)
    
    rho = 1.010427
    mu = 2.02e-5
    loss_divergence = torch.mean((vx_x + vy_y + vz_z)**2)
    loss_momentum_x = torch.mean((vx*vx_x + vy*vx_y + vz*vx_z + (1/rho)*p_x - (mu/rho)*(vx_xx + vx_yy + vx_zz))**2)
    loss_momentum_y = torch.mean((vx*vy_x + vy*vy_y + vz*vy_z + (1/rho)*p_y - (mu/rho)*(vy_xx + vy_yy + vy_zz))**2)
    loss_momentum_z = torch.mean((vx*vz_x + vy*vz_y + vz*vz_z + (1/rho)*p_z - (mu/rho)*(vz_xx + vz_yy + vz_zz))**2)

    loss = loss_divergence + 10.0*loss_momentum_x + 10.0*loss_momentum_y + 10.0*loss_momentum_z
    print(f'Loss: {loss_divergence.item()}, {loss_momentum_x.item()}, {loss_momentum_y.item()}, {loss_momentum_z.item()}')
    optim.zero_grad()
    loss.backward()
    optim.step()

Loss: 0.3221743702888489, 0.17541548609733582, 0.22631408274173737, 0.22136251628398895
Loss: 0.430877149105072, 0.3643985688686371, 0.22490420937538147, 0.218585804104805


KeyboardInterrupt: 

In [None]:
from sample import *

In [22]:
def get_supoort_points(points):
    x = points[:,0]
    y = points[:,1]
    z = points[:,2]
    x_floor = (x//step[0]).long()
    y_floor = (y//step[1]).long()
    z_floor = (z//step[2]).long()
    x_support_indices = torch.vstack((x_floor, torch.clamp(x_floor+1,max=grid_resolution[0]-1)))
    y_support_indices = torch.vstack((y_floor, torch.clamp(y_floor+1,max=grid_resolution[1]-1)))
    z_support_indices = torch.vstack((z_floor, torch.clamp(z_floor+1,max=grid_resolution[2]-1)))
    return x,y,z,x_support_indices,y_support_indices,z_support_indices

def f(spline_coeff, channel,points,der_x=0,der_y=0,der_z=0):
    x,y,z,x_support_indices,y_supports_indices,z_supports_indices = get_supoort_points(points)
    conv_sum = 0
    for type in range(spline_coeff[channel].shape[0]):
        i, j, k = binary_array(type)
        spline_coeff_ijk = spline_coeff[channel][type]
        for x_support_ind in x_support_indices:
            for y_support_ind in y_supports_indices:
                for z_support_ind in z_supports_indices:
                    # One of the 8 grid support points(enclosing cube vertices) for each sample point.
                    support_point_ind = torch.vstack((x_support_ind,y_support_ind,z_support_ind)).T

                    x_indices = support_point_ind[:, 0]
                    y_indices = support_point_ind[:, 1]
                    z_indices = support_point_ind[:, 2]

                    x_input = (x-x_indices*step[0])/step[0]
                    y_input = (y-y_indices*step[1])/step[1]
                    z_input = (z-z_indices*step[2])/step[2]
                    
                    conv_sum += (spline_coeff_ijk[x_indices, y_indices, z_indices]) * hermite_kernel_3d(i,j,k,x_input,y_input,z_input,der_x,der_y,der_z)
    return conv_sum

In [23]:
obj = trimesh.load("Baseline_ML4Science.stl")

grid_resolution = np.array([20,20,20])
binary_mask = get_binary_mask(obj, grid_resolution)
step = obj.bounding_box.extents/(grid_resolution-1)

# Instantiate the neural network
unet_model = UNet3D().to(device)
optimizer = Adam(unet_model.parameters(), lr = 1e-3)
unet_model.apply(initialize_weights)

UNet3D(
  (enc1): Sequential(
    (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (enc2): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (2): ReLU(inplace=True)
    (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (5): ReLU(inplace=True)
  )
  (bottleneck): Sequential(
    (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d

In [26]:
def get_binary_mask(model, grid_resolution):
    bounds = model.bounds
    min_bound, max_bound = bounds[0], bounds[1]
    x = np.linspace(min_bound[0], max_bound[0], grid_resolution[0])
    y = np.linspace(min_bound[1], max_bound[1], grid_resolution[1])
    z = np.linspace(min_bound[2], max_bound[2], grid_resolution[2])
    xx, yy, zz = np.meshgrid(x, y, z)
    grid_points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
    inside_points = grid_points[model.contains(grid_points)]
    binary_volume = np.zeros(grid_resolution, dtype=np.uint8)
    for point in inside_points:
        # Compute the index of the point in the grid
        index_x = np.searchsorted(x, point[0])
        index_y = np.searchsorted(y, point[1])
        index_z = np.searchsorted(z, point[2])
        # Set the corresponding position in the binary volume to 1
        if (
            0 <= index_x < grid_resolution[0]
            and 0 <= index_y < grid_resolution[1]
            and 0 <= index_z < grid_resolution[2]
        ):
            binary_volume[index_x, index_y, index_z] = 1
    return binary_volume

def get_inlet_surface_points(obj, num_points):
    threshold = 1e-5
    faces_x_zero = [
        i for i, face in enumerate(obj.faces)
        if np.all(np.abs(obj.vertices[face, 0]) < threshold)  # Check if all vertices' x-coordinates are 0
    ]
    subset_mesh = obj.submesh([faces_x_zero], only_watertight=False)[0]
    points, _ = trimesh.sample.sample_surface(subset_mesh, count=num_points)
    inlet_surface_points = torch.tensor(points / 1000.0, dtype = torch.float64)
    inlet_surface_labels = torch.ones(inlet_surface_points.size(0), dtype=torch.int64)
    return inlet_surface_points, inlet_surface_labels

def get_other_surface_points(obj, num_points):
    threshold = 1e-5
    points, _ = trimesh.sample.sample_surface(obj, count=num_points)
    filtered_points = points[np.abs(points[:, 0]) > threshold]
    other_surface_points = torch.tensor(filtered_points / 1000.0, dtype = torch.float64)
    other_surface_labels = 2*torch.ones(other_surface_points.size(0), dtype=torch.int64)
    return other_surface_points, other_surface_labels

def get_volume_points(obj, num_points):
    volume_points = torch.tensor(trimesh.sample.volume_mesh(obj, num_points) / 1000.0, dtype=torch.float64)
    volume_labels = torch.zeros(volume_points.size(0), dtype=torch.int64)
    return volume_points, volume_labels

In [36]:
from sample import *

# inlet_surface_points, inlet_surface_labels = get_inlet_surface_points(obj,100)
# other_surface_points, other_surface_labels = get_other_surface_points(obj,400)   
volume_points, volume_labels = get_volume_points(obj,500)   

volume_points.requires_grad_(True)

# Get Hermite Spline coefficients from the Unet
unet_input = prepare_mesh_for_unet(binary_mask)
spline_coeff = unet_model(unet_input)[0]

In [42]:
vx1 = f(spline_coeff,0,volume_points,0,0,0)

vx1.shape

torch.Size([482])

In [29]:
def binary_array(n):
    if n < 0 or n > 7:
        raise ValueError("Number must be between 0 and 7 inclusive.")
    # Convert the number to a binary string, remove the '0b' prefix, and pad it to 3 bits
    binary_str = format(n, '03b')
    binary_array = [int(bit) for bit in binary_str]
    return binary_array[0], binary_array[1], binary_array[2]

# Hermite Spline Kernels for 2nd Order
def h0(diff, der = 0):
    abs_diff = torch.abs(diff)
    match der:
        case 0:
            return (1-abs_diff)**2*(1+2*abs_diff)
        case 1:
            return 6*diff*(abs_diff-1)
        case 2:
            return 12*abs_diff-6

def h1(diff, der = 0):
    abs_diff = torch.abs(diff)
    match der:
        case 0:
            return diff*(1-abs_diff)**2
        case 1:
            return 3*abs_diff**2 - 4*abs_diff + 1
        case 2:
            return 6*diff - 4*abs_diff

hermite_kernel_1d = [h0, h1]

def hermite_kernel_3d(i,j,k,x,y,z,der_x=0,der_y=0,der_z=0):
    return hermite_kernel_1d[i](x, der_x) * hermite_kernel_1d[j](y, der_y) * hermite_kernel_1d[k](z, der_z)

In [44]:
def h_000(x):
    return h0(x[0])*h0(x[1])*h0(x[2])

def h_001(x):
    return h0(x[0])*h0(x[1])*h1(x[2])

def h_010(x):
    return h0(x[0])*h1(x[1])*h0(x[2])

def h_011(x):
    return h0(x[0])*h1(x[1])*h1(x[2])

def h_100(x):
    return h1(x[0])*h0(x[1])*h0(x[2])

def h_101(x):
    return h1(x[0])*h0(x[1])*h1(x[2])

def h_110(x):
    return h1(x[0])*h1(x[1])*h0(x[2])

def h_111(x):
    return h1(x[0])*h1(x[1])*h1(x[2])


hs = [h_000, h_001, h_010, h_011, h_100, h_101, h_110, h_111]

def get_field(coefficients, field_idx, points, x_support, y_support, z_support, step, der_x = 0, der_y = 0, der_z = 0):
    field = torch.zeros(points.shape[0])
    for l, p in enumerate(points):
        # Get 8 points of the cube
        # print(x_support[0, l], x_support[1, l])
        x_indices = torch.arange(x_support[0, l], x_support[1, l] + 1)  # x range
        y_indices = torch.arange(y_support[0, l], y_support[1, l] + 1)  # y range
        z_indices = torch.arange(z_support[0, l], z_support[1, l] + 1)  # z range
        x_coords = x_indices * step[0]
        y_coords = y_indices * step[1]
        z_coords = z_indices * step[2]
        # Generate all combinations of grid coordinates
        xx, yy, zz = torch.meshgrid(x_coords, y_coords, z_coords, indexing="ij")
        grid_points = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1) # 8x3
        h_points = torch.tensor([[h((p-x)/step) for h in hs] for x in grid_points])
        coeff = coefficients[field_idx, :, x_support[0, l]:x_support[1, l]+1, y_support[0, l]:y_support[1, l]+1, z_support[0, l]:z_support[1, l]+1]
        field[l] = torch.sum(coeff.flatten()*h_points.flatten())
    return field

x,y,z,x_support_indices,y_supports_indices,z_supports_indices = get_supoort_points(points)

vx2 = get_field(spline_coeff, 0, volume_points, x_support_indices, y_supports_indices, z_supports_indices, torch.tensor(step))

IndexError: index 97 is out of bounds for dimension 1 with size 97

In [43]:
x_support.shape

torch.Size([2, 97])