This notebook demonstrates an example of generating Python code using inference [CodeLlama-7b-Instruct-hf](https://huggingface.co/meta-llama/CodeLlama-7b-Instruct-hf) model, followed by validating the generated code by executing it in sandbox on [Tracto.ai](https://tracto.ai/).

In [1]:
import yt.wrapper as yt
from yt import type_info
import uuid

In [2]:
# configure environment to run this notebooks
import uuid
import yt.wrapper as yt

username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
    # prepare working directory on distributed file system
    user_info = yt.get(f"//sys/users/{yt.get_user_name()}/@user_info")
    homedir = user_info["home_path"]
    # find avaliable vm presets
    cpu_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("cpu")] or ["default"]
    h100_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100")]
    h100_8_pool_trees = [pool_tree for pool_tree in user_info["available_pool_trees"] if pool_tree.endswith("h100-8")]
    workdir = f"{homedir}/tmp/demo_workdir/{uuid.uuid4().hex}"
else:
    cpu_pool_trees = ["default"]
    h100_pool_trees = ["gpu_h100"]
    h100_8_pool_trees = ["gpu_h100"]
    workdir = f"//tmp/examples/{uuid.uuid4().hex}"

yt.create("map_node", workdir, recursive=True, ignore_existing=True)
print("Current working directory:", workdir)

Current working directory: //home/equal_amethyst_vulture/tmp/demo_workdir/ff6d2b9b3c4046cd9b7097b27d256a51


Upload dataset from huggingface to YTSaurus table: task description, inputs, and outputs.

In [4]:
from datasets import load_dataset

MAX_SAMPLES = 50

dataset = load_dataset("deepmind/code_contests")

dataset_path = f"{workdir}/dataset"

table_data = (
    {
        "index": index,
        "description": record["description"],
        "input": list(record["private_tests"]["input"] + record["generated_tests"]["input"])[:MAX_SAMPLES],
        "output": list(record["private_tests"]["output"] + record["generated_tests"]["output"])[:MAX_SAMPLES],
    }
    for index, record in enumerate(dataset["train"])
)

schema = yt.schema.TableSchema(strict=True)
schema.add_column("index", type_info.Int32)
schema.add_column("description", type_info.String)
schema.add_column("input", type_info.List[type_info.String])
schema.add_column("output", type_info.List[type_info.String])

yt.create("table", dataset_path, force=True, attributes={"schema": schema.to_yson_type()})
yt.write_table(dataset_path, table_data, table_writer={"max_row_weight": 128 * 1024 * 1024})

README.md:   0%|          | 0.00/13.0k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/39 [00:00<?, ?it/s]

dataset_infos.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Downloading data:   0%|          | 0/39 [00:00<?, ?files/s]

(…)-00000-of-00039-e991a271dbfa9925.parquet:   0%|          | 0.00/180M [00:00<?, ?B/s]

(…)-00001-of-00039-e092fe56fda18715.parquet:   0%|          | 0.00/209M [00:00<?, ?B/s]

(…)-00002-of-00039-9cea23812e920e41.parquet:   0%|          | 0.00/227M [00:00<?, ?B/s]

(…)-00003-of-00039-e3822fccad6e083a.parquet:   0%|          | 0.00/181M [00:00<?, ?B/s]

(…)-00004-of-00039-cefe355b4667b27e.parquet:   0%|          | 0.00/195M [00:00<?, ?B/s]

(…)-00005-of-00039-b7580d2d846c2136.parquet:   0%|          | 0.00/174M [00:00<?, ?B/s]

(…)-00006-of-00039-65184bb9f7d61fde.parquet:   0%|          | 0.00/186M [00:00<?, ?B/s]

(…)-00007-of-00039-05785de21e8b8429.parquet:   0%|          | 0.00/172M [00:00<?, ?B/s]

(…)-00008-of-00039-7246e6b7423b404f.parquet:   0%|          | 0.00/200M [00:00<?, ?B/s]

(…)-00009-of-00039-b8c920f6629b57b2.parquet:   0%|          | 0.00/205M [00:00<?, ?B/s]

(…)-00010-of-00039-6de28ba20654f69b.parquet:   0%|          | 0.00/178M [00:00<?, ?B/s]

(…)-00011-of-00039-5de236be5188959d.parquet:   0%|          | 0.00/164M [00:00<?, ?B/s]

(…)-00012-of-00039-da9476a39a1bdbb7.parquet:   0%|          | 0.00/200M [00:00<?, ?B/s]

(…)-00013-of-00039-30b8c3829ee3b962.parquet:   0%|          | 0.00/197M [00:00<?, ?B/s]

