In [1]:
import asyncio
import contextlib
import os
import signal
import subprocess
import time
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
import psutil
import requests
import transformers
from transformers.utils import logging

from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.inference.model import get_code_execution_model
from nemo_skills.prompt.utils import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

!nvidia-smi -L | cut -d '(' -f 1

GPU 0: NVIDIA H100 80GB HBM3 
GPU 1: NVIDIA H100 80GB HBM3 


In [2]:
BASE_DIR = "./"
MODEL_DIR_HF = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle"
MODEL_DIR_BF16 = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-bf16-trtllm"
MODEL_DIR_FP8 = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-fp8-trtllm"
MODEL_DIR_FP8_DRAFT = f"{BASE_DIR}/OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm"
benchmark = {}

In [3]:
def wait_for_server(host, port, timeout=300, interval=1):
    url = f"http://{host}:{port}"
    start_time = time.time()
    while True:
        try:
            response = requests.put(url)
            if response.status_code != 403:
                return True
        except requests.RequestException:
            if time.time() - start_time > timeout:
                raise TimeoutError("Server did not respond within timeout period")
            time.sleep(interval)


def start_server(model_dir, port=5000):
    host = "127.0.0.1"
    cmd = (
        f"trtllm-serve serve {model_dir} "
        f"    --tokenizer {MODEL_DIR_HF}"
        f"    --backend trt "
        f"    --tp_size 2 "
        f"    --kv_cache_free_gpu_memory_fraction 0.92 "
        f"    --max_batch_size 12 "
        f"    --host {host} "
        f"    --port {port}"
    )
    print(f"Starting server from {model_dir} at {host}:{port}")
    model_name = model_dir.split("/")[-1]
    log_path = Path(f"{model_name}_server_logs.log").resolve()
    log_file = open(log_path, "w", buffering=1)
    proc = subprocess.Popen(cmd, shell=True, stdout=log_file, stderr=subprocess.STDOUT, preexec_fn=os.setsid)
    print("Waiting for server to be ready (might take a while) ...")
    wait_for_server(host, port)
    print("Server ready!")
    return proc


def kill_server(proc, port=5000):
    os.killpg(proc.pid, signal.SIGTERM)
    time.sleep(10)

    for proc in psutil.process_iter(["pid", "name"]):
        for conn in proc.connections(kind="inet"):
            if conn.laddr.port == port:
                print(f"Killing process {proc.info['name']} (PID: {proc.info['pid']}) running on port {port}")
                os.kill(proc.info["pid"], 9)
                break

    time.sleep(10)
    print("Server closed.")

In [4]:
async def run_generation(request_id, prompt_obj, llm, problem):
    stream = None
    try:
        stream = await llm.generate_async(
            prompt=prompt_obj.fill({"problem": problem}),
            stream=True,
            random_seed=request_id,
            temperature=0.7,
            tokens_to_generate=20000,
            **prompt_obj.get_code_execution_args(),
        )
        full_generation = ""
        async for response in stream:
            full_generation += response["generation"]
        return full_generation
    except asyncio.CancelledError:
        # Close the stream so the server is notified on cancellation
        with contextlib.suppress(Exception):
            if stream is not None:
                if hasattr(stream, "aclose"):
                    await stream.aclose()
                elif hasattr(stream, "close"):
                    stream.close()
        raise


async def main_loop(prompt_obj, llm, problem):
    num_generations = 12
    cancel_after_done = 10
    tasks = [asyncio.create_task(run_generation(i, prompt_obj, llm, problem)) for i in range(num_generations)]

    completed = 0
    all_generations = []

    # Consume as they finish
    for fut in asyncio.as_completed(tasks):
        res = await fut
        all_generations.append(res)
        completed += 1
        if completed >= cancel_after_done:
            # cancel remaining
            for t in tasks:
                if not t.done():
                    t.cancel()
            break

    # Drain only the still-pending tasks to avoid duplicates
    pending = [t for t in tasks if not t.done()]
    if pending:
        drained = await asyncio.gather(*pending, return_exceptions=True)
        # keep only successful returns (skip exceptions like CancelledError)
        all_generations.extend(r for r in drained if not isinstance(r, Exception))

    return all_generations


