# RMSNorm with TensorRT

In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt

In [2]:
print("PyTorch version: " + torch.__version__)
print("TensorRT version: " + trt.__version__)

PyTorch version: 2.1.0a0+4136153
TensorRT version: 8.6.1


## Generate input and data shape

In [3]:
# Input tensor shape NCHW
nIn, hIn, wIn = 1, 2, 2

# Output tensor shape C
cOut = 2

# Input tensor
data = np.arange(hIn * wIn, dtype=np.float32).reshape(nIn, hIn, wIn)

# fully connected weight
weight = np.ones(cOut * hIn * wIn, dtype=np.float32).reshape(cOut, hIn * wIn)

# fully connected bias
bias = np.zeros(cOut, dtype=np.float32)

print("inputH0 :", data.shape)
print(data)

inputH0 : (1, 2, 2)
[[[0. 1.]
  [2. 3.]]]


## 1. RMSNorm by PyTorch 

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # The gamma parameter
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
        return self.weight * self._norm(x.float()).type_as(x)

In [5]:
def test_torch(nIn, hIn, wIn, cOut, raw_data, weight, bias):
    data = torch.tensor(raw_data).reshape(-1)
    
    model = RMSNorm(1)

    output = model(data)

    return output

## PyTorch Testing

In [6]:
torch_output = test_torch(nIn, hIn, wIn, cOut, data, weight, bias)
print("RMSNorm_output_torch :", torch_output.shape)
print(torch_output)

RMSNorm_output_torch : torch.Size([4])
tensor([0.0000, 0.5345, 1.0690, 1.6036], grad_fn=<MulBackward0>)


---

## 2. RMSNorm with TensorRT

In [7]:
def trt_create(nIn, hIn, cOut, weight, bias):
    # Config TensorRT Logger, Builder, Network
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (nIn, -1, hIn))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("inputT0", (nIn, 1, hIn), (nIn, 2, hIn), (nIn, 3, hIn)) 
    config.add_optimization_profile(profile)

    # RMSNorm Layer: 1) Square: X^2 -> 2) Sum: sum of all x^2 -> 3) Mean: 1/N -> 4) Root: sqrt(X) -> 5) Division: 1/X
    print("inputT0.shape :")
    print(inputT0.shape)
    # 1) Square: X^2
    RMSNorm_Square_layer = network.add_elementwise(inputT0, inputT0, op=trt.ElementWiseOperation.PROD)
    print("RMSNorm_Square_layer.get_output(0).shape :")
    print(RMSNorm_Square_layer.get_output(0).shape)
    # 2) Sum: sum of all X^2
    RMSNorm_Sum_layer = network.add_reduce(RMSNorm_Square_layer.get_output(0), op=trt.ReduceOperation.SUM, axes=1, keep_dims=True)
    print("RMSNorm_Sum_layer.get_output(0).shape :")
    print(RMSNorm_Sum_layer.get_output(0).shape)
    # 3) Mean: 1/N
    RMSNorm_Mean_layer = network.add_reduce(RMSNorm_Sum_layer.get_output(0), op=trt.ReduceOperation.AVG, axes=7, keep_dims=True)
    print("RMSNorm_Mean_layer.get_output(0).shape :")
    print(RMSNorm_Mean_layer.get_output(0).shape)
    # 4) Root: sqrt(X)
    RMSNorm_Sqrt_layer = network.add_unary(RMSNorm_Mean_layer.get_output(0), op=trt.UnaryOperation.SQRT)
    print("RMSNorm_Sqrt_layer.get_output(0).shape :")
    print(RMSNorm_Sqrt_layer.get_output(0).shape)
    # 5) Division: 1/X
    RMSNorm_Div_layer = network.add_elementwise(inputT0, RMSNorm_Sqrt_layer.get_output(0), op=trt.ElementWiseOperation.DIV)
    print("RMSNorm_Div_layer.get_output(0).shape :")
    print(RMSNorm_Div_layer.get_output(0).shape)
    # output
    network.mark_output(RMSNorm_Div_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [8]:
trt_engineStr = trt_create(nIn, hIn, cOut, weight, bias)

inputT0.shape :
(1, -1, 2)
RMSNorm_Square_layer.get_output(0).shape :
(1, -1, 2)
RMSNorm_Sum_layer.get_output(0).shape :
(1, -1, 2)
RMSNorm_Mean_layer.get_output(0).shape :
(1, 1, 1)
RMSNorm_Sqrt_layer.get_output(0).shape :
(1, 1, 1)
RMSNorm_Div_layer.get_output(0).shape :
(1, -1, 2)


In [9]:
def trt_inference(nIn, hIn, cOut, engineString, raw_data):
    print(engineString)
    print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    print("Set input shape")
    context.set_input_shape("inputT0", (nIn, 2, hIn))
    context.set_binding_shape(0, (nIn, 2, hIn))
    origin_inputshape = context.get_binding_shape(0)

    print("Set input shape completed")

    data = np.array(raw_data)

    _, stream = cudart.cudaStreamCreate()
    print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
    print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
    print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everything
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

In [10]:
trt_output = trt_inference(nIn, hIn, cOut, trt_engineStr, data)
trt_output = trt_output.reshape(-1)
print("output_trt :", trt_output.shape)
print(trt_output)

<tensorrt.tensorrt.IHostMemory object at 0x7f4381fa7370>
Runtime
Set input shape
Set input shape completed
Reshaping
Reshaped
execute
output_trt : (4,)
[0.        0.5345225 1.069045  1.6035675]


  context.set_binding_shape(0, (nIn, 2, hIn))
  origin_inputshape = context.get_binding_shape(0)
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
