From a4d67ba40ad0273f7d24a8131e3812ca706af0ac Mon Sep 17 00:00:00 2001 From: Marcelo Diaz Date: Fri, 14 Mar 2025 09:29:38 -0400 Subject: [PATCH 1/4] Add load_in_16bit parameter to FastBaseModel.from_pretrained - Add load_in_16bit parameter with default value of False - Add validation to prevent conflicting loading options - Add support for loading models in 16-bit precision (float16/bfloat16) - Update error messages to include the new 16-bit option --- unsloth/models/vision.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 24015f82fe..5b1b55d817 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -158,6 +158,7 @@ def from_pretrained( dtype = None, load_in_4bit = True, load_in_8bit = False, + load_in_16bit = False, full_finetuning = False, token = None, device_map = "sequential", @@ -240,6 +241,11 @@ def from_pretrained( break pass + # Check for conflicting loading options + loading_options = sum([load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning]) + if loading_options > 1: + raise RuntimeError("Unsloth: Can only use one of load_in_4bit, load_in_8bit, load_in_16bit, or full_finetuning!") + bnb_config = None if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") @@ -247,8 +253,6 @@ def from_pretrained( load_in_8bit = False pass - if load_in_4bit and load_in_8bit: - raise RuntimeError("Unsloth: Can only load in 4bit or 8bit, not both!") if load_in_4bit: bnb_config = BitsAndBytesConfig( load_in_4bit = True, @@ -262,8 +266,11 @@ def from_pretrained( load_in_8bit = True, llm_int8_skip_modules = SKIP_QUANTIZATION_MODULES.copy(), ) - elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") + elif load_in_16bit: + print("Unsloth: Loading model in 16-bit precision.") + # No bnb_config needed for 16-bit, we'll use torch_dtype directly + elif not load_in_4bit and not load_in_8bit and not load_in_16bit and not full_finetuning: + print("Unsloth: LoRA, QLoRA, 16-bit, and full finetuning all not selected. Switching to QLoRA.") load_in_4bit = True bnb_config = BitsAndBytesConfig( load_in_4bit = True, From bf3ca8e91d98bb9ff9de2e6638e17c0e1c135a10 Mon Sep 17 00:00:00 2001 From: Marcelo Diaz Date: Fri, 14 Mar 2025 09:31:48 -0400 Subject: [PATCH 2/4] Fix quantization_config assignment for 8-bit loading Update condition to assign quantization_config to kwargs when either load_in_4bit or load_in_8bit is True --- unsloth/models/vision.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index 5b1b55d817..c6c297c509 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -294,7 +294,8 @@ def from_pretrained( kwargs.pop("attn_implementation", None); # No need since we auto call it # Cannot be None, since HF now checks for the config - if load_in_4bit: kwargs["quantization_config"] = bnb_config + if load_in_4bit or load_in_8bit: + kwargs["quantization_config"] = bnb_config model = auto_model.from_pretrained( model_name, From 38d2409e3cb1df34f33bad5b53cf274bcbed40c5 Mon Sep 17 00:00:00 2001 From: Marcelo Diaz Date: Fri, 14 Mar 2025 09:54:29 -0400 Subject: [PATCH 3/4] Remove load_in_16bit from kwargs as it's not a valid parameter for transformers --- unsloth/models/vision.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/unsloth/models/vision.py b/unsloth/models/vision.py index c6c297c509..4fe9fe032c 100644 --- a/unsloth/models/vision.py +++ b/unsloth/models/vision.py @@ -296,6 +296,9 @@ def from_pretrained( # Cannot be None, since HF now checks for the config if load_in_4bit or load_in_8bit: kwargs["quantization_config"] = bnb_config + + # Remove load_in_16bit from kwargs as it's not a valid parameter for transformers + kwargs.pop("load_in_16bit", None) model = auto_model.from_pretrained( model_name, From 5cb2d6d958d4f08f64e61039a45ab5d49e84526b Mon Sep 17 00:00:00 2001 From: Marcelo Diaz Date: Fri, 14 Mar 2025 10:29:20 -0400 Subject: [PATCH 4/4] Add load_in_16bit parameter for 16-bit precision vision model loading --- unsloth/models/loader.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 44475780af..e27061d1a9 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -74,6 +74,7 @@ def from_pretrained( dtype = None, load_in_4bit = True, load_in_8bit = False, + load_in_16bit = False, full_finetuning = False, token = None, device_map = "sequential", @@ -93,13 +94,14 @@ def from_pretrained( disable_log_stats = True, *args, **kwargs, ): - if load_in_8bit or full_finetuning: + if load_in_8bit or load_in_16bit or full_finetuning: return FastModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, full_finetuning = full_finetuning, token = token, device_map = device_map, @@ -299,6 +301,7 @@ def from_pretrained( dtype = dtype, load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, full_finetuning = full_finetuning, token = token, device_map = device_map, @@ -445,6 +448,7 @@ def from_pretrained( dtype = None, load_in_4bit = True, load_in_8bit = False, + load_in_16bit = False, # Load model in 16-bit precision (float16/bfloat16) full_finetuning = False, token = None, device_map = "sequential", @@ -467,22 +471,23 @@ def from_pretrained( if use_gradient_checkpointing == "unsloth": patch_unsloth_smart_gradient_checkpointing(dtype = dtype) + # Check for conflicting loading options + loading_options = sum([load_in_4bit, load_in_8bit, load_in_16bit, full_finetuning]) + if loading_options > 1: + raise RuntimeError("Unsloth: Can only use one of load_in_4bit, load_in_8bit, load_in_16bit, or full_finetuning!") + if full_finetuning and (load_in_4bit or load_in_8bit): print("Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA.") load_in_4bit = False load_in_8bit = False pass - if load_in_4bit and load_in_8bit: - raise RuntimeError( - "Unsloth: Can only load in 4bit or 8bit, not both!\n"\ - "Also, we by default set `load_in_4bit = True`.\n"\ - "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`" - ) if load_in_4bit: pass elif load_in_8bit: pass - elif not load_in_4bit and not load_in_8bit and not full_finetuning: - print("Unsloth: LoRA, QLoRA and full finetuning all not selected. Switching to QLoRA.") + elif load_in_16bit: + print("Unsloth: Loading model in 16-bit precision.") + elif not load_in_4bit and not load_in_8bit and not load_in_16bit and not full_finetuning: + print("Unsloth: LoRA, QLoRA, 16-bit, and full finetuning all not selected. Switching to QLoRA.") load_in_4bit = True pass @@ -668,6 +673,7 @@ def from_pretrained( dtype = _get_dtype(dtype), load_in_4bit = load_in_4bit, load_in_8bit = load_in_8bit, + load_in_16bit = load_in_16bit, full_finetuning = full_finetuning, token = token, device_map = device_map,