Skip to content
18 changes: 7 additions & 11 deletions tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
from tests.testing_utils import process_dataset


@log_time
def _load_model_and_processor(
model: str,
model_class: str,
):
def load_model(model: str, model_class: str, device_map: str | None = None):
pretrained_model_class = getattr(transformers, model_class)
loaded_model = pretrained_model_class.from_pretrained(model, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(model)
return loaded_model, processor
loaded_model = pretrained_model_class.from_pretrained(
model, torch_dtype="auto", device_map=device_map
)
return loaded_model


@log_time
Expand All @@ -41,9 +38,8 @@ def run_oneshot_for_e2e_testing(
# Load model.
oneshot_kwargs = {}

loaded_model, processor = _load_model_and_processor(
model=model, model_class=model_class
)
loaded_model = load_model(model=model, model_class=model_class)
processor = AutoProcessor.from_pretrained(model)

if dataset_id:
ds = load_dataset(dataset_id, name=dataset_config, split=dataset_split)
Expand Down
65 changes: 33 additions & 32 deletions tests/lmeval/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic import BaseModel

from llmcompressor.core import active_session
from tests.e2e.e2e_utils import run_oneshot_for_e2e_testing
from tests.e2e.e2e_utils import load_model, run_oneshot_for_e2e_testing
from tests.test_timer.timer_utils import get_singleton_manager, log_time
from tests.testing_utils import requires_gpu

Expand All @@ -35,6 +35,10 @@ class LmEvalConfig(BaseModel):

try:
import lm_eval
import lm_eval.api.registry

# needed to populate model registry
import lm_eval.models # noqa

lm_eval_installed = True
except ImportError:
Expand Down Expand Up @@ -120,7 +124,7 @@ def test_lm_eval(self, test_data_file: str):

# Always evaluate base model for recovery testing
logger.info("================= Evaluating BASE model ======================")
self.base_results = self._eval_base_model()
base_results = self._eval_base_model()

if not self.save_dir:
self.save_dir = self.model.split("/")[1] + f"-{self.scheme}"
Expand All @@ -145,22 +149,41 @@ def test_lm_eval(self, test_data_file: str):
self._handle_recipe()

logger.info("================= Running LM Eval on COMPRESSED model ==========")
self._run_lm_eval()
compressed_results = self._eval_compressed_model()

# Always use recovery testing
self._validate_recovery(base_results, compressed_results)

# If absolute metrics provided, show warnings (not failures)
if self.lmeval.metrics:
self._check_absolute_warnings(compressed_results)

self.tear_down()

@log_time
def _eval_base_model(self):
def _eval_base_model(self) -> dict:
"""Evaluate the base (uncompressed) model."""
model_args = {**self.lmeval.model_args, "pretrained": self.model}
return self._eval_model(self.model)

@log_time
def _eval_compressed_model(self) -> dict:
"""Evaluate the compressed model."""
return self._eval_model(self.save_dir)

def _eval_model(self, model: str) -> dict:
# NOTE: pass in PreTrainedModel to avoid lm_eval's model-loading logic
# https://github.com/EleutherAI/lm-evaluation-harness/pull/3393
lm_eval_cls = lm_eval.api.registry.get_model(self.lmeval.model)

results = lm_eval.simple_evaluate(
model=self.lmeval.model,
model_args=model_args,
model=lm_eval_cls(
pretrained=load_model(model, self.model_class, device_map="cuda:0"),
batch_size=self.lmeval.batch_size,
**self.lmeval.model_args,
),
tasks=[self.lmeval.task],
num_fewshot=self.lmeval.num_fewshot,
limit=self.lmeval.limit,
device="cuda:0",
apply_chat_template=self.lmeval.apply_chat_template,
batch_size=self.lmeval.batch_size,
)
Expand All @@ -181,31 +204,9 @@ def _handle_recipe(self):
fp.write(recipe_yaml_str)
session.reset()

@log_time
def _run_lm_eval(self):
model_args = {"pretrained": self.save_dir}
model_args.update(self.lmeval.model_args)
results = lm_eval.simple_evaluate(
model=self.lmeval.model,
model_args=model_args,
tasks=[self.lmeval.task],
num_fewshot=self.lmeval.num_fewshot,
limit=self.lmeval.limit,
device="cuda:0",
apply_chat_template=self.lmeval.apply_chat_template,
batch_size=self.lmeval.batch_size,
)

# Always use recovery testing
self._validate_recovery(results)

# If absolute metrics provided, show warnings (not failures)
if self.lmeval.metrics:
self._check_absolute_warnings(results)

def _validate_recovery(self, compressed_results):
def _validate_recovery(self, base_results, compressed_results):
"""Validate using recovery testing - compare against base model."""
base_metrics = self.base_results["results"][self.lmeval.task]
base_metrics = base_results["results"][self.lmeval.task]
compressed_metrics = compressed_results["results"][self.lmeval.task]
higher_is_better_map = compressed_results.get("higher_is_better", {}).get(
self.lmeval.task, {}
Expand Down