Skip to content

Latest commit

 

History

History
64 lines (52 loc) · 1.98 KB

using-the-python-api.mdx

File metadata and controls

64 lines (52 loc) · 1.98 KB

Using the Python API

Lighteval can be used from a custom python script. To evaluate a model you will need to set up an [~logging.evaluation_tracker.EvaluationTracker], [~pipeline.PipelineParameters], a model or a model_config, and a [~pipeline.Pipeline].

After that, simply run the pipeline and save the results.

import lighteval
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.models.vllm.vllm_model import VLLMModelConfig
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
from lighteval.utils.utils import EnvConfig
from lighteval.utils.imports import is_accelerate_available

if is_accelerate_available():
    from datetime import timedelta
    from accelerate import Accelerator, InitProcessGroupKwargs
    accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
else:
    accelerator = None

def main():
    evaluation_tracker = EvaluationTracker(
        output_dir="./results",
        save_details=True,
        push_to_hub=True,
        hub_results_org="your user name",
    )

    pipeline_params = PipelineParameters(
        launcher_type=ParallelismManager.ACCELERATE,
        env_config=EnvConfig(cache_dir="tmp/"),
        # Remove the 2 parameters below once your configuration is tested
        override_batch_size=1,
        max_samples=10
    )

    model_config = VLLMModelConfig(
            model_name="HuggingFaceH4/zephyr-7b-beta",
            dtype="float16",
            use_chat_template=True,
    )

    task = "helm|mmlu|5|1"

    pipeline = Pipeline(
        tasks=task,
        pipeline_parameters=pipeline_params,
        evaluation_tracker=evaluation_tracker,
        model_config=model_config,
        custom_task_directory=None, # if using a custom task
    )

    pipeline.evaluate()
    pipeline.save_and_push_results()
    pipeline.show_results()

if __name__ == "__main__":
    main()