# Beam search with VLLM

## Goal

Learn how beam search works with VLLM

## Imports

In [None]:
import time
import textwrap
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import matplotlib.pyplot as plt
import matplotlib as mpl

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (25, 4)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

## Experiments with text generation

### Load model

In [None]:
class cfg:
    model_path: str = "/home/gbarbadillo/data/Qwen2-1.5B-Instruct"
    max_model_len: int = 8192 #61000 for phi-3

In [None]:
llm = LLM(model=cfg.model_path,
            trust_remote_code=True,
            dtype='half',
            tensor_parallel_size=2, # to use 2 gpus
            max_model_len=cfg.max_model_len,
            #kv_cache_dtype='fp8_e5m2', I have disabled kv cache quantization because it is hurtful
            enforce_eager=True, # without this 13.9GB of memory is used on each GPU, with this is 13.3GB,
            disable_log_stats=True,
            )
tokenizer = AutoTokenizer.from_pretrained(cfg.model_path)

### Prompt about Alexander the great and Tyre

If we use beam search we are able to generate text that has higher probability. The text changes.

```
greedy decoding
Cumulative logprob: -22.31
Alexander the Great was a legendary conqueror who ruled over a vast empire that
stretched from Greece to India. One of his most famous conquests was the city of
Tyre, which was located on the coast of Lebanon. Alexander had heard of Tyre

beam search, best_of 100
Cumulative logprob: -16.08
Alexander the Great was one of the most famous conquerors in history. He was
born in 356 BC and died in 323 BC, but his legacy lives on to this day. One of
his most famous conquests was the
```

In [None]:
def beam_search_experiment():
    logprob, runtime = [], []
    best_of_range = [1, 2, 4, 8, 16, 32, 64, 128]
    for n in best_of_range:
        if n == 1:
            sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
        else:
            sampling_params = SamplingParams(n=1, temperature=0.0, max_tokens=max_tokens, use_beam_search=True, best_of=n)
        t0 = time.time()
        ret = llm.generate(prompt, sampling_params, use_tqdm=False)
        logprob.append(ret[0].outputs[0].cumulative_logprob)
        runtime.append(time.time() - t0)
        print(f'## Best of {n} (logprob: {logprob[-1]:.2f}, runtime: {runtime[-1]:.2f}s)')
        print(textwrap.fill(ret[0].outputs[0].text, 80) + '\n\n')

    print(f'logprob: {logprob}')
    plt.subplot(1, 2, 1)
    plt.plot(best_of_range, logprob, label='logprob', marker='o')
    plt.xlabel('best_of')
    plt.ylabel('Logprob')
    plt.xscale('log')
    plt.grid(which='both')

    print(f'runtime: {runtime}')
    plt.subplot(1, 2, 2)
    plt.plot(best_of_range, runtime, label='runtime', marker='o')
    plt.xlabel('best_of')
    plt.ylabel('Runtime')
    plt.xscale('log')
    plt.grid(which='both')

    plt.suptitle(f'Effect of best_of when using beam search to generate {max_tokens} tokens');