def build_table(bench, tokenizer):
    data = {
        "Metric": [
            "Num Generations",
            "Total Generation Time",
            "Batch Throughput (Tok/sec)",
            "Avg Request Throughput (Tok/sec)",
        ]
    }
    for name, rec in bench.items():
        gens = [g for g in rec["gens"] if isinstance(g, str)]
        tt = max(rec["total_time"], 1e-9)
        toks = np.array([len(tokenizer.encode(g)) for g in gens], dtype=float)
        batch_tp = toks.sum() / tt
        per_req_tp = toks / tt
        data[name] = [
            f"{len(gens)}",
            f"{tt:.1f}",
            f"{batch_tp:.1f}",
            f"{per_req_tp.mean():.1f} ± {per_req_tp.std():.1f}",
        ]
    return pd.DataFrame(data)

In [5]:
sandbox_cmd = "python -m nemo_skills.code_execution.local_sandbox.local_sandbox_server"
subprocess.Popen(sandbox_cmd, shell=True)
time.sleep(2)

[worker unknown] 2025-08-29 09:23:53,871 INFO: Applied worker memory limit (RLIMIT_AS/RLIMIT_DATA): 21474836480 bytes


 * Serving Flask app 'local_sandbox_server'
 * Debug mode: off


In [6]:
sandbox = get_sandbox()

In [7]:
server_process = start_server(MODEL_DIR_FP8_DRAFT)

Starting server from .//OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [8]:
# Openai trtllm server
!curl -s http://127.0.0.1:5000/v1/models

{"object":"list","data":[{"id":"OpenMath-Nemotron-14B-kaggle-fp8-redrafter-trtllm","object":"model","created":1756459594,"owned_by":"tensorrt_llm"}]}

In [9]:
# Code execution server
!curl -s http://127.0.0.1:6000/sessions

{"backend":"ipython","sessions":{}}


[worker unknown] 2025-08-29 09:26:34,667 INFO: active_sessions=0


In [10]:
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_DIR_FP8)

In [11]:
# sandbox = get_sandbox()
code_execution = {"max_code_executions": 2, "sandbox_traceback_verbosity": "plain"}
prompt_obj = get_prompt("generic/math", tokenizer=MODEL_DIR_FP8, code_tags="openmath")

In [12]:
llm = get_code_execution_model(
    server_type="trtllm", model=MODEL_DIR_FP8, sandbox=sandbox, code_execution=code_execution
)

In [13]:
problem = r"""
The Fibonacci numbers are defined as follows: $F_0 = 0$, $F_1 = 1$, and $F_{n+1} = F_n + F_{n-1}$ for $n \geq 1$.
There are $N$ positive integers $n$ strictly less than $10^{101}$ such that $n^2 + (n+1)^2$ is a multiple of 5 but $F_{n-1}^2 + F_n^2$ is not.
How many prime factors does $N$ have, counted with multiplicity?
"""

In [14]:
bench = {}

In [15]:
t0 = time.perf_counter()
gens_fp8_draft = await main_loop(prompt_obj, llm, problem)
bench["fp8_draft"] = {
    "gens": gens_fp8_draft,
    "total_time": time.perf_counter() - t0,
}

In [16]:
kill_server(server_process)

Server closed.


In [17]:
server_process = start_server(MODEL_DIR_FP8)
llm = get_code_execution_model(
    server_type="trtllm", model=MODEL_DIR_FP8, sandbox=sandbox, code_execution=code_execution
)

Starting server from .//OpenMath-Nemotron-14B-kaggle-fp8-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [18]:
t0 = time.perf_counter()
gens_fp8 = await main_loop(prompt_obj, llm, problem)
bench["fp8"] = {
    "gens": gens_fp8,
    "total_time": time.perf_counter() - t0,
}

In [19]:
kill_server(server_process)

Server closed.


In [20]:
server_process = start_server(MODEL_DIR_BF16)
llm = get_code_execution_model(
    server_type="trtllm", model=MODEL_DIR_BF16, sandbox=sandbox, code_execution=code_execution
)

Starting server from .//OpenMath-Nemotron-14B-kaggle-bf16-trtllm at 127.0.0.1:5000
Waiting for server to be ready (might take a while) ...
Server ready!


In [21]:
t0 = time.perf_counter()
gens_bf16 = await main_loop(prompt_obj, llm, problem)
bench["bf16"] = {
    "gens": gens_bf16,
    "total_time": time.perf_counter() - t0,
}

In [22]:
build_table(bench, tokenizer)

Unnamed: 0,Metric,fp8_draft,fp8,bf16
0,Num Generations,10,10,10
1,Total Generation Time,30.5,64.7,144.2
2,Batch Throughput (Tok/sec),1385.4,751.7,346.0
3,Avg Request Throughput (Tok/sec),138.5 ± 24.6,75.2 ± 11.0,34.6 ± 6.1
