# repeat_KV with TensorRT

In [1]:
import numpy as np
from cuda import cudart
import torch
from torch import Tensor, nn
import tensorrt as trt
import math

In [2]:
print("PyTorch version: " + torch.__version__)
print("TensorRT version: " + trt.__version__)

PyTorch version: 2.1.0a0+4136153
TensorRT version: 8.6.1


## 0. Generate input and data shape

In [3]:
config = dict()

batch_size, seq_len, hidden_size = 4, 45, 4096
intermediate_size = 11008
num_attention_heads = 32
num_key_value_heads = 32
max_position_embeddings = 2048
rope_theta = 10000.0

config["hidden_size"] = hidden_size
config["num_heads"] = num_attention_heads
config["head_dim"] = config["hidden_size"] // config["num_heads"]
config["num_key_value_heads"] = num_key_value_heads
config["num_key_value_groups"] = config["num_heads"] // config["num_key_value_heads"]


In [4]:
data = torch.ones(batch_size, seq_len, hidden_size)
print(data.shape)

## Prepare Pytorch Testing Parameter
q_proj = nn.Linear(config["hidden_size"], config["num_heads"] * config["hidden_size"] // config["num_heads"], bias=False)
k_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)
v_proj = nn.Linear(config["hidden_size"], config["num_key_value_heads"] * config["head_dim"], bias=False)

query_states = q_proj(data)
key_states = k_proj(data)
value_states = v_proj(data)
print("Input query_states: before reshape " +str(query_states.shape))
print(query_states[0])
bsz, q_len, _ = data.size()

# reshape
query_states = query_states.view(bsz, q_len, config["num_heads"], config["head_dim"]).transpose(1, 2)
key_states = key_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
value_states = value_states.view(bsz, q_len, config["num_key_value_heads"], config["head_dim"]).transpose(1, 2)
print("Input query_states: after reshape " + str(query_states.shape))
print(query_states[0])


torch.Size([4, 45, 4096])
Input query_states: before reshape torch.Size([4, 45, 4096])
tensor([[-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272],
        [-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272],
        [-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272],
        ...,
        [-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272],
        [-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272],
        [-0.2101, -0.5661,  0.1543,  ..., -0.9110,  0.0716,  0.8272]],
       grad_fn=<SelectBackward0>)
Input query_states: after reshape torch.Size([4, 32, 45, 128])
tensor([[[-0.2101, -0.5661,  0.1543,  ...,  0.5787, -0.5385, -0.3816],
         [-0.2101, -0.5661,  0.1543,  ...,  0.5787, -0.5385, -0.3816],
         [-0.2101, -0.5661,  0.1543,  ...,  0.5787, -0.5385, -0.3816],
         ...,
         [-0.2101, -0.5661,  0.1543,  ...,  0.5787, -0.5385, -0.3816],
         [-0.2101, -0.5661,  0.1543,  ...,  0.5787, -0.5385, -0.3816],
         [-0.

## 1. Repeat_kv with Pytorch

In [5]:
# #####################################################
# # in hugging face, they do have kv cache, however, they don't have other attention optimization
# # this could be done directly in tensorRT by using dynamic shape
# kv_seq_len = key_states.shape[-2]
# if past_key_value is not None:
#     kv_seq_len += past_key_value[0].shape[-2]

# query_states, key_states = self.rotary_emb(query_states, key_states, value_states, position_ids, seq_len=q_len)

# if past_key_value is not None:
#     # reuse k, v, self_attention
#     key_states = torch.cat([past_key_value[0], key_states], dim=2)
#     value_states = torch.cat([past_key_value[1], value_states], dim=2)

# past_key_value = (key_states, value_states) if use_cache else None

# print(self.num_key_value_groups)
# # repeat k/v heads if n_kv_heads < n_heads
# key_states = repeat_kv(key_states, self.num_key_value_groups)
# value_states = repeat_kv(value_states, self.num_key_value_groups)
# #####################################################

In [6]:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    
    repeat at the second dimension
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    print("Input shape: ")
    print(hidden_states.shape)
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    print("Input shape after reshape: ")
    print(hidden_states.shape)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

In [7]:
n_rep = 3

In [8]:
key_states_repeat = repeat_kv(key_states, n_rep)

Input shape: 
torch.Size([4, 32, 45, 128])
Input shape after reshape: 
torch.Size([4, 32, 3, 45, 128])


In [9]:
key_states_repeat.shape

torch.Size([4, 96, 45, 128])

In [10]:
x = torch.tensor([1, 2, 3])
print(x.shape)

print(x.expand(2,3))
print(x.expand(2,3).shape)

print(x.repeat(2,3))
print(x.repeat(2,3).shape)

torch.Size([3])
tensor([[1, 2, 3],
        [1, 2, 3]])
torch.Size([2, 3])
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
        [1, 2, 3, 1, 2, 3, 1, 2, 3]])
torch.Size([2, 9])


In [11]:
key_states.shape

torch.Size([4, 32, 45, 128])

In [12]:
key_states.repeat(1, 2, 1, 1).shape

torch.Size([4, 64, 45, 128])

## 2. Repeat_kv with TensorRT

In [13]:
def trt_create(batch_size, num_attention_heads, dim):
    # Config TensorRT Logger, Builder, Network
    logger = trt.Logger(trt.Logger.ERROR)
    builder = trt.Builder(logger)

    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    config = builder.create_builder_config()

    # inputs: hidden_state, n_rep with dynamic shape
    hidden_states = network.add_input('hidden_states', trt.DataType.FLOAT, (batch_size, num_attention_heads, -1, dim))
    repeat_states = network.add_input('repeat_states', trt.DataType.FLOAT, (batch_size, -1, -1, dim))
    
    # dynamic shape optimization
    profile = builder.create_optimization_profile();
    profile.set_shape("hidden_states", 
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads, 45, dim), 
                      (batch_size, num_attention_heads, 1024, dim))
    profile.set_shape("repeat_states", 
                      (batch_size, num_attention_heads, 1, dim), 
                      (batch_size, num_attention_heads*2, 45, dim), 
                      (batch_size, num_attention_heads*10, 1024, dim)) 
    
    config.add_optimization_profile(profile)
    
    print("- 0) input: hidden_states, repeat_states shape :")
    print(hidden_states.shape, repeat_states.shape)

    # 1. Get repeat_hidden_states shape with repeat_states
    # Check detail how to do repeat like repeat in pytorch? #2408
    # https://github.com/NVIDIA/TensorRT/issues/2408
    print("- 1) Get repeat_hidden_states shape:")
    repeat_hidden_states_shape = network.add_shape(repeat_states)
    print(repeat_hidden_states_shape.get_output(0).shape)
    
    print("- 2) repeat_hidden_states :")
    repeat_hidden_states = network.add_slice(hidden_states, start=(0, 0, 0, 0), shape=(1, 1, 1, 1), stride=(1, 1, 1, 1))
    repeat_hidden_states.set_input(2, repeat_hidden_states_shape.get_output(0))
    repeat_hidden_states.mode = trt.SliceMode.WRAP
    
    print("- 3) check repeat_hidden_states shape :")
    print(repeat_hidden_states.get_output(0).shape)
    
    network.mark_output(repeat_hidden_states.get_output(0))

    engineString = builder.build_serialized_network(network, config)
    
    return engineString

In [14]:
trt_engineStr = trt_create(batch_size = batch_size, 
                           num_attention_heads = config["num_heads"],
                           dim = config["head_dim"])

- 0) input: hidden_states, repeat_states shape :
(4, 32, -1, 128) (4, -1, -1, 128)
- 1) Get repeat_hidden_states shape:
(4,)
- 2) repeat_hidden_states :
- 3) check repeat_hidden_states shape :
(4, -1, -1, 128)


