# MCTS vs Best-of-N Experiment - Google Colab (vLLM Backend)

This notebook compares **MCTS** and **Best-of-N** search methods on GSM8K math problems.

**WARNING: vLLM has compatibility issues with Google Colab. Use the Transformers version instead if you encounter errors.**

**This version uses vLLM for faster inference and proper logprobs (MCTS priors).**

**Requirements:**
- GPU runtime (T4 or better)
- Linux environment (Kaggle or Cloud VM recommended)

**Important:** Download results before session ends!

## 1. Check GPU and Setup

In [None]:
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    raise RuntimeError("No GPU found!")

Wed Dec 10 22:49:56 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   76C    P8             13W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## 2. Install Dependencies

**Important:** We use `--no-deps` to prevent vLLM from overriding Colab's numpy, which causes compatibility issues.

In [None]:
#@title Colab Install - Latest vLLM { display-mode: "form" }
import os

# Check current versions
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")

!pip install --upgrade -qqq uv

# Use latest vLLM (should be compatible with PyTorch 2.6)
# Preserve numpy and pillow only
try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
except: get_numpy = "numpy"; get_pil = "pillow"

print(f"\nInstalling latest vLLM...")
print(f"Preserving: {get_numpy}, {get_pil}")

# Install latest vLLM without version pin
!uv pip install --upgrade vllm {get_numpy} {get_pil}

# Install other requirements
!uv pip install transformers datasets accelerate pyyaml tqdm matplotlib seaborn huggingface_hub

print("\n✓ Installation complete!")
!pip show vllm | grep -E "^(Name|Version)"
!pip show torch | grep -E "^(Name|Version)"

PyTorch: 2.9.0+cu126
CUDA: 12.6

