# RoPE with TensorRT

In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

In [2]:
print("PyTorch version: " + torch.__version__)
print("TensorRT version: " + trt.__version__)

PyTorch version: 2.1.0a0+4136153
TensorRT version: 8.6.1


## 0. Generate input and data shape

In [3]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["intermediate_size"] = intermediate_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]
config["max_position_embeddings"] = max_position_embeddings
config["rope_theta"] = rope_theta

In [4]:
data = torch.ones(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, 1, seq_len, seq_len)
print("data : " + str(data.shape))
print("attention_mask : " + str(attention_mask.shape))
position_ids = torch.arange(0, seq_len)
print("position_ids : " + str(position_ids.shape))
position_ids = position_ids.repeat(batch_size, 1)
print("position_ids : " + str(position_ids.shape))

data : torch.Size([4, 45, 4096])
attention_mask : torch.Size([4, 1, 45, 45])
position_ids : torch.Size([45])
position_ids : torch.Size([4, 45])


## 1. RoPE with Pytorch

In [5]:
class LlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
        t = torch.arange(max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
        # Outer Product: outer(A, B)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(torch.get_default_dtype()), persistent=False)
        
    def rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, q, k, v, position_ids, seq_len=None):
        # v: [bs, num_attention_heads, seq_len, head_size]
        cos = self.cos_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        sin = self.sin_cached[:, :, :seq_len, ...].to(dtype=v.dtype)
        cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
        sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
        cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
        
        # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
        print("Pytorch RoPE - forwarding: " + str(cos))
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed

## 2. Test Pytorch

In [6]:
## Prepare Pytorch Testing Parameter
q_proj = nn.Linear(config["hidden_size"], config["num_heads"] * config["hidden_size"] // config["num_heads"], bias=False)
k_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)
v_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)

query_states = q_proj(data)
key_states = k_proj(data)
value_states = v_proj(data)

bsz, q_len, _ = data.size()

# reshape
query_states = query_states.view(bsz, q_len, config["num_heads"], config["head_dim"]).transpose(1, 2)
key_states = key_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
value_states = value_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
print("Input query_states: after reshape " + str(query_states.shape))
print(query_states[0])

