Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs in sequence classifier output #106

Closed
timo-obrecht opened this issue Jun 21, 2024 · 3 comments
Closed

NaNs in sequence classifier output #106

timo-obrecht opened this issue Jun 21, 2024 · 3 comments
Labels

Comments

@timo-obrecht
Copy link

timo-obrecht commented Jun 21, 2024

Hello,

I have an enormous amount of nan and inf in outputs of quantized models for sequence classification. It is not the case with non-quantized models, which never outputs nans whatever the sequence classification head initialisation. I conclude it must come from the quantization of the classification head.

This is not due to the model outputing nans, as I tested the quantized version for text generation without any problem.

This behavior is observed for several quantized models.

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_id = "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf"
# model_id = "ISTA-DASLab/Mistral-7B-v0.1-AQLM-PV-2Bit-1x16-hf"
# model_id = "ISTA-DASLab/Meta-Llama-3-8B-AQLM-PV-1Bit-1x16"
# model_id = "ISTA-DASLab/Meta-Llama-3-70B-AQLM-PV-2Bit-1x16"


quantized_model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map=device,
    num_labels=2,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
quantized_model.config.pad_token_id = tokenizer.eos_token_id


input_tokens = tokenizer([
    "There is a problem somewhere",
    "This is not related to the num_labels argument",
    "There are nan or inf way to many times to be normal"
    ],
    padding='longest'
    )

output = quantized_model(
    input_ids=torch.tensor(input_tokens['input_ids']).to(device),
    attention_mask=torch.tensor(input_tokens['attention_mask']).to(device)
    )
output.logits
>>>
# tensor([[nan, nan],
#         [nan, nan],
#         [nan, nan]], device='cuda:0', dtype=torch.bfloat16,
#        grad_fn=<IndexBackward0>)

For now, replacing the last layer with a non-quantized one solves the issue, but it is probably not an ideal solution.

new_head = torch.nn.modules.linear.Linear(
    quantized_model.score.in_features,
    quantized_model.score.out_features,
    device=device,
    dtype=torch.bfloat16
    )
setattr(quantized_model, "score", new_head)
type(quantized_model.score)

aqlm version : 1.1.6
torch : 2.3.1+cu121
nvcc version : 12.1
gcc version : 10.5.0
GPU : NVIDIA L40S

@justheuristic
Copy link
Collaborator

Hi!

As of right now, we only support quantized models with LM head (i.e. AutoModelForCausalLM).
For classifier, my guess is that the model created an empty classification head but it did not have any (quantized) weights associated with it. When initialized with empty data, quantized layers have undefined behavior.

Your workaround with explicitly initializing classifier head looks like a good (perhaps best) solution to this problem.

Unfortunately, we currently do not have time to explicitly support non-LM models (i.e. token classification, sequence classification, qa, etc). If you have bandwidth to implement such support without breaking LM models, we would welcome this as a pull request.

Copy link

This issue is stale because it has been open for 30 days with no activity.

@github-actions github-actions bot added the stale label Jul 25, 2024
Copy link

github-actions bot commented Aug 8, 2024

This issue was closed because it has been inactive for 14 days since being marked as stale.

@github-actions github-actions bot closed this as completed Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants