@@ -104,12 +104,21 @@ def main():
104
104
raise ValueError ("Please specific dtype: --fp16 or --bf16" )
105
105
else :
106
106
dtype = "float32"
107
- quantization_config = dict (
108
- weight_quantize_algo = model_args .weight_quantize_algo ,
109
- qlora_weight_blocksize = model_args .qlora_weight_blocksize ,
110
- qlora_weight_double_quant = model_args .qlora_weight_double_quant ,
111
- qlora_weight_double_quant_block_size = model_args .qlora_weight_double_quant_block_size ,
112
- )
107
+
108
+ if hasattr (model_args , "qlora_weight_blocksize" ):
109
+ quantization_config = dict (
110
+ weight_quantize_algo = model_args .weight_quantize_algo ,
111
+ qlora_weight_blocksize = model_args .qlora_weight_blocksize ,
112
+ qlora_weight_double_quant = model_args .qlora_weight_double_quant ,
113
+ qlora_weight_double_quant_block_size = model_args .qlora_weight_double_quant_block_size ,
114
+ )
115
+ else :
116
+ quantization_config = dict (
117
+ weight_quantize_algo = model_args .weight_quantize_algo ,
118
+ weight_blocksize = model_args .weight_blocksize ,
119
+ weight_double_quant = model_args .weight_double_quant ,
120
+ weight_double_quant_block_size = model_args .weight_double_quant_block_size ,
121
+ )
113
122
114
123
model_config = AutoConfig .from_pretrained (
115
124
model_args .model_name_or_path ,
0 commit comments