### P02 - Batching

This notebook will cover:
- What is batching?
- Throughput vs latency

#### Import required packages and load the LLM

In [15]:
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_name = "./models/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to("cuda")

In [3]:
# Get the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"The number of parameters in the model: {num_params:,}")

The number of parameters in the model: 124,439,808


#### KV-cache setup

In [4]:
prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values


def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = generate_token_with_past(next_inputs)
        next_inputs = {
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.tensor([[1]], device="cuda")],
                dim=1
            ),
            "past_key_values": past_key_values,
        }
        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
    return "".join(generated_tokens)


tokens = generate(inputs, max_tokens=10)
print(f"{prompt} - {tokens}")

The quick brown fox jumped over the -  fence and ran to the other side of the fence


#### Add padding tokens to the model

In [5]:
# define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [6]:
# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

Add padding to a batch of prompts such that they have the same length.

In [7]:
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must",
]
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to("cuda")
print("input_ids:", inputs["input_ids"])
print("attention_mask:", inputs["attention_mask"])

input_ids: tensor([[  464,  2068,  7586, 21831, 11687,   625,   262],
        [50256, 50256,   464,  6290,   287,  8602,  8953],
        [50256, 50256, 50256,  2061,  2058,   510,  1276]], device='cuda:0')
attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1]], device='cuda:0')


Add position ids to track original order of tokens in each prompt; padding tokens are set to `1` and then first real token starts with position `0`.

In [8]:
# position_ids tell the transformer the ordinal position of each token in the input sequence
# for single input inference, this is just [0 .. n] for n tokens
# but for batch inference, we need to zero-out the padding tokens at the start of the sequence
attention_mask = inputs["attention_mask"]
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1).to("cuda")
print(position_ids)

tensor([[0, 1, 2, 3, 4, 5, 6],
        [1, 1, 0, 1, 2, 3, 4],
        [1, 1, 1, 0, 1, 2, 3]], device='cuda:0')


In [9]:
# pass tokens to model to calculate logits
# same as before, but include the position_ids
with torch.no_grad():
    outputs = model(position_ids=position_ids, **inputs)
logits = outputs.logits
print(logits.shape)

torch.Size([3, 7, 50257])


In [10]:
# retrieve mostly likely token for each prompt
last_logits = logits[:, -1, :] 
next_token_ids = last_logits.argmax(dim=1) 
print(next_token_ids)
next_tokens = tokenizer.batch_decode(next_token_ids)
print(next_tokens)

tensor([13990,   319,   307], device='cuda:0')
[' fence', ' on', ' be']


#### Put it all together

In [11]:
# generate n tokens for each prompt - no change
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values

In [12]:
# generate max_tokens for each prompt
def generate_batch(inputs, max_tokens):
    # create a list of tokens for every input in the batch
    generated_tokens = [
        [] for _ in range(inputs["input_ids"].shape[0])
    ]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1).to("cuda")

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }

    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)

        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)), # reshape from (batch_size,) to (batch_size, 1)
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1, # slice as (batch_size, ) and unsqueeze to (batch_size, 1)
            "attention_mask": torch.cat([
                next_inputs["attention_mask"],
                torch.ones((next_token_ids.shape[0], 1), device="cuda"),
            ], dim=1),
            "past_key_values": past_key_values,
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

In [13]:
# generate 10 tokens for each prompt
generated_tokens = generate_batch(inputs, max_tokens=10)

for prompt, generated in zip(prompts, generated_tokens):
    print(prompt, f"\x1b[31m{generated}\x1b[0m")
    print("--------")

The quick brown fox jumped over the [31m fence and ran to the other side of the fence[0m
--------
The rain in Spain falls [31m on the first day of the month, and the[0m
--------
What comes up must [31m be a good idea.

"I think[0m
--------


#### Throughput vs Latency

In [14]:
# constants
max_tokens = 10

# observations
durations = []
throughputs = []
latencies = []

batch_sizes = [2**p for p in range(8)]
for batch_size in batch_sizes:
    # generate tokens for batch and record duration
    t0 = time.time()
    batch_prompts = [prompts[i % len(prompts)] for i in range(batch_size)]
    inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt").to("cuda")
    generated_tokens = generate_batch(inputs, max_tokens=max_tokens)
    duration_ms = (time.time() - t0) * 1e3

    ntokens = batch_size * max_tokens
    throughput = ntokens / duration_ms
    avg_latency = duration_ms / max_tokens

    print(f"bs= {batch_size:3d}, duration= {duration_ms:.1f}ms, throughput= {throughput:4.1f} tokens/ms, avg latency= {avg_latency:3.1f}ms")

    durations.append(duration_ms)
    throughputs.append(throughput)
    latencies.append(avg_latency)

bs=   1, duration= 59.8ms, throughput=  0.2 tokens/ms, avg latency= 6.0ms
bs=   2, duration= 59.9ms, throughput=  0.3 tokens/ms, avg latency= 6.0ms
bs=   4, duration= 59.3ms, throughput=  0.7 tokens/ms, avg latency= 5.9ms
bs=   8, duration= 61.6ms, throughput=  1.3 tokens/ms, avg latency= 6.2ms
bs=  16, duration= 64.4ms, throughput=  2.5 tokens/ms, avg latency= 6.4ms
bs=  32, duration= 67.2ms, throughput=  4.8 tokens/ms, avg latency= 6.7ms
bs=  64, duration= 75.6ms, throughput=  8.5 tokens/ms, avg latency= 7.6ms
bs= 128, duration= 93.0ms, throughput= 13.8 tokens/ms, avg latency= 9.3ms
