From bc9e5516a13efa80975ae706890fc4a05291d154 Mon Sep 17 00:00:00 2001 From: Yelaman Abdullin Date: Wed, 3 Apr 2024 23:13:53 +1100 Subject: [PATCH] fix GemmaModel_fast_forward_inference --- unsloth/models/gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsloth/models/gemma.py b/unsloth/models/gemma.py index 0be28da9..0259ab78 100644 --- a/unsloth/models/gemma.py +++ b/unsloth/models/gemma.py @@ -134,7 +134,7 @@ def GemmaModel_fast_forward_inference( position_ids, attention_mask = None, ): - out_weight = torch.empty_like(self.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda") + out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda") input_ids = input_ids[:,:self.max_seq_length] hidden_states = self.model.embed_tokens(input_ids) hidden_states = hidden_states.to(self.config.torch_dtype)