In [15]:
def trt_inference(batch_size, num_attention_heads, dim, engineString, h_state, r_state): 

    print("Runtime")
    logger = trt.Logger(trt.Logger.ERROR)
    engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)
    context = engine.create_execution_context()

    # dynamic shape configure
    print("Set input shape: hidden_states")
    #context.active_optimization_profile = 0
    
    h_shape = context.get_binding_shape(0)
    print(h_shape)
    context.set_input_shape("hidden_states", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))

    print("Set input shape: repeat_states")
    r_shape = context.get_binding_shape(1)
    print(r_shape)
    context.set_input_shape("repeat_states", (batch_size, num_attention_heads, seq_len, dim))
    context.set_binding_shape(1, (batch_size, num_attention_heads*n_rep, seq_len, dim))
 
    print("Set input shape completed")

    h_state_data = np.array(h_state)
    r_state_data = np.array(r_state)
    
    _, stream = cudart.cudaStreamCreate()

    inputH0 = np.ascontiguousarray(h_state_data.reshape(-1))
    inputH1 = np.ascontiguousarray(r_state_data.reshape(-1))
    outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))

    # initialize input and output data
    _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)
    _, inputD1 = cudart.cudaMallocAsync(inputH1.nbytes, stream)
    _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)
   
    # move input to device
    cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    cudart.cudaMemcpyAsync(inputD1, inputH1.ctypes.data, inputH1.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)
    
    # execute
#     print("execute")
    context.execute_async_v2([int(inputD0), int(inputD1), int(outputD0)], stream)

    # move output back to host
    cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)
    
    # wait for everythidden_sizeg
    cudart.cudaStreamSynchronize(stream)

    cudart.cudaStreamDestroy(stream)
    cudart.cudaFree(inputD0)
    cudart.cudaFree(inputD1)
    cudart.cudaFree(outputD0)

    return outputH0

In [16]:
repeat_state = key_states.repeat(1, n_rep, 1, 1)

In [17]:
h_state = key_states.detach().numpy()
r_state = repeat_state.detach().numpy()

In [18]:
trt_output = trt_inference(batch_size, config["num_heads"], config["head_dim"],
                           trt_engineStr, 
                           h_state, r_state)

trt_repeat_states = trt_output

Runtime
Set input shape: hidden_states
(4, 32, -1, 128)
Set input shape: repeat_states
(4, -1, -1, 128)
Set input shape completed


  h_shape = context.get_binding_shape(0)
  context.set_binding_shape(0, (batch_size, num_attention_heads, seq_len, dim))
  r_shape = context.get_binding_shape(1)
  context.set_binding_shape(1, (batch_size, num_attention_heads*n_rep, seq_len, dim))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))
  outputH0 = np.empty(context.get_binding_shape(2), dtype=trt.nptype(engine.get_binding_dtype(2)))


In [19]:
key_states_repeat.shape, trt_repeat_states.shape

(torch.Size([4, 96, 45, 128]), (4, 96, 45, 128))

In [20]:
key_states_repeat[0]

tensor([[[-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         ...,
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087]],

        [[-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         ...,
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087]],

        [[-0.2323,  0.7661,  0.1553,  ..., -0.0610,  1.0893,  0.1087],
         [-0.2323,  0.7661,  0.1553,  ..., -0

In [21]:
trt_repeat_states[0]

array([[[-0.23231319,  0.7661275 ,  0.15527761, ..., -0.06097494,
          1.0892941 ,  0.1087038 ],
        [-0.23231319,  0.7661275 ,  0.15527761, ..., -0.06097494,
          1.0892941 ,  0.1087038 ],
        [-0.23231319,  0.7661275 ,  0.15527761, ..., -0.06097494,
          1.0892941 ,  0.1087038 ],
        ...,
        [-0.23231313,  0.7661277 ,  0.15527785, ..., -0.06097482,
          1.0892944 ,  0.10870392],
        [-0.23231313,  0.7661277 ,  0.15527785, ..., -0.06097482,
          1.0892944 ,  0.10870392],
        [-0.23231313,  0.7661277 ,  0.15527785, ..., -0.06097482,
          1.0892944 ,  0.10870392]],

       [[ 1.2015883 ,  0.4420061 , -0.05878636, ...,  0.0535775 ,
         -0.6432709 , -0.21525244],
        [ 1.2015883 ,  0.4420061 , -0.05878636, ...,  0.0535775 ,
         -0.6432709 , -0.21525244],
        [ 1.2015883 ,  0.4420061 , -0.05878636, ...,  0.0535775 ,
         -0.6432709 , -0.21525244],
        ...,
        [ 1.2015884 ,  0.4420063 , -0.05878617, ...,  