<a href="https://colab.research.google.com/github/vikaskapur/my_genai_apps/blob/llm_serving/3_Continuous_Batching.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Continuous Batching

In this notebook, we'll discuss the production set up of "batching" in LLM inference, "Continuous batching".

- The key idea behind continuous batching is constantly swap out requests from the batch that have completed generation for requests in the queue that are waiting to be processed.

### Import required packages and load the LLM

In [1]:
! git clone https://github.com/vikaskapur/my_genai_apps.git


Cloning into 'my_genai_apps'...
remote: Enumerating objects: 118, done.[K
remote: Counting objects: 100% (118/118), done.[K
remote: Compressing objects: 100% (108/108), done.[K
remote: Total 118 (delta 40), reused 23 (delta 0), pack-reused 0[K
Receiving objects: 100% (118/118), 318.47 KiB | 1.81 MiB/s, done.
Resolving deltas: 100% (40/40), done.


In [2]:
%cd my_genai_apps/llm_serving

/content/my_genai_apps/llm_serving


In [5]:
!ls

notebooks  README.md  requirements.txt	utils


In [6]:
import utils.helpers
from utils.helpers import init_batch, generate_next_token
from utils.helpers import merge_batches, filter_batch

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

cuda:0


In [7]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [8]:
model_name = "openai-community/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
model = model.to(device)

cuda:0


### Add padding tokens to the model to prepare batches of prompts

In [9]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"

In [10]:
# multiple prompts of varying lengths to send to the model at once
prompts = [
    "The quick brown fox jumped over the",
    "The rain in Spain falls",
    "What comes up must",
]

# note: padding=True ensures the padding token will be inserted into the tokenized tensors
inputs = tokenizer(prompts, padding=True, return_tensors="pt").to(device)

### Define needed functions for batching

In [11]:
def generate_batch_tokens_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[:, -1, :]
    next_token_ids = last_logits.argmax(dim=1)
    return next_token_ids, outputs.past_key_values


def generate_batch(inputs, max_tokens):
    # create a list of tokens for every input in the batch
    generated_tokens = [[] for _ in range(inputs["input_ids"].shape[0])]

    attention_mask = inputs["attention_mask"]
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)

    next_inputs = {
        "position_ids": position_ids,
        **inputs
    }
    for _ in range(max_tokens):
        next_token_ids, past_key_values = generate_batch_tokens_with_past(next_inputs)
        next_inputs = {
            "input_ids": next_token_ids.reshape((-1, 1)),  # '-1' here means the remaining elements for this dim
            "position_ids": next_inputs["position_ids"][:, -1].unsqueeze(-1) + 1,  # increment last, discard the rest
            "attention_mask": torch.cat([
                next_inputs["attention_mask"],
                torch.ones((next_token_ids.shape[0], 1)).to(device),  # concatenate vector of 1's with shape [batch_size]
            ], dim=1),
            "past_key_values": past_key_values,
        }

        next_tokens = tokenizer.batch_decode(next_token_ids)
        for i, token in enumerate(next_tokens):
            generated_tokens[i].append(token)
    return ["".join(tokens) for tokens in generated_tokens]

### Define the requests to be processed

In [12]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

In [13]:
request_queue

[('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped o

In [14]:
batches = [
    request_queue[i:i + batch_size]
    for i in range(0, len(request_queue), batch_size)
]

In [15]:
len(batches)

4

In [16]:
batches[0]

[('The quick brown fox jumped over the', 100),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10),
 ('The quick brown fox jumped over the', 10)]

### Processing batches



In [17]:
# generate tokens for all batches and record duration
t0 = time.time()
with tqdm(total=len(batches), desc=f"bs={batch_size}") as pbar:
    for i, batch in enumerate(batches):
        # to accommodate all the requests with our
        # current implementation, we take the max of
        # all the tokens to generate among the requests
        batch_max_tokens = [b[1] for b in batch]
        max_tokens = max(batch_max_tokens)
        pbar.set_postfix({'max_tokens': max_tokens})

        batch_prompts = [b[0] for b in batch]
        inputs = tokenizer(
            batch_prompts, padding=True, return_tensors="pt").to(device)
        generate_batch(inputs, max_tokens=max_tokens)

        pbar.update(1)

duration_s = time.time() - t0
print("duration", duration_s)

bs=8: 100%|██████████| 4/4 [00:07<00:00,  1.91s/it, max_tokens=100]

duration 7.6376793384552





### Let's try continuous batching

- This time, rather than processing each batch to completion, you will use continuous batching to dynamically swap in and out inputs from the queue.


In [24]:
# seed the random number generator so our results are deterministic
random.seed(42)

# constants
queue_size = 32
batch_size = 8

# requests waiting to be processed
# this time requests are tuples (prompt, max_tokens)
request_queue = [
    (prompts[0], 100 if i % batch_size == 0 else 10)
    for i in range(queue_size)
]

t0 = time.time()
with tqdm(total=len(request_queue), desc=f"bs={batch_size}") as pbar:
    # first, let's seed the initial cached_batch
    # with the first `batch_size` inputs
    # and run the initial prefill step
    batch = init_batch(request_queue[:batch_size])
    cached_batch = generate_next_token(batch)
    request_queue = request_queue[batch_size:]

    print(f'batch : {batch}')
    print(f'cached_batch:')
    print(f'cached_batch: input_ids: {cached_batch["input_ids"]}')
    print(f'cached_batch: position_ids: {cached_batch["position_ids"]}')
    print(f'cached_batch: attention_mask: {cached_batch["attention_mask"]}')
    print(f'cached_batch: responses: {cached_batch["responses"]}')
    print(f'cached_batch: tokens_remaining: {cached_batch["tokens_remaining"]}')

    print(f'cached_batch generation_complete: {cached_batch["input_ids"].size(0)}')




    # continue until both the request queue is
    # fully drained and every input
    # within the cached_batch has completed generation
    while (
        len(request_queue) > 0 or
        cached_batch["input_ids"].size(0) > 0
    ):
        batch_capacity = (
            batch_size - cached_batch["input_ids"].size(0)
        )
        print(f'request_queue len : {len(request_queue)}')
        print(f'batch_capacity : {batch_capacity}')

        if batch_capacity > 0 and len(request_queue) > 0:
            # prefill
            new_batch = init_batch(request_queue[:batch_capacity])
            new_batch = generate_next_token(new_batch)
            request_queue = request_queue[batch_capacity:]

            # merge
            cached_batch = merge_batches(cached_batch, new_batch)

        # decode
        cached_batch = generate_next_token(cached_batch)

        # remove any inputs that have finished generation
        cached_batch, removed_indices = filter_batch(cached_batch)
        pbar.update(len(removed_indices))

duration_s = time.time() - t0
print("duration", duration_s)

bs=8:   0%|          | 0/32 [00:00<?, ?it/s]

batch : {'position_ids': tensor([[0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6],
        [0, 1, 2, 3, 4, 5, 6]], device='cuda:0'), 'responses': ['The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the', 'The quick brown fox jumped over the'], 'tokens_remaining': [100, 10, 10, 10, 10, 10, 10, 10], 'input_ids': tensor([[  464,  2068,  7586, 21831, 11687,   625,   262],
        [  464,  2068,  7586, 21831, 11687,   625,   262],
        [  464,  2068,  7586, 21831, 11687,   625,   262],
        [  464,  2068,  7586, 21831, 11687,   625,   262],
        [  464,  2068,  7586, 21831, 11687,   625,   262],
        [  464,  2

bs=8:  22%|██▏       | 7/32 [00:00<00:00, 66.61it/s]

request_queue len : 24
batch_capacity : 0
request_queue len : 24
batch_capacity : 0
request_queue len : 24
batch_capacity : 0
request_queue len : 24
batch_capacity : 0
request_queue len : 24
batch_capacity : 7
request_queue len : 17
batch_capacity : 0


bs=8:  41%|████      | 13/32 [00:00<00:00, 33.48it/s]

request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 0
request_queue len : 17
batch_capacity : 6
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0


bs=8:  56%|█████▋    | 18/32 [00:01<00:00, 23.13it/s]

request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 0
request_queue len : 11
batch_capacity : 5
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0


bs=8:  69%|██████▉   | 22/32 [00:01<00:00, 19.05it/s]

request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 0
request_queue len : 6
batch_capacity : 4
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0


bs=8:  81%|████████▏ | 26/32 [00:01<00:00, 15.95it/s]

request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 0
request_queue len : 2
batch_capacity : 4
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2


bs=8:  88%|████████▊ | 28/32 [00:01<00:00, 13.81it/s]

request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 2
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue len : 0
batch_capacity : 4
request_queue le

bs=8:  94%|█████████▍| 30/32 [00:03<00:00,  5.11it/s]

request_queue len : 0
batch_capacity : 5
request_queue len : 0
batch_capacity : 5
request_queue len : 0
batch_capacity : 5
request_queue len : 0
batch_capacity : 5
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6


bs=8: 100%|██████████| 32/32 [00:03<00:00,  9.22it/s]

request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 6
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
request_queue len : 0
batch_capacity : 7
duration 3.484861373901367



