Skip to content

Commit

Permalink
evaluation tracker implementation (EleutherAI#1766)
Browse files Browse the repository at this point in the history
* evaluation tracker implementation

* OVModelForCausalLM test fix

* typo fix

* moved methods args

* multiple args in one flag

* loggers moved to dedicated dir

* improved filename sanitization
  • Loading branch information
KonradSzafer authored and notrichardren committed May 31, 2024
1 parent 5d2aa6b commit 0fe6d8c
Show file tree
Hide file tree
Showing 9 changed files with 528 additions and 174 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,23 @@ lm_eval --model hf \

We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.

## Saving Results

To save evaluation results provide an `--output_path`. We also support logging model responses with the `--log_samples` flag for post-hoc analysis.

Additionally, one can provide a directory with `--use_cache` to cache the results of prior runs. This allows you to avoid repeated execution of the same (model, task) pairs for re-scoring.

To push results and samples to the Hugging Face Hub, first ensure an access token with write access is set in the `HF_TOKEN` environment variable. Then, use the --hf_hub_log_args flag to specify the organization, repository name, repository visibility, and whether to push results and samples to the Hub. For example:

```bash
lm_eval --model hf \
--model_args pretrained=model-name-or-path,autogptq=model.safetensors,gptq_use_triton=True \
--tasks hellaswag \
--log_samples \
--output_path results \
----hf_hub_log_args hub_results_org=EleutherAI,hub_repo_name=lm-eval-results,push_results_to_hub=True,push_samples_to_hub=True,public_repo=False \
```

For a full list of supported arguments, check out the [interface](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md) guide in our documentation!

## Visualizing Results
Expand Down
94 changes: 39 additions & 55 deletions lm_eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,16 @@
import json
import logging
import os
import re
import sys
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Union

import numpy as np

from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger
from lm_eval.logging import EvaluationTracker, WandbLogger
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table, simple_parse_args_string


DEFAULT_RESULTS_FILE = "results.json"


def _handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
return int(o)
elif isinstance(o, set):
return list(o)
else:
return str(o)
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string


def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
Expand Down Expand Up @@ -203,6 +188,12 @@ def setup_parser() -> argparse.ArgumentParser:
default="",
help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
)
parser.add_argument(
"--hf_hub_log_args",
type=str,
default="",
help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
)
parser.add_argument(
"--predict_only",
"-x",
Expand All @@ -228,7 +219,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)

return parser


Expand All @@ -251,6 +241,15 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Verbosity set to {args.verbosity}")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# update the evaluation tracker args with the output path and the HF token
args.hf_hub_log_args = f"output_path={args.output_path},token={os.environ.get('HF_TOKEN')},{args.hf_hub_log_args}"
evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
evaluation_tracker.general_config_tracker.log_experiment_args(
model_source=args.model,
model_args=args.model_args,
)

if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
Expand All @@ -262,6 +261,19 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
eval_logger.info(f"Including path: {args.include_path}")
task_manager = TaskManager(args.verbosity, include_path=args.include_path)

evaluation_tracker_args = Namespace(**evaluation_tracker_args)
if (
evaluation_tracker_args.push_results_to_hub
or evaluation_tracker_args.push_samples_to_hub
) and not evaluation_tracker_args.hub_results_org:
raise ValueError(
"If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
)
if evaluation_tracker_args.push_samples_to_hub and not args.log_samples:
eval_logger.warning(
"Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
)

if args.limit:
eval_logger.warning(
" --limit SHOULD ONLY BE USED FOR TESTING."
Expand Down Expand Up @@ -306,24 +318,6 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
)

if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning(
f"File {output_path_file} already exists. Results will be overwritten."
)
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
path.parent.mkdir(parents=True, exist_ok=True)
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)

# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
Expand Down Expand Up @@ -365,7 +359,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.log_samples:
samples = results.pop("samples")
dumped = json.dumps(
results, indent=2, default=_handle_non_serializable, ensure_ascii=False
results, indent=2, default=handle_non_serializable, ensure_ascii=False
)
if args.show_config:
print(dumped)
Expand All @@ -382,23 +376,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
except Exception as e:
eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

if args.output_path:
output_path_file.open("w", encoding="utf-8").write(dumped)

if args.log_samples:
for task_name, config in results["configs"].items():
output_name = "{}_{}".format(
re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", args.model_args),
task_name,
)
filename = path.joinpath(f"{output_name}.jsonl")
samples_dumped = json.dumps(
samples[task_name],
indent=2,
default=_handle_non_serializable,
ensure_ascii=False,
)
filename.write_text(samples_dumped, encoding="utf-8")
evaluation_tracker.save_results_aggregated(results=results, samples=samples)

if args.log_samples:
for task_name, config in results["configs"].items():
evaluation_tracker.save_results_samples(
task_name=task_name, samples=samples[task_name]
)

print(
f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
Expand Down
55 changes: 43 additions & 12 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
import logging
import random
import time
Expand All @@ -20,9 +21,15 @@
print_writeout,
run_task_tests,
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.logging.utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
from lm_eval.utils import (
eval_logger,
handle_non_serializable,
hash_string,
positional_deprecated,
simple_parse_args_string,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -272,16 +279,24 @@ def simple_evaluate(
results["config"] = {
"model": model_name,
"model_args": model_args,
"batch_size": batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": device,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
}
# add more detailed model info if available
if isinstance(lm, lm_eval.models.huggingface.HFLM):
results["config"].update(lm.get_model_info())
# add info about execution
results["config"].update(
{
"batch_size": batch_size,
"batch_sizes": (
list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
),
"device": device,
"use_cache": use_cache,
"limit": limit,
"bootstrap_iters": bootstrap_iters,
"gen_kwargs": gen_kwargs,
}
)
results["git_hash"] = get_git_commit_hash()
results["date"] = start_date
add_env_info(results) # additional environment info to results
Expand Down Expand Up @@ -349,7 +364,6 @@ def evaluate(
eval_logger.debug(
f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
)

if write_out:
print_writeout(task)
# aggregate Instances by LM method requested to get output.
Expand Down Expand Up @@ -435,6 +449,16 @@ def evaluate(
"filtered_resps": [
req.filtered_resps[filter_key] for req in requests
],
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
indent=2,
default=handle_non_serializable,
ensure_ascii=False,
)
),
"prompt_hash": hash_string(requests[0].arguments[0]),
"target_hash": hash_string(str(target)),
}
example.update(metrics)
task_output.logged_samples.append(example)
Expand Down Expand Up @@ -565,6 +589,13 @@ def evaluate(
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
"n-samples": {
task_output.task_name: {
"original": len(task_output.task.eval_docs),
"effective": min(limit, len(task_output.task.eval_docs)),
}
for task_output in eval_tasks
},
}
if log_samples:
results_dict["samples"] = dict(samples)
Expand Down
2 changes: 2 additions & 0 deletions lm_eval/logging/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .evaluation_tracker import EvaluationTracker
from .wandb_logger import WandbLogger
Loading

0 comments on commit 0fe6d8c

Please sign in to comment.