diff --git a/llama_inference_offload.py b/llama_inference_offload.py index e10d4a2..8acdcbf 100644 --- a/llama_inference_offload.py +++ b/llama_inference_offload.py @@ -213,9 +213,12 @@ def noop(*args, **kwargs): load_checkpoint_in_model(model, checkpoint, dtype='float16') model.seqlen = 2048 - quant.make_quant_attn(model) - if fused_mlp: - quant.make_fused_mlp(model) + if eval: + quant.make_quant_attn(model) + quant.make_quant_norm(model) + if fused_mlp: + quant.make_fused_mlp(model) + if warmup_autotune: quant.autotune_warmup_linear(model)