diff --git a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py index c5b46b82a..0216734d1 100644 --- a/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/lorax_server/models/custom_modeling/flash_gpt2_modeling.py @@ -363,6 +363,7 @@ class FlashGPT2ForCausalLM(FlashGPT2PreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.transformer = FlashGPT2Model(config, weights) + self.wte_t = self.transformer.wte.weight.T.contiguous() def forward( self, @@ -393,6 +394,6 @@ def forward( # lm_head reuses the weights of the embedding layer # https://github.com/huggingface/transformers/issues/6291 - logits = hidden_states @ self.transformer.wte.weight.T + logits = hidden_states @ self.wte_t logits = logits[:, :self.transformer.config.vocab_size] return logits diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 7ed60ec72..4c883c214 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -218,7 +218,7 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False): # https://huggingface.co/docs/peft/package_reference/tuners#peft.LoraConfig.fan_in_fan_out # Set to True if replacing a Conv1D layer with a Linear layer if fan_in_fan_out: - weight = weight.T + weight = weight.T.contiguous() if quantize is None: linear = FastLinear(weight, bias)