# repeat_KV with TensorRT

In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

In [2]:
print("PyTorch version: " + torch.__version__)
print("TensorRT version: " + trt.__version__)

PyTorch version: 2.1.0a0+4136153
TensorRT version: 8.6.1


## 0. Generate input and data shape

In [3]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]


In [4]:
data = torch.ones(batch_size, seq_len, hidden_size)
print(data.shape)

## Prepare Pytorch Testing Parameter
q_proj = nn.Linear(config["hidden_size"], config["num_heads"] * config["hidden_size"] // config["num_heads"], bias=False)
k_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)
v_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)

query_states = q_proj(data)
key_states = k_proj(data)
value_states = v_proj(data)
print("Input query_states: before reshape " +str(query_states.shape))
print(query_states[0])
bsz, q_len, _ = data.size()

# reshape
query_states = query_states.view(bsz, q_len, config["num_heads"], config["head_dim"]).transpose(1, 2)
key_states = key_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
value_states = value_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
print("Input query_states: after reshape " + str(query_states.shape))
print(query_states[0])


torch.Size([4, 45, 4096])
Input query_states: before reshape torch.Size([4, 45, 4096])
tensor([[ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554],
        [ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554],
        [ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554],
        ...,
        [ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554],
        [ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554],
        [ 0.6771,  0.0336,  0.7661,  ..., -0.1484,  0.5155, -0.7554]],
       grad_fn=<SelectBackward0>)
Input query_states: after reshape torch.Size([4, 32, 45, 128])
tensor([[[ 0.6771,  0.0336,  0.7661,  ...,  0.7346, -0.6431, -0.2101],
         [ 0.6771,  0.0336,  0.7661,  ...,  0.7346, -0.6431, -0.2101],
         [ 0.6771,  0.0336,  0.7661,  ...,  0.7346, -0.6431, -0.2101],
         ...,
         [ 0.6771,  0.0336,  0.7661,  ...,  0.7346, -0.6431, -0.2101],
         [ 0.6771,  0.0336,  0.7661,  ...,  0.7346, -0.6431, -0.2101],
         [ 0.

## 1. Repeat_kv with Pytorch

In [5]:
# #####################################################
# # in hugging face, they do have kv cache, however, they don't have other attention optimization
# # this could be done directly in tensorRT by using dynamic shape
# kv_seq_len = key_states.shape[-2]
# if past_key_value is not None:
#     kv_seq_len += past_key_value[0].shape[-2]

# query_states, key_states = self.rotary_emb(query_states, key_states, value_states, position_ids, seq_len=q_len)

# if past_key_value is not None:
#     # reuse k, v, self_attention
#     key_states = torch.cat([past_key_value[0], key_states], dim=2)
#     value_states = torch.cat([past_key_value[1], value_states], dim=2)

# past_key_value = (key_states, value_states) if use_cache else None

# print(self.num_key_value_groups)
# # repeat k/v heads if n_kv_heads < n_heads
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
# #####################################################

In [6]:
def concat_kv(hidden_states: torch.Tensor, hidden_states_2: torch.Tensor,) -> torch.Tensor:
    print("Input shape : ")
    print(hidden_states.shape)
    new_states = torch.cat([hidden_states, hidden_states_2], dim=2)
    print("Output shape : ")
    print(new_states.shape)
    return new_states

In [7]:
concat_key_states = concat_kv(key_states, key_states)

Input shape : 
torch.Size([4, 32, 45, 128])
Output shape : 
torch.Size([4, 32, 90, 128])


In [8]:
concat_key_states.shape

torch.Size([4, 32, 90, 128])

## 2. Concat_kv with TensorRT

In [9]:
def trt_create(batch_size, num_attention_heads, dim):
    # Config TensorRT Logger, Builder, Network
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # inputs: hidden_state, n_rep with dynamic shape
    hidden_states = network.add_input('hidden_states', trt.DataType.FLOAT, (batch_size, num_attention_heads, -1, dim))
    hidden_states_2 = network.add_input('hidden_states_2', trt.DataType.FLOAT, (batch_size, num_attention_heads, -1, dim))
    
    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("hidden_states", 
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads, 45, dim), 
                      (batch_size, num_attention_heads, 1024, dim))
    profile.set_shape("hidden_states_2", 
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads, 45, dim), 
                      (batch_size, num_attention_heads, 1024, dim)) 
    
    config.add_optimization_profile(profile)
    
    print("- 0) input: hidden_states, repeat_states shape :")
    print(hidden_states.shape, hidden_states_2.shape)

    print("- 1) Get concat_states shape:")
    concat_states = network.add_concatenation([hidden_states, hidden_states_2])
    concat_states.axis = 2
    
    print(concat_states.get_output(0).shape)
  
    network.mark_output(concat_states.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [10]:
trt_engineStr = trt_create(batch_size = batch_size, 
                           num_attention_heads = config["num_heads"],
                           dim = config["head_dim"])

- 0) input: hidden_states, repeat_states shape :
(4, 32, -1, 128) (4, 32, -1, 128)
- 1) Get concat_states shape:
(4, 32, -1, 128)


In [11]:
def trt_inference(batch_size, num_attention_heads, dim, engineString, h_state, h_state_2): 

    print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    print("Set input shape: hidden_states")
    #context.active_optimization_profile = 0
    
    h_shape = context.get_binding_shape(0)
    print(h_shape)
    context.set_input_shape("hidden_states", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))

    print("Set input shape: repeat_states")
    h_shape_2 = context.get_binding_shape(1)
    print(h_shape_2)
    context.set_input_shape("hidden_states_2", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(1, (batch_size, num_attention_heads, seq_len, dim))
 
    print("Set input shape completed")

    h_state_data = np.array(h_state)
    h_state_2_data = np.array(h_state_2)
    
    _, stream = cudart.cudaStreamCreate()

    inputH0 = np.ascontiguousarray(h_state_data.reshape(-1))
    inputH1 = np.ascontiguousarray(h_state_2_data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, inputD1 = cudart.cudaMallocAsync(inputH1.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)
   
    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    cudart.cudaMemcpyAsync(inputD1, inputH1.ctypes.data, inputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    
    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(inputD1), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    
    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(inputD1)
    cudart.cudaFree(outputD0)

    return outputH0

In [12]:
h_state = key_states.detach().numpy()

In [13]:
trt_output = trt_inference(batch_size, config["num_heads"], config["head_dim"],
                           trt_engineStr, 
                           h_state, h_state)

trt_concat_states = trt_output

Runtime
Set input shape: hidden_states
(4, 32, -1, 128)
Set input shape: repeat_states
(4, 32, -1, 128)
Set input shape completed


  h_shape = context.get_binding_shape(0)
  context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))
  h_shape_2 = context.get_binding_shape(1)
  context.set_binding_shape(1, (batch_size, num_attention_heads, seq_len, dim))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))