In [None]:
system_prompt = 'You are a helpful AI assistant'
prompt = 'Write a small story about how Alexander the great conquered the city of Tyre'
messages = [{"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt
max_tokens = 100
beam_search_experiment()

## Pseudo beam search

In [None]:
def pseudo_beam_search_experiment(temperature):
    logprob, runtime = [], []
    n_range = [1, 2, 4, 8, 16, 32, 64, 128]
    for n in n_range:
        sampling_params = SamplingParams(n=n, temperature=temperature, max_tokens=max_tokens)
        t0 = time.time()
        ret = llm.generate(prompt, sampling_params, use_tqdm=False)
        print([round(output.cumulative_logprob, 2) for output in ret[0].outputs[:10]])
        logprob.append(ret[0].outputs[0].cumulative_logprob)
        runtime.append(time.time() - t0)
        print(f'## Best of {n} (logprob: {logprob[-1]:.2f}, runtime: {runtime[-1]:.2f}s)')
        print(textwrap.fill(ret[0].outputs[0].text, 80) + '\n\n')

    print(f'logprob: {logprob}')
    plt.subplot(1, 2, 1)
    plt.plot(n_range, logprob, label='logprob', marker='o')
    plt.xlabel('n')
    plt.ylabel('Logprob')
    plt.xscale('log')
    plt.grid(which='both')

    print(f'runtime: {runtime}')
    plt.subplot(1, 2, 2)
    plt.plot(n_range, runtime, label='runtime', marker='o')
    plt.xlabel('n')
    plt.ylabel('Runtime')
    plt.xscale('log')
    plt.grid(which='both')

    plt.suptitle(f'Effect of n when using pseudo beam search to generate {max_tokens} tokens with temperature={temperature}');

In [None]:
ret = pseudo_beam_search_experiment(0.2)

In [None]:
x = [1, 2, 4, 8, 16, 32, 64, 128]
logprob = [-51.2078690540302, -37.65839752664306, -35.001866647457064, -33.90833894416937, -32.81388856506328, -31.850933521825937, -31.82146939802442, -31.82100444171556]
runtime = [2.3626060485839844, 1.9727156162261963, 1.9264826774597168, 2.203904628753662, 3.0708730220794678, 4.34735894203186, 7.32841944694519, 12.297247409820557]
plt.subplot(1, 2, 1)
plt.plot(x, logprob, label='beam-search', marker='o')
plt.xlabel('n')
plt.ylabel('Logprob')
plt.xscale('log')
plt.grid(which='both')
plt.legend(loc=0)
plt.subplot(1, 2, 2)
plt.plot(x, runtime, label='beam-search', marker='o')
plt.xlabel('n')
plt.ylabel('Runtime')
plt.xscale('log')
plt.grid(which='both')
plt.legend(loc=0)

logprob = [-9.437750849371902, -4.664379475801766, -6.483303615268078, -5.2097757438402255, -3.5085881827150587, -3.4622852842907577, -3.637831517862125, -3.51099348212459]
runtime = [2.052835464477539, 1.9102702140808105, 2.0122568607330322, 1.8525443077087402, 2.0149896144866943, 2.2063076496124268, 2.0433428287506104, 2.752204418182373]
plt.subplot(1, 2, 1)
plt.plot(x, logprob, label='pseudo beam-search T=0.1', marker='o')
plt.subplot(1, 2, 2)
plt.plot(x, runtime, label='pseudo beam-search T=0.1', marker='o')

I cannot compare the logprobs because they are temperature dependent. I would rather measure ARC accuracy.

In [None]:
raise

## Does batching speedup inference?

It is unrelated to the notebook, but I want to find if batching speedsup the inference.

In [None]:
system_prompt = 'You are a helpful AI assistant'
prompt = 'Write a small story about how Alexander the great conquered the city of Tyre'
messages = [{"role": "system", "content": system_prompt},
            {"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
prompt
max_tokens = 200

In [None]:
def run_batch_experiment(n):
    sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
    t0 = time.time()
    for _ in range(n):
        ret = llm.generate(prompt, sampling_params, use_tqdm=False)
    print(f'## {n} inferences (runtime: {time.time() - t0:.2f}s)')

    t0 = time.time()
    ret = llm.generate([prompt]*n, sampling_params, use_tqdm=False)
    print(f'## Batch size {n} (runtime: {time.time() - t0:.2f}s)')

In [None]:
run_batch_experiment(5)

In [None]:
run_batch_experiment(10)

In [None]:
run_batch_experiment(20)

There seems to be a lot of potential when batching predictions.

```
## 20 inferences (runtime: 98.38s)
## Batch size 20 (runtime: 5.52s)
```

## Clean

In [None]:
def clear_vllm_gpu_memory():
    global llm
    # https://github.com/vllm-project/vllm/issues/1908
    from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment
    import torch
    import gc
    destroy_model_parallel()
    destroy_distributed_environment()
    del llm.llm_engine.model_executor
    del llm
    gc.collect()
    torch.cuda.empty_cache()

clear_vllm_gpu_memory()