diff --git a/torchao/_models/_eval.py b/torchao/_models/_eval.py index de7f010035..29582e9b3d 100644 --- a/torchao/_models/_eval.py +++ b/torchao/_models/_eval.py @@ -59,11 +59,16 @@ def _model_call(self, inps): with torch.device(self._device): if hasattr(self._model, "setup_caches"): self._model.setup_caches(self.batch_size, max_seq_length) - logits = self._model(*input) + output = self._model(*input) from transformers.modeling_outputs import CausalLMOutputWithPast + from transformers.models.gemma3.modeling_gemma3 import ( + Gemma3CausalLMOutputWithPast, + ) - if isinstance(logits, CausalLMOutputWithPast): - logits = logits.logits + if isinstance(output, (CausalLMOutputWithPast, Gemma3CausalLMOutputWithPast)): + logits = output.logits + else: + logits = output return logits def run_eval(self, tasks, limit):