diff --git a/tests/e2e/e2e_utils.py b/tests/e2e/e2e_utils.py index f052a7ca73..92a272737f 100644 --- a/tests/e2e/e2e_utils.py +++ b/tests/e2e/e2e_utils.py @@ -10,15 +10,12 @@ from tests.testing_utils import process_dataset -@log_time -def _load_model_and_processor( - model: str, - model_class: str, -): +def load_model(model: str, model_class: str, device_map: str | None = None): pretrained_model_class = getattr(transformers, model_class) - loaded_model = pretrained_model_class.from_pretrained(model, torch_dtype="auto") - processor = AutoProcessor.from_pretrained(model) - return loaded_model, processor + loaded_model = pretrained_model_class.from_pretrained( + model, torch_dtype="auto", device_map=device_map + ) + return loaded_model @log_time @@ -41,9 +38,8 @@ def run_oneshot_for_e2e_testing( # Load model. oneshot_kwargs = {} - loaded_model, processor = _load_model_and_processor( - model=model, model_class=model_class - ) + loaded_model = load_model(model=model, model_class=model_class) + processor = AutoProcessor.from_pretrained(model) if dataset_id: ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split) diff --git a/tests/lmeval/test_lmeval.py b/tests/lmeval/test_lmeval.py index da99782c68..4e38d7643b 100644 --- a/tests/lmeval/test_lmeval.py +++ b/tests/lmeval/test_lmeval.py @@ -13,7 +13,7 @@ from pydantic import BaseModel from llmcompressor.core import active_session -from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing +from tests.e2e.e2e_utils import load_model, run_oneshot_for_e2e_testing from tests.test_timer.timer_utils import get_singleton_manager, log_time from tests.testing_utils import requires_gpu @@ -35,6 +35,10 @@ class LmEvalConfig(BaseModel): try: import lm_eval + import lm_eval.api.registry + + # needed to populate model registry + import lm_eval.models # noqa lm_eval_installed = True except ImportError: @@ -120,7 +124,7 @@ def test_lm_eval(self, test_data_file: str): # Always evaluate base model for recovery testing logger.info("================= Evaluating BASE model ======================") - self.base_results = self._eval_base_model() + base_results = self._eval_base_model() if not self.save_dir: self.save_dir = self.model.split("/")[1] + f"-{self.scheme}" @@ -145,22 +149,41 @@ def test_lm_eval(self, test_data_file: str): self._handle_recipe() logger.info("================= Running LM Eval on COMPRESSED model ==========") - self._run_lm_eval() + compressed_results = self._eval_compressed_model() + + # Always use recovery testing + self._validate_recovery(base_results, compressed_results) + + # If absolute metrics provided, show warnings (not failures) + if self.lmeval.metrics: + self._check_absolute_warnings(compressed_results) self.tear_down() @log_time - def _eval_base_model(self): + def _eval_base_model(self) -> dict: """Evaluate the base (uncompressed) model.""" - model_args = {**self.lmeval.model_args, "pretrained": self.model} + return self._eval_model(self.model) + + @log_time + def _eval_compressed_model(self) -> dict: + """Evaluate the compressed model.""" + return self._eval_model(self.save_dir) + + def _eval_model(self, model: str) -> dict: + # NOTE: pass in PreTrainedModel to avoid lm_eval's model-loading logic + # https://github.com/EleutherAI/lm-evaluation-harness/pull/3393 + lm_eval_cls = lm_eval.api.registry.get_model(self.lmeval.model) results = lm_eval.simple_evaluate( - model=self.lmeval.model, - model_args=model_args, + model=lm_eval_cls( + pretrained=load_model(model, self.model_class, device_map="cuda:0"), + batch_size=self.lmeval.batch_size, + **self.lmeval.model_args, + ), tasks=[self.lmeval.task], num_fewshot=self.lmeval.num_fewshot, limit=self.lmeval.limit, - device="cuda:0", apply_chat_template=self.lmeval.apply_chat_template, batch_size=self.lmeval.batch_size, ) @@ -181,31 +204,9 @@ def _handle_recipe(self): fp.write(recipe_yaml_str) session.reset() - @log_time - def _run_lm_eval(self): - model_args = {"pretrained": self.save_dir} - model_args.update(self.lmeval.model_args) - results = lm_eval.simple_evaluate( - model=self.lmeval.model, - model_args=model_args, - tasks=[self.lmeval.task], - num_fewshot=self.lmeval.num_fewshot, - limit=self.lmeval.limit, - device="cuda:0", - apply_chat_template=self.lmeval.apply_chat_template, - batch_size=self.lmeval.batch_size, - ) - - # Always use recovery testing - self._validate_recovery(results) - - # If absolute metrics provided, show warnings (not failures) - if self.lmeval.metrics: - self._check_absolute_warnings(results) - - def _validate_recovery(self, compressed_results): + def _validate_recovery(self, base_results, compressed_results): """Validate using recovery testing - compare against base model.""" - base_metrics = self.base_results["results"][self.lmeval.task] + base_metrics = base_results["results"][self.lmeval.task] compressed_metrics = compressed_results["results"][self.lmeval.task] higher_is_better_map = compressed_results.get("higher_is_better", {}).get( self.lmeval.task, {}