In [1]:
import asyncio
import os
import warnings

import numpy as np
import pandas as pd
from transformers.utils import logging

from nemo_skills.inference.model import get_tool_calling_model
from nemo_skills.inference.model.base import EndpointType
from nemo_skills.prompt.utils import get_prompt

os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
logging.set_verbosity_error()

!nvidia-smi -L | cut -d '(' -f 1



GPU 0: NVIDIA RTX A6000 
GPU 1: NVIDIA RTX A6000 


In [2]:
def build_table(bench, tokenizer):
    data = {
        "Metric": [
            "Num Generations",
            "Total Generation Time",
            "Batch Throughput (Tok/sec)",
            "Avg Request Throughput (Tok/sec)",
        ]
    }
    for name, rec in bench.items():
        gens = [g for g in rec["gens"] if isinstance(g, str)]
        tt = max(rec["total_time"], 1e-9)
        toks = np.array([len(tokenizer.encode(g)) for g in gens], dtype=float)
        batch_tp = toks.sum() / tt
        per_req_tp = toks / tt
        data[name] = [
            f"{len(gens)}",
            f"{tt:.1f}",
            f"{batch_tp:.1f}",
            f"{per_req_tp.mean():.1f} Â± {per_req_tp.std():.1f}",
        ]
    return pd.DataFrame(data)

In [3]:
prompt_obj = get_prompt("generic/math")

In [4]:
model = "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16"

llm = get_tool_calling_model(
    model,
    server_type="vllm",
    tool_modules=["nemo_skills.mcp.servers.python_tool::PythonTool"],
    additional_config={"sandbox": {}},
)

In [5]:
problem = r"""
The Fibonacci numbers are defined as follows: $F_0 = 0$, $F_1 = 1$, and $F_{n+1} = F_n + F_{n-1}$ for $n \geq 1$.
There are $N$ positive integers $n$ strictly less than $10^{101}$ such that $n^2 + (n+1)^2$ is a multiple of 5 but $F_{n-1}^2 + F_n^2$ is not.
How many prime factors does $N$ have, counted with multiplicity?
"""

In [6]:
bench = {}

In [None]:
async def run_generation(request_id, prompt_obj, llm, problem):
    stream = None
    full_generation = ""
    try:
        stream = await llm.generate_async(
            prompt=prompt_obj.fill({"problem": problem}),
            endpoint_type=EndpointType.chat,
            stream=True,
            random_seed=request_id,
            temperature=1.0,
            tokens_to_generate=5000,
        )
        async for response in stream:
            full_generation += response["generation"]
        print("I'm done!", request_id)
        return {"request_id": request_id, "status": "complete", "generation": full_generation}
    except asyncio.CancelledError:
        if stream is not None:
            await stream.aclose()
        print("I'm cancelled!", request_id)
        return {"request_id": request_id, "status": "partial", "generation": full_generation}
    except Exception as e:
        if stream is not None:
            await stream.aclose()
        print(f"Generation {request_id} failed with error: {e}")
        return {"request_id": request_id, "status": "error", "generation": full_generation}


async def main_loop(prompt_obj, llm, problem):
    num_generations = 20
    cancel_after_done = 1

    completed = 0
    complete_generations = []
    partial_generations = []

    tasks = [asyncio.create_task(run_generation(i, prompt_obj, llm, problem)) for i in range(num_generations)]
    pending = set(tasks)

    while pending:
        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
        for task in done:
            result = task.result()
            if result["status"] == "complete":
                complete_generations.append(result["generation"])
                completed += 1
            else:
                partial_generations.append(result["generation"])

        if completed >= cancel_after_done and pending:
            for task in pending:
                task.cancel()

    return complete_generations, partial_generations


import logging

# Suppress tool-call cancellation noise from anyio TaskGroup shutdowns.
try:
    import anyio
except Exception:
    anyio = None


class _ToolCallCancelFilter(logging.Filter):
    def filter(self, record):
        exc = record.exc_info[1] if record.exc_info else None
        if exc is None:
            return True
        if isinstance(exc, asyncio.CancelledError):
            return False
        if isinstance(exc, BaseExceptionGroup):
            for sub_exc in exc.exceptions:
                if isinstance(sub_exc, asyncio.CancelledError):
                    return False
                if anyio and isinstance(sub_exc, (anyio.BrokenResourceError, anyio.ClosedResourceError)):
                    return False
            return True
        if anyio and isinstance(exc, (anyio.BrokenResourceError, anyio.ClosedResourceError)):
            return False
        return True


class _ProcessKillNoiseFilter(logging.Filter):
    def filter(self, record):
        if record.name != "mcp.os.posix.utilities":
            return True
        if not record.getMessage().startswith("Failed to kill process"):
            return True
        exc = record.exc_info[1] if record.exc_info else None
        if exc is None:
            return False
        if isinstance(exc, ProcessLookupError):
            return False
        if isinstance(exc, BaseExceptionGroup):
            return not any(isinstance(sub_exc, ProcessLookupError) for sub_exc in exc.exceptions)
        return True


tool_logger = logging.getLogger("nemo_skills.inference.model.tool_call")
tool_logger.addFilter(_ToolCallCancelFilter())

posix_logger = logging.getLogger("mcp.os.posix.utilities")
posix_logger.addFilter(_ProcessKillNoiseFilter())

logging.getLogger().setLevel(logging.WARNING)

complete_generations, partial_generations = await main_loop(prompt_obj, llm, problem)

I'm done! 3
I'm cancelled! 12
I'm cancelled! 4
I'm cancelled! 1
I'm cancelled! 8
I'm cancelled! 18
I'm cancelled! 11
I'm cancelled! 9
I'm cancelled! 0
I'm cancelled! 2
I'm cancelled! 6
I'm cancelled! 19
I'm cancelled! 5
I'm cancelled! 17
I'm cancelled! 10
I'm cancelled! 13
I'm cancelled! 14
I'm cancelled! 15
I'm cancelled! 16
I'm cancelled! 7


In [35]:
from transformers import AutoTokenizer

try:
    tokenizer = llm.tokenizer
except Exception:
    tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)


def _tok_count(text: str) -> int:
    return len(tokenizer.encode(text))


print("Complete generations token counts:")
for i, g in enumerate(complete_generations):
    print(f"  complete[{i}]: {_tok_count(g)}")

print("\nPartial generations token counts:")
for i, g in enumerate(partial_generations):
    print(f"  partial[{i}]: {_tok_count(g)}")

Complete generations token counts:
  complete[0]: 5000

Partial generations token counts:
  partial[0]: 4410
  partial[1]: 4751
  partial[2]: 4666
  partial[3]: 3898
  partial[4]: 3988
  partial[5]: 2722
  partial[6]: 3824
  partial[7]: 4998
  partial[8]: 4997
  partial[9]: 4761
  partial[10]: 3769
  partial[11]: 4607
  partial[12]: 3611
  partial[13]: 4702
  partial[14]: 4998
  partial[15]: 4251
  partial[16]: 4496
  partial[17]: 4430
  partial[18]: 4536