Input query_states: after reshape torch.Size([4, 32, 45, 128])
tensor([[[-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         [-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         [-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         ...,
         [-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         [-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         [-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300]],

        [[-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         [-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         [-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         ...,
         [-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         [-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         [-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659]],

        [[ 0.4959, -0.2269, -0.1815,  ...,  0.1882,  

In [7]:
model = LlamaRotaryEmbedding(dim=config["hidden_size"] // config["num_heads"],
                             max_position_embeddings = config["max_position_embeddings"],
                             base = config["rope_theta"])

device = torch.device("cuda")
model = model.to(device)
data_D = data.to(device)
attention_mask_D = attention_mask.to(device)
position_ids_D = position_ids.to(device)

query_states_D = query_states.to(device)
key_states_D = key_states.to(device)
value_states_D = value_states.to(device)

output_query_states, output_key_states = model(query_states_D, key_states_D, value_states_D, position_ids_D, seq_len=q_len)

print("output_query_states: " + str(output_query_states.shape))

Pytorch RoPE - forwarding: tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5403,  0.6479,  0.7318,  ...,  1.0000,  1.0000,  1.0000],
          [-0.4161, -0.1604,  0.0709,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-0.4000,  0.2398,  0.9968,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5551,  0.8949,  0.6752,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.9998,  0.9198, -0.0086,  ...,  1.0000,  1.0000,  1.0000]]],


        [[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5403,  0.6479,  0.7318,  ...,  1.0000,  1.0000,  1.0000],
          [-0.4161, -0.1604,  0.0709,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [-0.4000,  0.2398,  0.9968,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.5551,  0.8949,  0.6752,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.9998,  0.9198, -0.0086,  ...,  1.0000,  1.0000,  1.0000]]],


        [[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],

In [8]:
output_query_states[0]

tensor([[[-0.7269, -0.0775, -0.7819,  ..., -0.4893, -0.1463, -0.1300],
         [-0.5390, -0.8500, -0.8107,  ..., -0.4892, -0.1461, -0.1299],
         [ 0.1444, -1.0239, -0.4045,  ..., -0.4892, -0.1460, -0.1298],
         ...,
         [ 0.4500,  1.0007, -0.8073,  ..., -0.4874, -0.1385, -0.1255],
         [-0.2589,  0.3992, -0.7861,  ..., -0.4873, -0.1383, -0.1254],
         [-0.7298, -0.4834, -0.3431,  ..., -0.4873, -0.1381, -0.1253]],

        [[-0.8873, -1.1440, -0.9657,  ...,  0.4321, -0.2121, -0.6659],
         [-0.6318, -0.1051, -0.4665,  ...,  0.4318, -0.2120, -0.6659],
         [ 0.2046,  1.0079,  0.2829,  ...,  0.4316, -0.2119, -0.6659],
         ...,
         [ 0.5209, -1.0852, -0.9346,  ...,  0.4223, -0.2067, -0.6659],
         [-0.3420, -1.3965, -0.3922,  ...,  0.4220, -0.2066, -0.6659],
         [-0.8904, -0.7245,  0.3607,  ...,  0.4218, -0.2065, -0.6659]],

        [[ 0.4959, -0.2269, -0.1815,  ...,  0.1882,  0.2129,  0.2076],
         [ 0.7157, -0.0068, -0.3643,  ...,  0

## 3. RoPE with TensorRT

In [9]:
# seq length is not specified, since it is a dynamic size
def trt_create(batch_size, num_attention_heads, dim, max_position_embeddings, base):
    # Config TensorRT Logger, Builder, Network
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # inputs: Q, K with dynamic seq_len (seq_len_mask: provide dynamic shape)
    query_states = network.add_input('query_states', trt.DataType.FLOAT, (batch_size, num_attention_heads, -1, dim))
    key_states = network.add_input('key_states', trt.DataType.FLOAT, (batch_size, num_attention_heads, -1, dim))
    seq_len_mask = network.add_input('seq_len_mask', trt.DataType.FLOAT, (-1, dim))
    
    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("query_states", 
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads, 45, dim), 
                      (batch_size, num_attention_heads, 1024, dim))
    profile.set_shape("key_states",  
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads, 45, dim), 
                      (batch_size, num_attention_heads, 1024, dim))
    profile.set_shape("seq_len_mask", (1, dim), (45, dim), (1024, dim)) 
    
    config.add_optimization_profile(profile)
    
    print("- 0) input: Q, K, seq_len_mask shape :")
    print(query_states.shape, key_states.shape, seq_len_mask.shape)

    # 1. Precompute sin & cos cache with max_position_embeddings
    print("- 1) Precompute sin & cos cache:")
    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
    # Shape: (Head_Dim / 2) = (64,)
    theta_numerator = np.arange(0, dim, 2).astype('float32')
    # Shape: (Head_Dim / 2) = (64,)
    theta = 1.0 / (base ** (theta_numerator / dim)) # (dim / 2)
    # Construct the positions (the "m" parameter)
    # Shape: (Max_Position_Embeddings) = (2048,) 
    m = np.arange(max_position_embeddings).astype('float32')
    # Multiply each theta by each position using the outer product.
    # Shape: (Max_Position_Embeddings) outer_product* (Head_Dim / 2) -> (Max_Position_Embeddings, Head_Dim / 2) = (2048, 64)
    freqs = np.outer(m, theta)
    # Different from paper, but it uses a different permutation in order to obtain the same calculation
    # emb Shape: (2048, 128)
    emb = np.concatenate((freqs, freqs), axis=-1)
    
    # 2. Convert to cos, sin cache = (2048, 128)
    cos_cached_np = np.cos(emb) #[None, None, :, :]
    sin_cached_np = np.sin(emb) #[None, None, :, :]
    cached_shape = list(cos_cached_np.shape)
    
    cos_cached_layer = network.add_constant(shape=cached_shape, weights=trt.Weights(cos_cached_np))
    sin_cached_layer = network.add_constant(shape=cached_shape, weights=trt.Weights(sin_cached_np))
   
    print("- 2) cos, sin cached layer shape :")
    print(cos_cached_layer.get_output(0).shape)
    
    # 3. Dynamic Slicing: to fetch cos_cache by seq_len :  e.g. [2048, 128] -> [seq_len, 128]
    # See detail https://github.com/NVIDIA/TensorRT/issues/2282
    
    # Fill in start, shape, and stride with some sane defaults.
    # Later, we'll replace these with input tensors to make the slice dynamic.
    cos_cached_slice_layer = network.add_slice(cos_cached_layer.get_output(0), start=(0, 0), shape=(32, dim), stride=(1, 1))
    sin_cached_slice_layer = network.add_slice(sin_cached_layer.get_output(0), start=(0, 0), shape=(32, dim), stride=(1, 1))
    
    # Get the seq_len shape from IShapeLayer: (-1, 128)
    seq_len_mask_shape = network.add_shape(seq_len_mask)
    
    # Now that we know the seq_len_mask shape, let's add an input tensor for `shape`, which is at index 2.
    # Refer to the API documentation for details:
    # https://docs.nvidia.com/deeplearning/tensorrt/api/python_api/infer/Graph/Layers.html?highlight=islicelayer#tensorrt.ISliceLayer.set_input.
    cos_cached_slice_layer.set_input(2, seq_len_mask_shape.get_output(0))
    sin_cached_slice_layer.set_input(2, seq_len_mask_shape.get_output(0))
    
    print("- 3) cos_cached_slice_layer shape :")
    print(cos_cached_slice_layer.get_output(0).shape)
    
    # 4. repeat the cos , sin cache with batch_size (new dimension) = [seq_len, 128] -> [batch_size, seq_len, 128]
    print("- 4-1) cos_cached_slice_shuffle_layer shape :")
    cos_cached_slice_shuffle_layer = network.add_shuffle(cos_cached_slice_layer.get_output(0))
    sin_cached_slice_shuffle_layer = network.add_shuffle(sin_cached_slice_layer.get_output(0))
    
    cos_cached_slice_shuffle_layer.reshape_dims = trt.Dims([1, 1, -1, dim])
    sin_cached_slice_shuffle_layer.reshape_dims = trt.Dims([1, 1, -1, dim])
    
    
    print(cos_cached_slice_shuffle_layer.get_output(0).shape)
    print("- 4-2) cos_cached_slice_repeat_layer shape :")
    
    cos_cached_slice_repeat_layer = network.add_concatenation([cos_cached_slice_shuffle_layer.get_output(0), 
                                                               cos_cached_slice_shuffle_layer.get_output(0), 
                                                               cos_cached_slice_shuffle_layer.get_output(0), 
                                                               cos_cached_slice_shuffle_layer.get_output(0)])
    sin_cached_slice_repeat_layer = network.add_concatenation([sin_cached_slice_shuffle_layer.get_output(0), 
                                                               sin_cached_slice_shuffle_layer.get_output(0), 
                                                               sin_cached_slice_shuffle_layer.get_output(0), 
                                                               sin_cached_slice_shuffle_layer.get_output(0)])
    
    cos_cached_slice_repeat_layer.axis = 0
    sin_cached_slice_repeat_layer.axis = 0
    
    print(cos_cached_slice_repeat_layer.get_output(0).shape)
    
    # 5. rotate_half of Q, K tensor:
    print("- 5) rotate_half of Q, K shape : (slice half / negative x2 / concat(-x2, x1) )")
    # 5-0) get rotate_half shape 
    q_state_shape = network.add_shape(query_states)
    q_state_shape_divisor = network.add_constant(shape=(4,), weights=np.array([1, 1, 1, 2], dtype=np.int32))
    rotate_half_shape = network.add_elementwise(q_state_shape.get_output(0), 
                                                q_state_shape_divisor.get_output(0), 
                                                trt.ElementWiseOperation.DIV)
    # 5-1) slice half tensor
    rotate_half_q1 = network.add_slice(query_states, start=(0, 0, 0, 0), 
                                       shape=(4, 1, 64, 64), 
                                       stride=(1, 1, 1, 1))
    rotate_half_q2 = network.add_slice(query_states, start=(0, 0, 0, 64), 
                                       shape=(4, 1, 64, 64), 
                                       stride=(1, 1, 1, 1))
    rotate_half_k1 = network.add_slice(key_states, start=(0, 0, 0, 0), 
                                       shape=(4, 1, 64, 64), 
                                       stride=(1, 1, 1, 1))
    rotate_half_k2 = network.add_slice(key_states, start=(0, 0, 0, 64), 
                                       shape=(4, 1, 64, 64), 
                                       stride=(1, 1, 1, 1))
    rotate_half_q1.set_input(2, rotate_half_shape.get_output(0))
    rotate_half_q2.set_input(2, rotate_half_shape.get_output(0))
    rotate_half_k1.set_input(2, rotate_half_shape.get_output(0))
    rotate_half_k2.set_input(2, rotate_half_shape.get_output(0))
    print(rotate_half_q2.get_output(0).shape)
    
    # 5-2) negative x2
    rotate_half_q2_negative = network.add_unary(rotate_half_q2.get_output(0), op=trt.UnaryOperation.NEG)
    rotate_half_k2_negative = network.add_unary(rotate_half_k2.get_output(0), op=trt.UnaryOperation.NEG)
    print(rotate_half_q2_negative.get_output(0).shape)
    
    # 5-3) concat (-x2, x1)
    rotate_half_q2_negative_concat_q1 = network.add_concatenation([rotate_half_q2_negative.get_output(0),
                                                                   rotate_half_q1.get_output(0)])
    rotate_half_k2_negative_concat_k1 = network.add_concatenation([rotate_half_k2_negative.get_output(0),
                                                                   rotate_half_k1.get_output(0)])
    rotate_half_q2_negative_concat_q1.axis = 3
    rotate_half_k2_negative_concat_k1.axis = 3
    
    print(rotate_half_q2_negative_concat_q1.get_output(0).shape)
    
    # 6. Output: Matrix Multiply of Q * cos + rotate_half(Q) * sin
    print("- 6) Output: Q_embed = Q * cos + rotate_half(Q) * sin")
    q_cos = network.add_einsum(inputs=[query_states, cos_cached_slice_repeat_layer.get_output(0)], equation="ijkl,ijkl->ijkl")
    q_sin = network.add_einsum(inputs=[rotate_half_q2_negative_concat_q1.get_output(0), 
                                       sin_cached_slice_repeat_layer.get_output(0)], equation="ijkl,ijkl->ijkl")
    q_embed = network.add_elementwise(q_cos.get_output(0), q_sin.get_output(0), op=trt.ElementWiseOperation.SUM)
    
    k_cos = network.add_einsum(inputs=[key_states, cos_cached_slice_repeat_layer.get_output(0)], equation="ijkl,ijkl->ijkl")
    k_sin = network.add_einsum(inputs=[rotate_half_k2_negative_concat_k1.get_output(0), 
                                       sin_cached_slice_repeat_layer.get_output(0)], equation="ijkl,ijkl->ijkl")
    k_embed = network.add_elementwise(k_cos.get_output(0), k_sin.get_output(0), op=trt.ElementWiseOperation.SUM)
    
    print(q_embed.get_output(0).shape)
    print(k_embed.get_output(0).shape)
    
    print("- 7) check seq_len_mask shape :")
    print(seq_len_mask.shape)
    
    network.mark_output(q_embed.get_output(0))
    network.mark_output(k_embed.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [10]:
trt_engineStr = trt_create(batch_size = batch_size, 
                           num_attention_heads = config["num_heads"],
                           dim = config["head_dim"],
                           max_position_embeddings = config["max_position_embeddings"],
                           base = config["rope_theta"])

- 0) input: Q, K, seq_len_mask shape :
(4, 32, -1, 128) (4, 32, -1, 128) (-1, 128)
- 1) Precompute sin & cos cache:
- 2) cos, sin cached layer shape :
(2048, 128)
- 3) cos_cached_slice_layer shape :
(-1, 128)
- 4-1) cos_cached_slice_shuffle_layer shape :
(1, 1, -1, 128)
- 4-2) cos_cached_slice_repeat_layer shape :
(4, 1, -1, 128)
- 5) rotate_half of Q, K shape : (slice half / negative x2 / concat(-x2, x1) )
(4, 32, -1, 64)
(4, 32, -1, 64)
(4, 32, -1, 128)
- 6) Output: Q_embed = Q * cos + rotate_half(Q) * sin
(4, 32, -1, 128)
(4, 32, -1, 128)
- 7) check seq_len_mask shape :
(-1, 128)


In [11]:
def trt_inference(batch_size, num_attention_heads, dim, engineString, q_state, k_state, seq_len_mask): 

    print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    print("Set input shape: query_states")
    #context.active_optimization_profile = 0
    
    q_shape = context.get_binding_shape(0)
    print(q_shape)
    context.set_input_shape("query_states", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))

    print("Set input shape: key_states")
    k_shape = context.get_binding_shape(1)
    print(k_shape)
    context.set_input_shape("key_states", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(1, (batch_size, num_attention_heads, seq_len, dim))
    
    print("Set input shape: seq_len_mask")
    mask_shape = context.get_binding_shape(2)
    print(mask_shape)
    context.set_input_shape("seq_len_mask", (seq_len, dim))
    context.set_binding_shape(2, (seq_len, dim))
    
    print("Set input shape completed")

    q_state_data = np.array(q_state)
    k_state_data = np.array(k_state)
    seq_len_mask_data = np.array(seq_len_mask)

    _, stream = cudart.cudaStreamCreate()
#     print("Reshaping")

    inputH0 = np.ascontiguousarray(q_state_data.reshape(-1))
    inputH1 = np.ascontiguousarray(k_state_data.reshape(-1))
    inputH2 = np.ascontiguousarray(seq_len_mask_data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
    outputH1 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))
    
#     print("Reshaped")

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, inputD1 = cudart.cudaMallocAsync(inputH1.nbytes, stream)
    _, inputD2 = cudart.cudaMallocAsync(inputH2.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)
    _, outputD1 = cudart.cudaMallocAsync(outputH1.nbytes, stream)


    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    cudart.cudaMemcpyAsync(inputD1, inputH1.ctypes.data, inputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    cudart.cudaMemcpyAsync(inputD2, inputH2.ctypes.data, inputH2.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)

    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(inputD1), int(inputD2), int(outputD0), int(outputD1)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    cudart.cudaMemcpyAsync(outputH1.ctypes.data, outputD1, outputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)

    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(inputD1)
    cudart.cudaFree(inputD2)
    cudart.cudaFree(outputD0)
    cudart.cudaFree(outputD1)

    return outputH0, outputH1

In [12]:
q_state = query_states.detach().numpy()
k_state = key_states.detach().numpy()
seq_len_mask = np.ones((seq_len, config["head_dim"]))

In [13]:
q_state.shape, k_state.shape, seq_len_mask.shape

((4, 32, 45, 128), (4, 32, 45, 128), (45, 128))

In [14]:
trt_output = trt_inference(batch_size, config["num_heads"], config["head_dim"],
                           trt_engineStr, 
                           q_state, k_state, seq_len_mask)

trt_query_states, trt_key_states = trt_output

Runtime
Set input shape: query_states
(4, 32, -1, 128)
Set input shape: key_states
(4, 32, -1, 128)
Set input shape: seq_len_mask
(-1, 128)
Set input shape completed


  q_shape = context.get_binding_shape(0)
  context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))
  k_shape = context.get_binding_shape(1)
  context.set_binding_shape(1, (batch_size, num_attention_heads, seq_len, dim))
  mask_shape = context.get_binding_shape(2)
  context.set_binding_shape(2, (seq_len, dim))
  outputH0 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
  outputH0 = np.empty(context.get_binding_shape(3), dtype=trt.nptype(engine.get_binding_dtype(3)))
  outputH1 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))
  outputH1 = np.empty(context.get_binding_shape(4), dtype=trt.nptype(engine.get_binding_dtype(4)))


In [15]:
print(trt_query_states.shape, trt_key_states.shape)

(4, 32, 45, 128) (4, 32, 45, 128)


In [16]:
trt_query_states[0]

array([[[-0.7268712 , -0.07753751, -0.7819331 , ..., -0.48927367,
         -0.14633395, -0.12997538],
        [-0.5389836 , -0.8500305 , -0.81065595, ..., -0.48922914,
         -0.14614709, -0.12986952],
        [ 0.14444299, -1.023942  , -0.4044796 , ..., -0.4891846 ,
         -0.14596021, -0.12976368],
        ...,
        [ 0.4500357 ,  1.0007448 , -0.8072869 , ..., -0.4873935 ,
         -0.1384836 , -0.125528  ],
        [-0.25892752,  0.3992397 , -0.78606105, ..., -0.48734847,
         -0.13829663, -0.12542208],
        [-0.72983396, -0.48340797, -0.34312868, ..., -0.48730347,
         -0.13810965, -0.12531614]],

       [[-0.88734734, -1.1440415 , -0.9657183 , ...,  0.43205604,
         -0.21212707, -0.6659065 ],
        [-0.63178384, -0.10507695, -0.466529  , ...,  0.4318235 ,
         -0.21199818, -0.66590583],
        [ 0.20463876,  1.0078816 ,  0.28294283, ...,  0.4315909 ,
         -0.21186927, -0.6659051 ],
        ...,
        [ 0.5208617 , -1.0851523 , -0.93463045, ...,  

## Is the result valid?

In [17]:
print("output query state of Pytorch:" + str(output_query_states.shape) )
print("output query state of TensorRT:" + str(trt_query_states.shape))

output query state of Pytorch:torch.Size([4, 32, 45, 128])
output query state of TensorRT:(4, 32, 45, 128)


In [18]:
np.allclose(output_query_states.clone().detach().cpu().numpy(), trt_query_states, atol=1e-06)

True