(…)-00014-of-00039-dc3ebb07a3cba8e4.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

(…)-00015-of-00039-19ccd7331d695677.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

(…)-00016-of-00039-bf38b0908b322307.parquet:   0%|          | 0.00/202M [00:00<?, ?B/s]

(…)-00017-of-00039-ae5533a2f822e6ef.parquet:   0%|          | 0.00/169M [00:00<?, ?B/s]

(…)-00018-of-00039-8c793837880f5507.parquet:   0%|          | 0.00/185M [00:00<?, ?B/s]

(…)-00019-of-00039-d688fad5ee604390.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

(…)-00020-of-00039-5d59387098675b73.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

(…)-00021-of-00039-b257bf03d6876780.parquet:   0%|          | 0.00/181M [00:00<?, ?B/s]

(…)-00022-of-00039-1cfd39fa43c1917c.parquet:   0%|          | 0.00/194M [00:00<?, ?B/s]

(…)-00023-of-00039-d078bcb55e45cbf0.parquet:   0%|          | 0.00/176M [00:00<?, ?B/s]

(…)-00024-of-00039-f4e3da0e5661e6d1.parquet:   0%|          | 0.00/181M [00:00<?, ?B/s]

(…)-00025-of-00039-3f6ebfbaba5f4c70.parquet:   0%|          | 0.00/206M [00:00<?, ?B/s]

(…)-00026-of-00039-7d4898300894cbbe.parquet:   0%|          | 0.00/189M [00:00<?, ?B/s]

(…)-00027-of-00039-f8196766547533a2.parquet:   0%|          | 0.00/217M [00:00<?, ?B/s]

(…)-00028-of-00039-79a302af3c924863.parquet:   0%|          | 0.00/179M [00:00<?, ?B/s]

(…)-00029-of-00039-2b6615897d038115.parquet:   0%|          | 0.00/198M [00:00<?, ?B/s]

(…)-00030-of-00039-4135cc54050afc22.parquet:   0%|          | 0.00/223M [00:00<?, ?B/s]

(…)-00031-of-00039-40309dd907c042b7.parquet:   0%|          | 0.00/181M [00:00<?, ?B/s]

(…)-00032-of-00039-7b7d2068a3d9c359.parquet:   0%|          | 0.00/186M [00:00<?, ?B/s]

(…)-00033-of-00039-53b0f749aacff9c1.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

(…)-00034-of-00039-a36ff0bff7d2a76f.parquet:   0%|          | 0.00/188M [00:00<?, ?B/s]

(…)-00035-of-00039-d28f9be60314601f.parquet:   0%|          | 0.00/151M [00:00<?, ?B/s]

(…)-00036-of-00039-146e1a11c054aeab.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

(…)-00037-of-00039-995207c374a4e6f2.parquet:   0%|          | 0.00/231M [00:00<?, ?B/s]

(…)-00038-of-00039-96a59dd6a98cd075.parquet:   0%|          | 0.00/204M [00:00<?, ?B/s]

(…)-00000-of-00001-9c49eeff30aacaa8.parquet:   0%|          | 0.00/63.1M [00:00<?, ?B/s]

(…)-00000-of-00001-5e672c5751f060d3.parquet:   0%|          | 0.00/51.8M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/13328 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/165 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/117 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

In [5]:
import os

hf_token = os.environ.get("YT_SECURE_VAULT_HF_TOKEN", "")
assert hf_token != "", "set HF token in kernel's secrets to use llama"

Generate solutions by CodeLlama-7b-Instruct-hf.

In [7]:
from typing import Iterable
import logging
import sys
import random


BATCH_SIZE = 150
yt.config["pickling"]["safe_stream_mode"] = False  # important to run vllm


