Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The output tensor's data type is not torch.long when the input text is empty. #36277

Open
2 of 4 tasks
wangzhen0518 opened this issue Feb 19, 2025 · 7 comments · May be fixed by #36555
Open
2 of 4 tasks

The output tensor's data type is not torch.long when the input text is empty. #36277

wangzhen0518 opened this issue Feb 19, 2025 · 7 comments · May be fixed by #36555
Labels

Comments

@wangzhen0518
Copy link

System Info

  • transformers version: 4.48.1
  • Platform: Linux-5.15.0-130-generic-x86_64-with-glibc2.35
  • Python version: 3.12.8
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.3.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): 0.10.2 (cpu)
  • Jax version: 0.5.0
  • JaxLib version: 0.5.0
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: NVIDIA GeForce RTX 3060 Ti

Who can help?

@ArthurZucker and @itazap

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The output tensor's data type is not torch.long when the input text is empty.

t = tokenizer('', return_tensors='pt')
print(t['input_ids'].dtype)
# torch.float32

Expected behavior

t = tokenizer('', return_tensors='pt')
print(t['input_ids'].dtype)
# torch.int64
@Rocketknight1
Copy link
Member

Hi @wangzhen0518, does this happen with all tokenizer classes, or just a specific one you tested?

@wangzhen0518
Copy link
Author

wangzhen0518 commented Feb 19, 2025

I have only tested it on the tokenizer of the QWen series models. Here is the complete code.

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-1.5B-Instruct')
t = tokenizer('', return_tensors='pt')
print(t['input_ids'].dtype) #torch.float32

@Rocketknight1
Copy link
Member

I've investigated further and I believe this is caused by the behaviour of torch.tensor():

torch.tensor([]).dtype # float32

torch.tensor([1, 2, 3]).dtype # int64

The torch behaviour when converting an input list to a tensor is that the output dtype will be int64 when all of the elements in the list are int, otherwise it will be torch.float32. This creates a strange edge case when the input is empty, but this empty input is not a valid model input regardless of dtype, and so I'm not sure if it's worth fixing this bug!

@wangzhen0518
Copy link
Author

Yes, you are right. I also noticed torch.tensor's behavior determined by the input's dtype.

Actually, I’m currently working with an LLM integrated into an environment. When concatenating the environment’s response (which is sometimes empty) to the LLM’s output tensor of type torch.long, the concatenation unexpectedly changes the resulting tensor’s dtype to torch.float32 when the response is empty. This dtype mismatch subsequently causes an error when feeding the concatenated tensor back into the LLM.

I know this can be easily resolved by adding type conversion when concatenating tensors, but should the tokenizer's behavior remain consistent, i.e., whether the default data type of return values should stay unchanged as long as no errors occur during tokenization?

@Rocketknight1
Copy link
Member

Hmm, I think it might make sense, but I'm unsure if the expected dtypes are stored anywhere, so this could be a tricky PR.

I suspect the only way we could do this is to have a dict of common tokenizer output names like input_ids, and map those to the appropriate dtype like torch.long. Other outputs might still exhibit this edge case behaviour.

cc @ArthurZucker for tokenizers - would you support a PR to add that, or does it add complexity for not enough gain?

@wangzhen0518
Copy link
Author

Thanks! Can we just modify function convert_to_tensors?

# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available():
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
)
import tensorflow as tf
as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available():
raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
import torch
is_tensor = torch.is_tensor
def as_tensor(value, dtype=None):
if isinstance(value, list) and isinstance(value[0], np.ndarray):
return torch.from_numpy(np.array(value))
return torch.tensor(value)
elif tensor_type == TensorType.JAX:
if not is_flax_available():
raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
import jax.numpy as jnp # noqa: F811
as_tensor = jnp.array
is_tensor = is_jax_tensor
elif tensor_type == TensorType.MLX:
if not is_mlx_available():
raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
import mlx.core as mx
as_tensor = mx.array
def is_tensor(obj):
return isinstance(obj, mx.array)
else:
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]
if not is_tensor(value):
tensor = as_tensor(value)

For example, just explicitly specify the dtype when creating tensors in the as_tensor function as following.

        elif tensor_type == TensorType.PYTORCH:
            if not is_torch_available():
                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
            import torch

            is_tensor = torch.is_tensor

            def as_tensor(value, dtype=None):
                if isinstance(value, list) and isinstance(value[0], np.ndarray):
                    return torch.from_numpy(np.array(value)).to(torch.long)
                return torch.tensor(value, dtype=torch.long)

@Rocketknight1
Copy link
Member

Hi @wangzhen0518, the problem is that some tokenizers might return float32 values for some keys! I'm not sure which, so I'd have to scan the codebase. We could try making your change in a PR and seeing which tests fail, though - feel free to open a PR for that and ping me!

wangzhen0518 added a commit to wangzhen0518/transformers that referenced this issue Mar 5, 2025
wangzhen0518 added a commit to wangzhen0518/transformers that referenced this issue Mar 5, 2025
wangzhen0518 added a commit to wangzhen0518/transformers that referenced this issue Mar 5, 2025
@wangzhen0518 wangzhen0518 linked a pull request Mar 5, 2025 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants