diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 0be28da9..0259ab78 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -134,7 +134,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda") + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype)