In [1]:
import vedo
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

import pycuda.driver as cuda
import pycuda.autoinit
import tensorrt as trt

import common

In [2]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def get_graph_feature(x, k=20, idx=None, dim9=False, device='cpu'):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature      # (batch_size, 2*num_dims, num_points, k)

In [3]:
def centring(mesh: vedo.Mesh):
    mesh.points(pts=mesh.points()-mesh.centerOfMass())
    return mesh

def get_metadata(model_name: str, mesh: vedo.Mesh, device='cpu'):
    # mesh = centring(mesh)
    N = mesh.NCells()
    points = vedo.vtk2numpy(mesh.polydata().GetPoints().GetData())
    ids = vedo.vtk2numpy(mesh.polydata().GetPolys().GetData()).reshape((N, -1))[:,1:]
    cells = points[ids].reshape(N, 9).astype(dtype='float32')
    normals = vedo.vedo2trimesh(mesh).face_normals
    normals.setflags(write=1)
    barycenters = mesh.cellCenters()
    
    #normalized data
    maxs = points.max(axis=0)
    mins = points.min(axis=0)
    means = points.mean(axis=0)
    stds = points.std(axis=0)
    nmeans = normals.mean(axis=0)
    nstds = normals.std(axis=0)

    for i in range(3):
        cells[:, i] = (cells[:, i] - means[i]) / stds[i] #point 1
        cells[:, i+3] = (cells[:, i+3] - means[i]) / stds[i] #point 2
        cells[:, i+6] = (cells[:, i+6] - means[i]) / stds[i] #point 3
        barycenters[:,i] = (barycenters[:,i] - mins[i]) / (maxs[i]-mins[i])
        normals[:,i] = (normals[:,i] - nmeans[i]) / nstds[i]

    X = np.column_stack((cells, barycenters, normals))
    X = X.transpose(1, 0)

    meta = dict()
    meta["input"] = torch.from_numpy(X).unsqueeze(0).to(device, dtype=torch.float)

    if model_name == "iMeshSegNet":
        print("Getting KG6 and KG12.")
        KG_6 = get_graph_feature(torch.from_numpy(X[9:12, :]).unsqueeze(0), k=6).squeeze(0)
        KG_12 = get_graph_feature(torch.from_numpy(X[9:12, :]).unsqueeze(0), k=12).squeeze(0)
        meta["KG_6"] = KG_6.unsqueeze(0).to(device, dtype=torch.float)
        meta["KG_12"] = KG_12.unsqueeze(0).to(device, dtype=torch.float)
    elif model_name == "MeshSegNet":
        print("Getting A_S and A_L.")
        X = X.transpose(1, 0)
        # computing A_S and A_L
        A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
        A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
        TX = torch.as_tensor(X[:, 9:12], device='cpu')
        TD = torch.cdist(TX, TX)
        D = TD.cpu().numpy()
        # D = distance_matrix(X[:, 9:12], X[:, 9:12])
        A_S[D<0.1] = 1.0
        A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))

        A_L[D<0.2] = 1.0
        A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))

        # numpy -> torch.tensor
        A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
        A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
        meta["A_S"] = torch.from_numpy(A_S).to(device, dtype=torch.float)
        meta["A_L"] = torch.from_numpy(A_L).to(device, dtype=torch.float)

    return meta

In [4]:
class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

def allocate_buffers(engine, batch):
    inputs = []
    outputs = []
    bindings = []
    stream = cuda.Stream()
    for item in zip(engine, batch):
        binding, sub = item
        size = trt.volume(sub.shape)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate host and device buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings.
        bindings.append(int(device_mem))
        # Append to the appropriate list.
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    return inputs, outputs, bindings, stream

def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

def do_inference_v2(context, bindings, inputs, outputs, stream):
    # Transfer input data to the GPU.
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference.
    context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU.
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return only the host outputs.
    return [out.host for out in outputs]

def load_np_to_input_buffer(item, pagelocked_buffer):
    np.copyto(pagelocked_buffer, item)

In [5]:
onnx_file_path = "/home/ziyang/Desktop/iMeshSegNet-ONNX/onnx/model_sim.onnx"

In [6]:
TRT_LOGGER = trt.Logger()

builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(common.EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
runtime = trt.Runtime(TRT_LOGGER)

# Parse model file
with open(onnx_file_path, "rb") as model:
    print("Beginning ONNX file parsing")
    if not parser.parse(model.read()):
        print("ERROR: Failed to parse the ONNX file.")
        for error in range(parser.num_errors):
            print(parser.get_error(error))
print("Completed parsing of ONNX file")

opt_profile = builder.create_optimization_profile()
opt_profile.set_shape(input='input', min=(1,15,1000), opt=(2,15,5000), max=(2,15,10000))
opt_profile.set_shape(input='a_s', min=(1,1000,1000), opt=(2,5000,5000), max=(2,10000,10000))
opt_profile.set_shape(input='a_l', min=(1,1000,1000), opt=(2,5000,5000), max=(2,10000,10000))

config = builder.create_builder_config()

config.add_optimization_profile(opt_profile)

config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30) # 256MiB*2*2*2*2

plan = builder.build_serialized_network(network, config)

engine = runtime.deserialize_cuda_engine(plan)

context = engine.create_execution_context()

Beginning ONNX file parsing
[09/05/2022-11:57:41] [TRT] [W] onnx2trt_utils.cpp:369: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
Completed parsing of ONNX file
[09/05/2022-11:57:41] [TRT] [W] TensorRT was linked against cuDNN 8.4.1 but loaded cuDNN 8.3.2
[09/05/2022-11:58:00] [TRT] [W] The getMaxBatchSize() function should not be used with an engine built from a network created with NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag. This function will always return 1.
[09/05/2022-11:58:00] [TRT] [W] The getMaxBatchSize() function should not be used with an engine built from a network created with NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag. This function will always return 1.


In [7]:
import time
t1 = time.time()

small_batch = get_metadata("MeshSegNet", vedo.Mesh("/home/ziyang/Desktop/iMeshSegNet-ONNX/mesh/test_1.ply"), "cpu")
print(small_batch["input"].shape)
batch_size = small_batch["input"].shape[0]
points_num = small_batch["input"].shape[2]

context.set_binding_shape(0, (batch_size, 15, points_num))
context.set_binding_shape(1, (batch_size, points_num, points_num))
context.set_binding_shape(2, (batch_size, points_num, points_num))

output = np.empty((batch_size, points_num, 17), dtype=np.float32)
small_batch = (small_batch["input"].numpy(), small_batch['A_S'].numpy(), small_batch['A_L'].numpy(), output)

t2 = time.time()

inputs, outputs, bindings, stream = allocate_buffers(engine, small_batch)

for idx, item in enumerate(small_batch[:-1]):
    load_np_to_input_buffer(item.ravel(), pagelocked_buffer=inputs[idx].host)

output = do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print('Infer time: ', time.time()-t2)

out = np.asarray(output[0]).reshape(batch_size,points_num,17)
print('Total time: ', time.time()-t1)

print(out.max())
print(out[0][4999])
print(out.shape)

Getting A_S and A_L.
torch.Size([1, 15, 5000])
Infer time:  0.10907363891601562
Total time:  0.8495852947235107
1.0
[8.49054632e-05 5.96667496e-06 4.25081089e-06 4.15956549e-08
 5.01430577e-08 1.41140597e-04 8.87975679e-04 1.21823945e-07
 1.01396523e-07 1.46261455e-05 3.64778716e-05 1.35645436e-07
 1.09602911e-06 9.80355963e-03 9.89018559e-01 1.03043874e-06
 3.19302771e-08]
(1, 5000, 17)


In [8]:
print(engine.get_binding_shape(0))
print(engine.get_binding_shape(1))
print(engine.get_binding_shape(2))
print(engine.get_binding_shape(3))
print(context.get_binding_shape(0))
print(context.get_binding_shape(1))
print(context.get_binding_shape(2))

(-1, 15, -1)
(-1, -1, -1)
(-1, -1, -1)
(-1, -1, 17)
(1, 15, 5000)
(1, 5000, 5000)
(1, 5000, 5000)


In [9]:
# bindings = [int(d_input), int(d_a_s), int(d_a_l), int(d_output)]

# print(bindings)

# context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)

# cuda.memcpy_dtoh_async(output, d_output, stream)

# stream.synchronize()