Installing latest vLLM...
Preserving: numpy==2.0.2, pillow==11.3.0
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m153 packages[0m [2min 956ms[0m[0m
[2K[2mPrepared [1m2 packages[0m [2min 1ms[0m[0m
[2mUninstalled [1m2 packages[0m [2min 4ms[0m[0m
[2K[2mInstalled [1m2 packages[0m [2min 24ms[0m[0m
 [31m-[39m [1mdill[0m[2m==0.3.8[0m
 [32m+[39m [1mdill[0m[2m==0.4.0[0m
 [31m-[39m [1mfsspec[0m[2m==2025.3.0[0m
 [32m+[39m [1mfsspec[0m[2m==2025.12.0[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m69 packages[0m [2min 136ms[0m[0m
[2mUninstalled [1m2 packages[0m [2min 2ms[0m[0m
[2K[2mInstalled [1m2 packages[0m [2min 8ms[0m[0m
 [31m-[39m [1mdill[0m[2m==0.4.0[0m
 [32m+[39m [1mdill[0m[2m==0.3.8[0m
 [31m-[39m [1mfsspec[0m[2m==2025.12.0[0m
 [32m+[39m [1mfsspec[0m[2m==2025.3.0[0m

✓ Installation complete!
Name: vllm
Version: 0.12.0
Name

## 2.1 HuggingFace Authentication

In [None]:
import os
from huggingface_hub import login
from google.colab import userdata

# Get token from Colab secrets (recommended) or environment variable
try:
    HF_TOKEN = userdata.get('HF_TOKEN')
except:
    HF_TOKEN = os.environ.get('HF_TOKEN')

if not HF_TOKEN:
    raise ValueError("HF_TOKEN not found! Add it to Colab secrets or set as environment variable.")

login(token=HF_TOKEN)
print("Logged in to HuggingFace!")

## 3. Create Project Structure

In [None]:
import os
os.makedirs("src", exist_ok=True)
os.makedirs("configs", exist_ok=True)
os.makedirs("outputs", exist_ok=True)
os.makedirs("logs", exist_ok=True)
print("Project structure created!")

Project structure created!


In [None]:
%%writefile src/__init__.py
# MCTS Math Package

Overwriting src/__init__.py


In [None]:
%%writefile src/utils.py
import random
import numpy as np
import torch
import logging
import os
import sys
from typing import Optional

def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def setup_logging(name: str, log_dir: str):
    os.makedirs(log_dir, exist_ok=True)
    logger = logging.getLogger(name)
    if logger.handlers: return logger
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(message)s")
    sh = logging.StreamHandler(sys.stdout)
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    return logger

def extract_answer(text: str) -> Optional[str]:
    if not text: return None
    boxed_start = text.rfind("\\boxed{")
    if boxed_start != -1:
        start = boxed_start + len("\\boxed{")
        depth, end = 1, start
        while end < len(text) and depth > 0:
            if text[end] == '{': depth += 1
            elif text[end] == '}': depth -= 1
            end += 1
        if depth == 0: return text[start:end-1].strip()
    if "####" in text: return text.split("####")[-1].strip()
    return None

def is_correct(pred: str, truth: str) -> bool:
    if pred is None or truth is None: return False
    def norm(s):
        s = str(s).replace(",", "").replace(" ", "").split("=")[-1]
        try: return float(s)
        except: return s
    return norm(pred) == norm(truth)

Overwriting src/utils.py


In [None]:
%%writefile src/config_parser.py
import yaml
def load_config(path: str):
    with open(path) as f: return yaml.safe_load(f)

Overwriting src/config_parser.py


In [None]:
%%writefile src/dataset.py
from datasets import load_dataset

class GSM8KDataset:
    def __init__(self, config):
        self.dataset = load_dataset("openai/gsm8k", "main", split="test")
        if config.get("num_samples"):
            self.dataset = self.dataset.select(range(config["num_samples"]))
    def __len__(self): return len(self.dataset)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        q, a = item.get("question", ""), item.get("answer", "")
        prompt = ("<|im_start|>system\nPlease reason step by step. "
                  "End your final answer with \\boxed{answer}.<|im_end|>\n"
                  "<|im_start|>user\n" + q + "<|im_end|>\n<|im_start|>assistant\n")
        return {"question": q, "ground_truth": a, "prompt": prompt}

Overwriting src/dataset.py


In [None]:
%%writefile src/llm_engine.py
from vllm import LLM, SamplingParams
from typing import List, Tuple
import math
from collections import Counter
from src.utils import extract_answer

def normalize_answer(s):
    if s is None: return None
    s = str(s).replace(",", "").replace(" ", "").strip()
    try:
        v = float(s)
        return str(int(v)) if v == int(v) else str(v)
    except: return s

class LLMEngine:
    def __init__(self, config):
        self.config = config
        self.llm = LLM(
            model=config['model']['model_id'],
            dtype=config['model']['dtype'],
            gpu_memory_utilization=config['model']['gpu_memory_utilization'],
            max_model_len=config['model'].get('max_model_len', 2048),
            trust_remote_code=True,
        )

    def _generate(self, prompt, n, temp, max_tokens):
        params = SamplingParams(n=n, temperature=temp, max_tokens=max_tokens)
        outputs = self.llm.generate([prompt], params, use_tqdm=False)
        return [seq.text for seq in outputs[0].outputs]

    def generate_steps(self, state: str, n: int = 3) -> List[Tuple[str, float]]:
        params = SamplingParams(
            n=n, temperature=self.config['generation']['temperature'],
            max_tokens=512,  # Increased for multi-step reasoning
            stop=self.config['generation']['stop_tokens'],
            logprobs=1
        )
        outputs = self.llm.generate([state], params, use_tqdm=False)
        candidates = []
        for seq in outputs[0].outputs:
            text = seq.text.strip()
            if text:
                # FIX: Use average logprob per token, not cumulative
                # cumulative_logprob is very negative (e.g., -50), exp(-50) ≈ 0
                # Normalizing by token count gives meaningful relative probabilities
                num_tokens = max(len(seq.token_ids), 1)
                avg_logprob = seq.cumulative_logprob / num_tokens
                prior = math.exp(avg_logprob)  # Now in reasonable range (0.5-0.95)
                candidates.append((text, prior))
        return candidates if candidates else [("Let me solve this step by step.", 1.0)]

    def get_consensus_value(self, state: str, n_rollouts: int = 3) -> float:
        if n_rollouts == 0: return 0.0
        texts = self._generate(state, n_rollouts, 0.7, 512)
        answers = [normalize_answer(extract_answer(t)) for t in texts]
        answers = [a for a in answers if a]
        if not answers: return 0.0
        return Counter(answers).most_common(1)[0][1] / len(answers)

    def greedy_complete(self, state: str) -> str:
        texts = self._generate(state, 1, 0.0, 512)
        return texts[0] if texts else ""

Overwriting src/llm_engine.py


In [None]:
# %%writefile src/mcts.py
# import math
# import numpy as np

# class MCTSNode:
#     def __init__(self, state, parent=None, action=None, prior=0.0):
#         self.state, self.parent, self.action = state, parent, action
#         self.children, self.visits, self.value_sum, self.prior = [], 0, 0.0, prior

#     @property
#     def value(self):
#         return self.value_sum / (self.visits + 1e-6)

#     def is_fully_expanded(self):
#         return len(self.children) > 0

#     def is_terminal(self):
#         return "\\boxed" in (self.action or "") or "boxed{" in (self.action or "")

# class MCTSSearch:
#     def __init__(self, engine, config):
#         self.llm, self.config = engine, config
#         self.c_puct = config.get("c_puct", 1.25)

#     def search(self, root_state, simulations):
#         root = MCTSNode(state=root_state, prior=1.0)

#         for sim in range(simulations):
#             node = root

#             # Selection: traverse tree using UCB until we find unexpanded non-terminal
#             while node.is_fully_expanded() and not node.is_terminal():
#                 node = self._select_child(node)

#             # Expansion: add children if not terminal
#             if not node.is_terminal():
#                 node = self._expand(node)

#             # Evaluation: get value for this node
#             value = self._evaluate(node)

#             # Backpropagation
#             self._backpropagate(node, value)

#         return self._get_best_path(root)

#     def _select_child(self, node):
#         sqrt_n = math.sqrt(max(1, node.visits))

#         def ucb_score(c):
#             exploitation = c.value
#             exploration = self.c_puct * c.prior * sqrt_n / (1 + c.visits)
#             return exploitation + exploration

#         return max(node.children, key=ucb_score)

#     def _expand(self, node):
#         candidates = self.llm.generate_steps(node.state, n=self.config.get("n_expand", 3))
#         if not candidates:
#             return node

#         # Normalize priors
#         priors = np.array([c[1] for c in candidates])
#         priors = priors / priors.sum() if priors.sum() > 0 else np.ones(len(candidates)) / len(candidates)

#         for i, (text, _) in enumerate(candidates):
#             child = MCTSNode(node.state + "\n" + text, node, text, priors[i])
#             node.children.append(child)

#         # Return child with highest prior
#         return max(node.children, key=lambda c: c.prior)

#     def _evaluate(self, node):
#         """
#         Value = probability this path leads to a consistent answer.

#         For terminal nodes: complete the solution and return 1.0 if valid answer found.
#         For non-terminal: run rollouts, check answer consistency AND validity.
#         """
#         if node.is_terminal():
#             # Check if terminal node actually has a valid boxed answer
#             from src.utils import extract_answer
#             ans = extract_answer(node.action)
#             return 1.0 if ans else 0.5  # Penalize malformed \boxed{}

#         # Get consensus value from rollouts
#         n_rollouts = self.config["value_function"]["n_rollouts"]
#         consensus = self.llm.get_consensus_value(node.state, n_rollouts)

#         # Weight by prior: good steps (high prior) with good outcomes (high consensus)
#         # This combines "is this step likely?" with "does it lead to consistent answers?"
#         weighted_value = 0.5 * consensus + 0.5 * node.prior

#         return weighted_value

#     def _backpropagate(self, node, value):
#         while node:
#             node.visits += 1
#             node.value_sum += value
#             node = node.parent

#     def _get_best_path(self, root):
#         """Follow most-visited path, then greedy complete if needed."""
#         node = root
#         path_text = ""

#         while node.children:
#             node = max(node.children, key=lambda c: c.visits)
#             path_text += "\n" + node.action

#             if node.is_terminal():
#                 break

#         if not node.is_terminal():
#             completion = self.llm.greedy_complete(root.state + path_text)
#             path_text += completion

#         return path_text

Overwriting src/mcts.py


In [None]:
%%writefile src/mcts.py
import math
import numpy as np
from collections import Counter

class MCTSNode:
    def __init__(self, state, parent=None, action=None, prior=0.0):
        self.state, self.parent, self.action = state, parent, action
        self.children, self.visits, self.value_sum, self.prior = [], 0, 0.0, prior
        self._cached_value = None  # Cache value to avoid re-computing

    @property
    def value(self):
        return self.value_sum / (self.visits + 1e-6)

    def is_fully_expanded(self):
        return len(self.children) > 0

    def is_terminal(self):
        return "\\boxed" in (self.action or "") or "boxed{" in (self.action or "")

class MCTSSearch:
    def __init__(self, engine, config):
        self.llm = engine
        self.config = config
        self.c_puct = config.get("c_puct", 1.25)

    def search(self, root_state, simulations):
        root = MCTSNode(state=root_state, prior=1.0)

        for _ in range(simulations):
            node = root

            # Selection
            while node.is_fully_expanded() and not node.is_terminal():
                node = self._select_child(node)

            # Expansion
            if not node.is_terminal():
                node = self._expand(node)

            # Evaluation using Best-of-N
            value = self._evaluate_bon(node)

            # Backpropagation
            self._backpropagate(node, value)

        return self._get_best_path(root)

    def _select_child(self, node):
        sqrt_n = math.sqrt(max(1, node.visits))

        def ucb_score(c):
            exploitation = c.value
            exploration = self.c_puct * c.prior * sqrt_n / (1 + c.visits)
            return exploitation + exploration

        return max(node.children, key=ucb_score)

    def _expand(self, node):
        candidates = self.llm.generate_steps(node.state, n=self.config.get("n_expand", 3))
        if not candidates:
            return node

        priors = np.array([c[1] for c in candidates])
        priors = priors / priors.sum() if priors.sum() > 0 else np.ones(len(candidates)) / len(candidates)

        for i, (text, _) in enumerate(candidates):
            child = MCTSNode(node.state + "\n" + text, node, text, priors[i])
            node.children.append(child)

        return max(node.children, key=lambda c: c.prior)

    def _evaluate_bon(self, node):
        """
        Best-of-N value function:
        1. Generate N complete solutions from this state
        2. Extract answers from each
        3. Value = (count of majority answer) / N

        This measures: "How likely is this state to lead to a consistent answer?"
        """
        from src.utils import extract_answer

        # Terminal nodes - check if answer is extractable
        if node.is_terminal():
            ans = extract_answer(node.state)
            return 1.0 if ans else 0.5

        # Use cached value if available (avoid redundant LLM calls)
        if node._cached_value is not None:
            return node._cached_value

        n_rollouts = self.config["value_function"]["n_rollouts"]

        # Generate N complete solutions from this state
        completions = self.llm._generate(node.state, n_rollouts, 0.7, 512)

        # Extract and normalize answers
        answers = []
        for comp in completions:
            ans = extract_answer(comp)
            if ans:
                # Normalize answer
                ans_norm = str(ans).replace(",", "").replace(" ", "").strip()
                try:
                    val = float(ans_norm)
                    ans_norm = str(int(val)) if val == int(val) else str(val)
                except:
                    pass
                answers.append(ans_norm)

        if not answers:
            node._cached_value = 0.0
            return 0.0

        # Value = majority ratio (how consistent are the answers?)
        counts = Counter(answers)
        majority_count = counts.most_common(1)[0][1]
        value = majority_count / len(answers)

        node._cached_value = value
        return value

    def _backpropagate(self, node, value):
        while node:
            node.visits += 1
            node.value_sum += value
            node = node.parent

    def _get_best_path(self, root):
        """Follow most-visited path, then greedy complete if needed."""
        node = root
        path_text = ""

        while node.children:
            node = max(node.children, key=lambda c: c.visits)
            path_text += "\n" + node.action

            if node.is_terminal():
                break

        if not node.is_terminal():
            completion = self.llm.greedy_complete(root.state + path_text)
            path_text += completion

        return path_text

Overwriting src/mcts.py


In [None]:
%%writefile src/best_of_n.py
from collections import Counter
from src.utils import extract_answer

def normalize_answer(s):
    if s is None: return None
    s = str(s).replace(",", "").replace(" ", "").strip()
    try:
        v = float(s)
        return str(int(v)) if v == int(v) else str(v)
    except: return s

class BestOfNSearch:
    def __init__(self, engine, config):
        self.llm, self.config = engine, config

    def search(self, prompt, n=1):
        if n == 1: return self.llm.greedy_complete(prompt)
        completions = self.llm._generate(prompt, n, 0.7, 512)
        if not completions: return self.llm.greedy_complete(prompt)

        ans_to_comp = {}
        all_ans = []
        for c in completions:
            a = normalize_answer(extract_answer(c))
            if a:
                all_ans.append(a)
                if a not in ans_to_comp: ans_to_comp[a] = c

        if all_ans:
            best = Counter(all_ans).most_common(1)[0][0]
            return ans_to_comp.get(best, completions[0])
        return completions[0]

Overwriting src/best_of_n.py


## 4. Create Configuration

In [None]:
%%writefile configs/gsm8k_config.yaml
experiment:
  name: "mcts_vs_bon_vllm"
  seed: 42
  num_samples: 50
  output_dir: "outputs"
  log_dir: "logs"

model:
  model_id: "Qwen/Qwen2.5-Math-1.5B-Instruct"
  dtype: "float16"
  gpu_memory_utilization: 0.85
  max_model_len: 2048

mcts:
  simulations: [1, 5, 10]
  c_puct: 1.5              # Slightly more exploration
  n_expand: 3
  value_function:
    n_rollouts: 5          # Increased from 3 for more stable consensus

generation:
  temperature: 0.7
  stop_tokens: ["\n"]

Overwriting configs/gsm8k_config.yaml


## 5. Initialize

In [None]:
import sys, os, json
sys.path.insert(0, os.getcwd())

from src.config_parser import load_config
from src.dataset import GSM8KDataset
from src.llm_engine import LLMEngine
from src.mcts import MCTSSearch
from src.best_of_n import BestOfNSearch
from src.utils import setup_logging, seed_everything, extract_answer, is_correct

config = load_config("configs/gsm8k_config.yaml")
logger = setup_logging(config['experiment']['name'], config['experiment']['log_dir'])
seed_everything(config['experiment']['seed'])

print(f"Model: {config['model']['model_id']}")
print(f"Samples: {config['experiment']['num_samples']}")
print(f"Budgets: {config['mcts']['simulations']}")

Model: Qwen/Qwen2.5-Math-1.5B-Instruct
Samples: 50
Budgets: [1, 5, 10]


In [None]:
print("Loading dataset...")
dataset = GSM8KDataset(config['experiment'])
print(f"Dataset size: {len(dataset)}")

print("\nInitializing vLLM...")
engine = LLMEngine(config)
print("Ready!")

Loading dataset...
Dataset size: 50

Initializing vLLM...
INFO 12-11 00:29:02 [utils.py:253] non-default args: {'trust_remote_code': True, 'dtype': 'float16', 'seed': None, 'max_model_len': 2048, 'gpu_memory_utilization': 0.85, 'disable_log_stats': True, 'model': 'Qwen/Qwen2.5-Math-1.5B-Instruct'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 12-11 00:29:03 [model.py:637] Resolved architecture: Qwen2ForCausalLM
INFO 12-11 00:29:03 [model.py:1750] Using max model len 2048
INFO 12-11 00:29:03 [scheduler.py:228] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 12-11 00:30:09 [llm.py:343] Supported tasks: ['generate']
Ready!


## 6. Run Best-of-N

In [None]:
from tqdm.notebook import tqdm

budgets = config['mcts']['simulations']
best_of_n = BestOfNSearch(engine, config['mcts'])
bon_results = []

for budget in budgets:
    print(f"\n--- Best-of-N: N={budget} ---")
    correct = 0
    for i in tqdm(range(len(dataset))):
        data = dataset[i]
        out = best_of_n.search(data['prompt'], n=budget)
        pred = extract_answer(out)
        truth = extract_answer(data['ground_truth'])
        c = is_correct(pred, truth)
        if c: correct += 1
        bon_results.append({"method": "Best-of-N", "budget": budget, "id": i, "correct": c})
    print(f"Accuracy: {correct/len(dataset):.2%}")


--- Best-of-N: N=1 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 80.00%

--- Best-of-N: N=5 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 86.00%

--- Best-of-N: N=10 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 86.00%


## 7. Run MCTS

In [None]:
mcts = MCTSSearch(engine, config['mcts'])
mcts_results = []

for budget in budgets:
    print(f"\n--- MCTS: N={budget} ---")
    correct = 0
    for i in tqdm(range(len(dataset))):
        data = dataset[i]
        out = mcts.search(data['prompt'], simulations=budget)
        pred = extract_answer(out)
        truth = extract_answer(data['ground_truth'])
        c = is_correct(pred, truth)
        if c: correct += 1
        mcts_results.append({"method": "MCTS", "budget": budget, "id": i, "correct": c})
    print(f"Accuracy: {correct/len(dataset):.2%}")


--- MCTS: N=1 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 82.00%

--- MCTS: N=5 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 80.00%

--- MCTS: N=10 ---


  0%|          | 0/50 [00:00<?, ?it/s]

Accuracy: 76.00%


## 8. Save & Visualize

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

all_results = bon_results + mcts_results
with open("outputs/results.jsonl", "w") as f:
    for r in all_results: f.write(json.dumps(r) + "\n")

df = pd.DataFrame(all_results)
summary = df.groupby(["method", "budget"])["correct"].mean().reset_index()
summary["accuracy"] = summary["correct"] * 100

print("\nResults:")
print(summary.pivot(index="budget", columns="method", values="accuracy"))

plt.figure(figsize=(10, 6))
for method, color in [("Best-of-N", "#3498db"), ("MCTS", "#e74c3c")]:
    d = summary[summary["method"] == method]
    plt.plot(d["budget"], d["accuracy"], marker="o", label=method, color=color, linewidth=2.5, markersize=10)

plt.title("MCTS vs Best-of-N (vLLM with Logprobs)", fontsize=14, fontweight='bold')
plt.xlabel("Compute Budget (N)")
plt.ylabel("Accuracy (%)")
plt.legend()
plt.xticks(budgets)
plt.savefig("outputs/comparison.png", dpi=300)
plt.show()

## 9. Download Results

In [None]:
from google.colab import files
import shutil
shutil.make_archive('results_vllm', 'zip', 'outputs')
files.download('results_vllm.zip')