@yt.aggregator
def bulk_inference(records: Iterable[dict[str, str]]) -> Iterable[dict[str, str]]:
    from vllm import LLM, SamplingParams

    os.environ["HF_TOKEN"] = hf_token

    # yt job have to write all logs to stderr
    vllm_logger = logging.getLogger("vllm")
    vllm_logger.handlers.clear()
    vllm_logger.addHandler(logging.StreamHandler(sys.stderr))

    llm = LLM(model="meta-llama/CodeLlama-7b-Instruct-hf", tensor_parallel_size=1, trust_remote_code=True)
    sampling_params = SamplingParams(
        temperature=0.6,
        top_p=0.9,
        max_tokens=5000,
    )

    def generate(records_batch: list[dict]):
        conversations = [
            [
                {
                    "role": "system",
                    "content": "You are an AI assistant that generates Python code. Always write Python code that reads input from stdin and writes output to stdout. Ensure that all indentation is correct. Your responses must contain only ready-to-run Python code, without any explanations, comments, or extra text.",
                },
                {
                    "role": "user",
                    "content": r["description"],
                },
            ] for r in records_batch
        ]
        outputs = llm.chat(
            messages=conversations,
            sampling_params=sampling_params,
        )
        return outputs

    batch = []
    for record in records:
        batch.append(record)
        if len(batch) < BATCH_SIZE:
            continue
        outputs = generate(batch)
        for r, output in zip(batch, outputs):
            yield {
                "index": r["index"],
                "description": r["description"],
                "input": r["input"],
                "output": r["output"],
                "code": output.outputs[0].text.strip(" "),
            }
        batch = []
    if batch:
        outputs = generate(batch)
        for r, output in zip(batch, outputs):
            yield {
                "index": r["index"],
                "description": r["description"],
                "input": r["input"],
                "output": r["output"],
                "code": output.outputs[0].text.strip(" "),
            }

In [8]:
inference_path = f"{workdir}/inference"

schema = yt.schema.TableSchema(strict=True)
schema.add_column("index", type_info.Int32)
schema.add_column("description", type_info.String)
schema.add_column("input", type_info.List[type_info.String])
schema.add_column("output", type_info.List[type_info.String])
schema.add_column("code", type_info.String)

yt.create("table", inference_path, force=True, attributes={"schema": schema.to_yson_type()})


yt.run_map(
    bulk_inference,
    dataset_path,
    inference_path,
    job_count=16,
    spec={
        "pool_trees": h100_pool_trees,
        "mapper": {
            "gpu_limit": 1,
            "memory_limit": 32212254720,
            "cpu_limit": 2,
        },
        "job_io": {
            "table_writer": {
                "max_row_weight": 128 * 1024 * 1024,
            },
        },
    },
)

Let's run the code in parallel and validate the results. We'll limit memory consumption for the tested code using prlimit.
In this example, we do not isolate the llm-generated code from the controlling system at the io-level, but we limit RAM consumption.

In [10]:
import subprocess
import sys


RUN_TIMEOUT = 10
CODE_IO_LIMIT = 2000
MEM_LIMIT = 32212254720


@yt.aggregator
def validate_code(records: Iterable[dict[str, str]]) -> Iterable[dict[str, str]]:
    for record in records:
        for inp, expected in zip(record["input"], record["output"]):
            print(f"Run code for {record['index']}", file=sys.stderr)
            try:
                process = subprocess.run(
                    ["prlimit", f"--as={int(MEM_LIMIT * 0.5)}", "python3", "-c", record["code"]],
                    input=inp,
                    text=True,
                    capture_output=True,
                    timeout=RUN_TIMEOUT,
                )
            except subprocess.TimeoutExpired:
                print(f"Execution failed: timeout", file=sys.stderr)
                yield {
                    "index": record["index"],
                    "description": record["description"],
                    "input": record["input"],
                    "output": record["output"],
                    "code": record["code"],
                    "result_stdout": "<TIMEOUT>",
                    "result_stderr": "<TIMEOUT>",
                    "exitcode": -1,
                    "match_expected": False,
                }
                continue
            if process.returncode:
                print(f"Execution failed: {process.stderr}", file=sys.stderr)

            yield {
                "index": record["index"],
                "description": record["description"],
                "input": record["input"],
                "output": record["output"],
                "code": record["code"],
                "result_stdout": process.stdout[:CODE_IO_LIMIT],
                "result_stderr": process.stderr[:CODE_IO_LIMIT],
                "exitcode": process.returncode,
                "match_expected": process.stdout == expected,
            }

In [11]:
code_validation_path = f"{workdir}/code_validation"

schema = yt.schema.TableSchema(strict=True)
schema.add_column("index", type_info.Int32)
schema.add_column("description", type_info.String)
schema.add_column("input", type_info.List[type_info.String])
schema.add_column("output", type_info.List[type_info.String])
schema.add_column("code", type_info.String)
schema.add_column("result_stdout", type_info.String)
schema.add_column("result_stderr", type_info.String)
schema.add_column("exitcode", type_info.Int8)
schema.add_column("match_expected", type_info.Bool)

yt.create("table", code_validation_path, force=True, attributes={"schema": schema.to_yson_type()})

yt.run_map(
    validate_code,
    inference_path,
    code_validation_path,
    job_count=256,
    spec={
        "mapper": {
            "memory_limit": MEM_LIMIT,
            "cpu_limit": 2,
        },
        "job_io": {
            "table_writer": {
                "max_row_weight": 128 * 1024 * 1024,
            },
        },
    },
)