In [14]:
concat_key_states.shape, trt_concat_states.shape

(torch.Size([4, 32, 90, 128]), (4, 32, 90, 128))

## Is the result valid?

In [15]:
np.allclose(concat_key_states.clone().detach().cpu().numpy(), trt_concat_states, atol=1e-06)

True

In [16]:
concat_key_states[0]

tensor([[[-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249],
         [-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249],
         [-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249],
         ...,
         [-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249],
         [-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249],
         [-0.1029,  0.4360,  1.7170,  ...,  1.0561,  0.6612,  0.4249]],

        [[-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475],
         [-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475],
         [-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475],
         ...,
         [-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475],
         [-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475],
         [-1.0432, -0.3430,  0.2046,  ...,  0.0911,  0.0408,  0.0475]],

        [[ 1.2805,  0.3576,  0.9282,  ...,  0.5669,  0.1782, -0.4079],
         [ 1.2805,  0.3576,  0.9282,  ...,  0

In [17]:
trt_concat_states[0]

array([[[-0.10292056,  0.4360279 ,  1.7169639 , ...,  1.0561283 ,
          0.66121805,  0.424856  ],
        [-0.10292056,  0.4360279 ,  1.7169639 , ...,  1.0561283 ,
          0.66121805,  0.424856  ],
        [-0.10292056,  0.4360279 ,  1.7169639 , ...,  1.0561283 ,
          0.66121805,  0.424856  ],
        ...,
        [-0.10292089,  0.436028  ,  1.7169638 , ...,  1.0561281 ,
          0.661218  ,  0.42485595],
        [-0.10292089,  0.436028  ,  1.7169638 , ...,  1.0561281 ,
          0.661218  ,  0.42485595],
        [-0.10292089,  0.436028  ,  1.7169638 , ...,  1.0561281 ,
          0.661218  ,  0.42485595]],

       [[-1.0431769 , -0.34295708,  0.20461127, ...,  0.09106764,
          0.04082745,  0.04753289],
        [-1.0431769 , -0.34295708,  0.20461127, ...,  0.09106764,
          0.04082745,  0.04753289],
        [-1.0431769 , -0.34295708,  0.20461127, ...,  0.09106764,
          0.04082745,  0.04753289],
        ...,
        [-1.0431767 , -0.3429572 ,  0.20461132, ...,  