You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have an enormous amount of nan and inf in outputs of quantized models for sequence classification. It is not the case with non-quantized models, which never outputs nans whatever the sequence classification head initialisation. I conclude it must come from the quantization of the classification head.
This is not due to the model outputing nans, as I tested the quantized version for text generation without any problem.
This behavior is observed for several quantized models.
importtorchfromtransformersimportAutoModelForSequenceClassification, AutoTokenizerdevice=torch.device("cuda:0"iftorch.cuda.is_available() else"cpu")
model_id="ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf"# model_id = "ISTA-DASLab/Mistral-7B-v0.1-AQLM-PV-2Bit-1x16-hf"# model_id = "ISTA-DASLab/Meta-Llama-3-8B-AQLM-PV-1Bit-1x16"# model_id = "ISTA-DASLab/Meta-Llama-3-70B-AQLM-PV-2Bit-1x16"quantized_model=AutoModelForSequenceClassification.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map=device,
num_labels=2,
)
tokenizer=AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token=tokenizer.eos_tokenquantized_model.config.pad_token_id=tokenizer.eos_token_idinput_tokens=tokenizer([
"There is a problem somewhere",
"This is not related to the num_labels argument",
"There are nan or inf way to many times to be normal"
],
padding='longest'
)
output=quantized_model(
input_ids=torch.tensor(input_tokens['input_ids']).to(device),
attention_mask=torch.tensor(input_tokens['attention_mask']).to(device)
)
output.logits>>># tensor([[nan, nan],# [nan, nan],# [nan, nan]], device='cuda:0', dtype=torch.bfloat16,# grad_fn=<IndexBackward0>)
For now, replacing the last layer with a non-quantized one solves the issue, but it is probably not an ideal solution.
As of right now, we only support quantized models with LM head (i.e. AutoModelForCausalLM).
For classifier, my guess is that the model created an empty classification head but it did not have any (quantized) weights associated with it. When initialized with empty data, quantized layers have undefined behavior.
Your workaround with explicitly initializing classifier head looks like a good (perhaps best) solution to this problem.
Unfortunately, we currently do not have time to explicitly support non-LM models (i.e. token classification, sequence classification, qa, etc). If you have bandwidth to implement such support without breaking LM models, we would welcome this as a pull request.
Hello,
I have an enormous amount of
nan
andinf
in outputs of quantized models for sequence classification. It is not the case with non-quantized models, which never outputs nans whatever the sequence classification head initialisation. I conclude it must come from the quantization of the classification head.This is not due to the model outputing nans, as I tested the quantized version for text generation without any problem.
This behavior is observed for several quantized models.
For now, replacing the last layer with a non-quantized one solves the issue, but it is probably not an ideal solution.
aqlm version : 1.1.6
torch : 2.3.1+cu121
nvcc version : 12.1
gcc version : 10.5.0
GPU : NVIDIA L40S
The text was updated successfully, but these errors were encountered: