In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt

## Generate input and data shape

In [2]:
config = dict()
batch_size, seq_len, hidden_size = 4, 1, 4096
intermediate_size = 11008
config['hidden_size'] = hidden_size
config['intermediate_size'] = intermediate_size

In [3]:
data = torch.ones(batch_size, seq_len, hidden_size)

## torch MLP

In [4]:
class SiLUActivation(nn.Module):
    """
    See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
    Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
    Approximation in Reinforcement Learbatch_sizeg (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
    Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
    later.
    """

    def forward(self, input: Tensor) -> Tensor:
        return nn.functional.silu(input)
    
    def b_forward(self, input: Tensor) -> Tensor:
        return torch.matmul(input.T, nn.functional.sigmoid(input))

In [40]:
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config['hidden_size']
        self.intermediate_size = config['intermediate_size']
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = SiLUActivation()
        self.init = False

    def load(self, dir):
        weights = torch.load(dir)
        mlp_weights = dict()
        for key in weights.keys():
            if key.split(".")[3] == "mlp":
                mlp_weights[key[key.find(key.split(".")[4]):]] = weights[key]

        self.load_state_dict(mlp_weights)

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        return down_proj

## Test torch

In [70]:
model = LlamaMLP(config)

device = torch.device("cuda")

model.load("/home/fuchiang137/.cache/huggingface/hub/models--decapoda-research--llama-7b-hf/snapshots/5f98eefcc80e437ef68d457ad7bf167c2c6a1348/pytorch_model-00019-of-00033.bin")
model = model.to(device)

data_D = data.to(device)
# output = model(data)
output = model(data_D)

print(output)
print(output.shape)

tensor([[[ 2.0781,  4.5118,  3.0771,  ..., -2.6065, -2.1167, -2.7058]],

        [[ 2.0781,  4.5118,  3.0771,  ..., -2.6065, -2.1167, -2.7058]],

        [[ 2.0781,  4.5118,  3.0771,  ..., -2.6065, -2.1167, -2.7058]],

        [[ 2.0781,  4.5118,  3.0771,  ..., -2.6065, -2.1167, -2.7058]]],
       device='cuda:1', grad_fn=<UnsafeViewBackward0>)
torch.Size([4, 1, 4096])


In [7]:
print(model.up_proj.weight.shape)

torch.Size([11008, 4096])


## tensorRT MLP

In [42]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, hidden_size, intermediate_size, model):
    
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (batch_size, 1, -1, hidden_size))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("inputT0", (batch_size, 1, 1, hidden_size), (batch_size, 1, 2, hidden_size), (batch_size, 1, 3, hidden_size))
    config.add_optimization_profile(profile)


    # self.up_proj(x)
    up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()
    up_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, up_proj_weight)

    # act_fn(self.gate_proj(x))
    gate_proj_weight = model.gate_proj.weight.clone().detach().cpu().numpy()
    gate_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, gate_proj_weight)

    selu_sigmoid_layer = network.add_activation(gate_proj_layer.get_output(0), type=trt.ActivationType.SIGMOID)
    selu_mult_layer = network.add_elementwise(gate_proj_layer.get_output(0), selu_sigmoid_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

    # act_fn(self.gate_proj(x)) * self.up_proj(x)
    before_down_proj_layer = network.add_elementwise(selu_mult_layer.get_output(0), up_proj_layer.get_output(0), op=trt.ElementWiseOperation.PROD)

    down_proj_weight = model.down_proj.weight.clone().detach().cpu().numpy()
    down_proj_layer = network.add_fully_connected(before_down_proj_layer.get_output(0), hidden_size, down_proj_weight)

    # output
    network.mark_output(down_proj_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [43]:
trt_engineStr = trt_create(batch_size, hidden_size, intermediate_size, model)

  up_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, up_proj_weight)
  gate_proj_layer = network.add_fully_connected(inputT0, model.intermediate_size, gate_proj_weight)
  down_proj_layer = network.add_fully_connected(before_down_proj_layer.get_output(0), hidden_size, down_proj_weight)


In [62]:
def trt_inference(batch_size, hidden_size, engineString, raw_data, up_proj):
#     print(engineString)
#     print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
#     print("Set input shape", (batch_size, 1, hidden_size))
#     context.set_input_shape("inputT0", (batch_size, 1, hidden_size))
#     context.set_binding_shape(0, (batch_size, 1, hidden_size))
#     origin_inputshape = context.get_binding_shape(0)

#     print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [63]:
up_proj_weight = model.up_proj.weight.clone().detach().cpu().numpy()

trt_output = trt_inference(batch_size, hidden_size, trt_engineStr, data, up_proj_weight)

trt_output = trt_output.reshape(batch_size, seq_len, hidden_size)
print("output_trt :", trt_output.shape)
print(trt_output)

output_trt : (4, 1, 4096)
[[[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]

 [[ 2.0781114  4.51184    3.0770564 ... -2.6064956 -2.116741  -2.7058382]]]


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))


## Benchmark

In [None]:
import time

### Torch

In [74]:
torch_start = time.time_ns()

output = model(data_D)

torch_complete = time.time_ns()

print("torch memory exe", (torch_complete - torch_start) / 10e6, "ms")


torch memory exe 0.2631836 ms


### TensorRT

### profile CPU/GPU time for tensorRT

In [68]:
def profile_trt_inference(batch_size, hidden_size, engineString, raw_data, up_proj):
    trt_prep_start = time.time_ns()
    
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    trt_prep_complete = time.time_ns()

    data = np.array(raw_data)

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))

    memory_alloc_complete = time.time_ns()

    _, stream = cudart.cudaStreamCreate()

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    trt_complete = time.time_ns()
    
    print("trt_prep", (trt_prep_complete - trt_prep_start) / 10e6, "ms")
    print("memory_alloc CPU", (memory_alloc_complete - trt_prep_complete) / 10e6, "ms")
    print("trt memory alloc & mv & exe", (trt_complete - memory_alloc_complete) / 10e6, "ms")

    return outputH0

In [69]:
trt_output = profile_trt_inference(batch_size, hidden_size, trt_engineStr, data, up_proj_weight)

trt_prep 15.1114614 ms
memory_alloc CPU 0.0241049 ms
trt memory alloc & mv & exe 0.1599591 ms


  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
