# RMSNorm with TensorRT

In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt

In [2]:
print("PyTorch version: " + torch.__version__)
print("TensorRT version: " + trt.__version__)

PyTorch version: 2.1.0a0+4136153
TensorRT version: 8.6.1


## Generate input and data shape

In [3]:
# Input tensor shape NCHW
nIn, hIn, wIn = 4, 45, 4096

# Output tensor shape C
cOut = 2

# Input tensor
data = np.arange(nIn * hIn * wIn, dtype=np.float32).reshape(nIn, hIn, wIn)

# fully connected weight
#weight = np.ones(cOut * hIn * wIn, dtype=np.float32).reshape(cOut, hIn * wIn)

# fully connected bias
#bias = np.zeros(cOut, dtype=np.float32)

print("inputH0 :", data.shape)
print(data)

inputH0 : (4, 45, 4096)
[[[0.00000e+00 1.00000e+00 2.00000e+00 ... 4.09300e+03 4.09400e+03
   4.09500e+03]
  [4.09600e+03 4.09700e+03 4.09800e+03 ... 8.18900e+03 8.19000e+03
   8.19100e+03]
  [8.19200e+03 8.19300e+03 8.19400e+03 ... 1.22850e+04 1.22860e+04
   1.22870e+04]
  ...
  [1.72032e+05 1.72033e+05 1.72034e+05 ... 1.76125e+05 1.76126e+05
   1.76127e+05]
  [1.76128e+05 1.76129e+05 1.76130e+05 ... 1.80221e+05 1.80222e+05
   1.80223e+05]
  [1.80224e+05 1.80225e+05 1.80226e+05 ... 1.84317e+05 1.84318e+05
   1.84319e+05]]

 [[1.84320e+05 1.84321e+05 1.84322e+05 ... 1.88413e+05 1.88414e+05
   1.88415e+05]
  [1.88416e+05 1.88417e+05 1.88418e+05 ... 1.92509e+05 1.92510e+05
   1.92511e+05]
  [1.92512e+05 1.92513e+05 1.92514e+05 ... 1.96605e+05 1.96606e+05
   1.96607e+05]
  ...
  [3.56352e+05 3.56353e+05 3.56354e+05 ... 3.60445e+05 3.60446e+05
   3.60447e+05]
  [3.60448e+05 3.60449e+05 3.60450e+05 ... 3.64541e+05 3.64542e+05
   3.64543e+05]
  [3.64544e+05 3.64545e+05 3.64546e+05 ... 3.6863

## 1. RMSNorm by PyTorch 

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        # The gamma parameter
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def _norm(self, x: torch.Tensor):
        # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
        return self.weight * self._norm(x.float()).type_as(x)

In [8]:
def test_torch(nIn, hIn, wIn, cOut, raw_data):
    data = torch.tensor(raw_data).reshape(-1)
    
    model = RMSNorm(1)

    output = model(data)

    return output

## PyTorch Testing

In [9]:
torch_output = test_torch(nIn, hIn, wIn, cOut, data)
print("RMSNorm_output_torch :", torch_output.shape)
print(torch_output)

RMSNorm_output_torch : torch.Size([737280])
tensor([0.0000e+00, 2.3492e-06, 4.6985e-06,  ..., 1.7320e+00, 1.7320e+00,
        1.7321e+00], grad_fn=<MulBackward0>)


---

## 2. RMSNorm with TensorRT

In [13]:
def trt_create(nIn, hIn, wIn, cOut):
    # Config TensorRT Logger, Builder, Network
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # input
    inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, [-1])
    avg_factor = np.array([nIn * hIn * wIn]).astype('float32')
    epsilon_weight = np.array([1e-06]).astype('float32')
    avg_tensor = network.add_constant(shape=list(avg_factor.shape), weights=trt.Weights(avg_factor))
    epsilon = network.add_constant(shape=list(epsilon_weight.shape), weights=trt.Weights(epsilon_weight))

    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("inputT0", [1], [hIn*wIn], [nIn*hIn*wIn]) 
    config.add_optimization_profile(profile)

    # RMSNorm Layer: 1) Square: X^2 -> 2) Sum: sum of all x^2 -> 3) Mean: 1/N 
    # -> 4) Add epsilon -> 5) Root: sqrt(X) -> 6) Division: 1/X
    print("inputT0.shape :")
    print(inputT0.shape)
    # 1) Square: X^2
    RMSNorm_Square_layer = network.add_elementwise(inputT0, inputT0, op=trt.ElementWiseOperation.PROD)
    print("RMSNorm_Square_layer.get_output(0).shape :")
    print(RMSNorm_Square_layer.get_output(0).shape)
    # 2) Sum: sum of all X^2
    RMSNorm_Sum_layer = network.add_reduce(RMSNorm_Square_layer.get_output(0), op=trt.ReduceOperation.SUM, axes=1, keep_dims=True)
    print("RMSNorm_Sum_layer.get_output(0).shape :")
    print(RMSNorm_Sum_layer.get_output(0).shape)
    # 3) Mean: 1/N
    RMSNorm_Mean_layer = network.add_elementwise(RMSNorm_Sum_layer.get_output(0),
                                                 avg_tensor.get_output(0),
                                                 op=trt.ElementWiseOperation.DIV)
    print("RMSNorm_Mean_layer.get_output(0).shape :")
    print(RMSNorm_Mean_layer.get_output(0).shape)
    # 4) Add epsilon
    RMSNorm_Mean_with_epsilon_layer = network.add_elementwise(RMSNorm_Mean_layer.get_output(0),
                                                              epsilon.get_output(0), 
                                                              op=trt.ElementWiseOperation.SUM)
    print("RMSNorm_Mean_with_epsilon_layer.get_output(0).shape :")
    print(RMSNorm_Mean_with_epsilon_layer.get_output(0).shape)
    # 5) Root: sqrt(X)
    RMSNorm_Sqrt_layer = network.add_unary(RMSNorm_Mean_with_epsilon_layer.get_output(0), op=trt.UnaryOperation.SQRT)
    print("RMSNorm_Sqrt_layer.get_output(0).shape :")
    print(RMSNorm_Sqrt_layer.get_output(0).shape)
    # 6) Division: 1/X
    RMSNorm_Div_layer = network.add_elementwise(inputT0, RMSNorm_Sqrt_layer.get_output(0), op=trt.ElementWiseOperation.DIV)
    print("RMSNorm_Div_layer.get_output(0).shape :")
    print(RMSNorm_Div_layer.get_output(0).shape)
    # output
    network.mark_output(RMSNorm_Div_layer.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [14]:
trt_engineStr = trt_create(nIn, hIn, wIn, cOut)

inputT0.shape :
(-1,)
RMSNorm_Square_layer.get_output(0).shape :
(-1,)
RMSNorm_Sum_layer.get_output(0).shape :
(1,)
RMSNorm_Mean_layer.get_output(0).shape :
(1,)
RMSNorm_Mean_with_epsilon_layer.get_output(0).shape :
(1,)
RMSNorm_Sqrt_layer.get_output(0).shape :
(1,)
RMSNorm_Div_layer.get_output(0).shape :
(-1,)


In [15]:
def trt_inference(nIn, hIn, wIn, cOut, engineString, raw_data):
    #print(engineString)
    #print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    #print("Set input shape")
    context.set_input_shape("inputT0", [nIn * hIn * wIn])
    context.set_binding_shape(0, [nIn * hIn * wIn])
    origin_inputshape = context.get_binding_shape(0)

    #print("Set input shape completed")

    data = np.array(raw_data).reshape(-1)

    _, stream = cudart.cudaStreamCreate()
    #print("Reshaping")

    inputH0 = np.ascontiguousarray(data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
    #print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)

    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
    #print("execute")
    context.execute_async_v2([int(inputD0), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everything
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(outputD0)

    return outputH0

## Testing TensorRT Output

In [16]:
trt_output = trt_inference(nIn, hIn, wIn, cOut, trt_engineStr, data)
print("RMSNorm_output_trt - without reshape :", trt_output.shape)
trt_output = trt_output.reshape(-1)
print("RMSNorm_output_trt :", trt_output.shape)
print(trt_output)

RMSNorm_output_trt - without reshape : (737280,)
RMSNorm_output_trt : (737280,)
[0.0000000e+00 2.3492469e-06 4.6984937e-06 ... 1.7320457e+00 1.7320480e+00
 1.7320503e+00]


  context.set_binding_shape(0, [nIn * hIn * wIn])
  origin_inputshape = context.get_binding_shape(0)
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
  outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))
