Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Started getting new warnings for gemma3 after upgrading from 4.49.0-gemma3 to 4.50.0 #36942

Open
HJJ256 opened this issue Mar 24, 2025 · 14 comments

Comments

@HJJ256
Copy link

HJJ256 commented Mar 24, 2025

/opt/conda/lib/python3.10/site-packages/accelerate/utils/modeling.py:1569: UserWarning: Current model requires 33280 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.
warnings.warn(
[2025-03-24 19:49:10,626-accelerate.utils.modeling] - Based on the current allocation process, no modules could be assigned to the following devices due to insufficient memory:

  • 0: 2024584448 bytes required
    These minimum requirements are specific to this allocation attempt and may vary. Consider increasing the available memory for these devices to at least the specified minimum, or adjusting the model config.
    Loading checkpoint shards: 100% 5/5 [00:00<00:00, 40.00it/s]

After the model loads, on running generate, I get the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Note: I am pushing my input_ids to "cuda"

When I am loading and testing with 4.49.0-gemma3, everything is working fine. Is there a specific change that is affecting the model loading in 4.50.0?

Model: Gemma3-12b-it
GPU: NVIDIA RTX 4000 Ada
device_map: auto
torch dtype: bfloat16

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Mar 24, 2025

@HJJ256 hey, can you include a short reproducer or is that an inference script from model docs?

From the warnings seems that the model doesn't fir entirely in the GPU, but I am interested what changed between v4.49 to v.4.50

@HJJ256
Copy link
Author

HJJ256 commented Mar 24, 2025

@HJJ256 hey, can you include a short reproducer or is that an inference script from model docs?

From the warnings seems that the model doesn't fir entirely in the GPU, but I am interested what changed between v4.49 to v.4.50

The model doesn't fit completely in GPU, but it was able to generate the output with 4.49.0-gemma3. I am using the same code present in the model docs here with the addition of torch_dtype=torch.bfloat16 parameter in Gemma3ForConditionalGeneration.from_pretrained

This was the only dialog that I used to get with 4.49.0-gemma3
Loading checkpoint shards: 100% 5/5 [00:02<00:00, 2.02it/s]
[2025-03-24 20:56:56,855-accelerate.big_modeling] - Some parameters are on the meta device because they were offloaded to the cpu.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Mar 25, 2025

I see. I got issues with caching_allocator_warmup trying to reproduce which I believe has a PR to fix. Can confirm in model release branch it doesn't OOM with 20GB GPU

cc @SunMarc, seems related to accelerate

@SunMarc
Copy link
Member

SunMarc commented Mar 25, 2025

After the model loads, on running generate, I get the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Can you share the full traceback @HJJ256 . I'm able to run the model correctly on both branches with cpu/disk offload. Can you share the device_map of the model after loading the model : model.hf_device_map ?

@SunMarc
Copy link
Member

SunMarc commented Mar 25, 2025

The only thing that we need to fix is to specify torch_dtype=bfloat16 when loading the model in v4.50 (contrary to v4.49 gemma release). Otherwise, the model will be in fp32 (taking 48GB) cc @zucchini-nlp maybe we need to update gemma model card ?

@zucchini-nlp
Copy link
Member

yeah, we can update, let me open a PR

From issue description, it seems the user is already loading in bf16, not sure unless confirmed by the users

@SunMarc
Copy link
Member

SunMarc commented Mar 25, 2025

Yeah, i'm not sure about the issue that he's experiencing as I didn't manage to reproduce the device mismatch

@HJJ256
Copy link
Author

HJJ256 commented Mar 25, 2025

After the model loads, on running generate, I get the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Can you share the full traceback @HJJ256 . I'm able to run the model correctly on both branches with cpu/disk offload. Can you share the device_map of the model after loading the model : model.hf_device_map ?

Logs

TORCH DTYPE = torch.bfloat16
[2025-03-25 18:06:18,816-accelerate.utils.modeling] - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
/opt/conda/lib/python3.10/site-packages/accelerate/utils/modeling.py:1569: UserWarning: Current model requires 33280 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.
  warnings.warn(
[2025-03-25 18:06:18,853-accelerate.utils.modeling] - Based on the current allocation process, no modules could be assigned to the following devices due to insufficient memory:
  - 0: 2024584448 bytes required
These minimum requirements are specific to this allocation attempt and may vary. Consider increasing the available memory for these devices to at least the specified minimum, or adjusting the model config.
Loading checkpoint shards: 100% 5/5 [00:00<00:00, 39.59it/s]
HF_DEVICE_MAP = {'': 'cpu'}

I think the device map is incorrect, however, I am loading the model as specified and with torch_dtype=torch.bloat16

Traceback on request

  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
    result = self._sample(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 3286, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1310, in forward
    inputs_embeds = self.get_input_embeddings()(llm_input_ids)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 104, in forward
    return super().forward(input_ids) * self.embed_scale
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 164, in forward
    return F.embedding(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/functional.py", line 2267, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

nvidia-smi output when model is loaded. There is no other code running on the machine except this.

Tue Mar 25 19:11:46 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX 4000...  Off  | 00000000:01:00.0 Off |                  Off |
| 30%   40C    P8     5W /  70W |  19182MiB / 20475MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    273503      C   /opt/conda/bin/python           19180MiB |
+-----------------------------------------------------------------------------+

@HJJ256
Copy link
Author

HJJ256 commented Mar 25, 2025

Also, this is the device map output when I run the code using 4.49.0-gemma3 release

HF_DEVICE_MAP = {'vision_tower': 0, 'multi_modal_projector': 0, 'language_model.model.embed_tokens': 0, 'language_model.lm_head': 0, 'language_model.model.layers.0': 0, 'language_model.model.layers.1': 0, 'language_model.model.layers.2': 0, 'language_model.model.layers.3': 0, 'language_model.model.layers.4': 0, 'language_model.model.layers.5': 0, 'language_model.model.layers.6': 0, 'language_model.model.layers.7': 0, 'language_model.model.layers.8': 0, 'language_model.model.layers.9': 0, 'language_model.model.layers.10': 0, 'language_model.model.layers.11': 0, 'language_model.model.layers.12': 0, 'language_model.model.layers.13': 0, 'language_model.model.layers.14': 0, 'language_model.model.layers.15': 0, 'language_model.model.layers.16': 0, 'language_model.model.layers.17': 0, 'language_model.model.layers.18': 0, 'language_model.model.layers.19': 0, 'language_model.model.layers.20': 0, 'language_model.model.layers.21': 0, 'language_model.model.layers.22': 0, 'language_model.model.layers.23': 0, 'language_model.model.layers.24': 0, 'language_model.model.layers.25': 0, 'language_model.model.layers.26': 0, 'language_model.model.layers.27': 0, 'language_model.model.layers.28': 0, 'language_model.model.layers.29': 0, 'language_model.model.layers.30': 0, 'language_model.model.layers.31': 0, 'language_model.model.layers.32': 0, 'language_model.model.layers.33': 'cpu', 'language_model.model.layers.34': 'cpu', 'language_model.model.layers.35': 'cpu', 'language_model.model.layers.36': 'cpu', 'language_model.model.layers.37': 'cpu', 'language_model.model.layers.38': 'cpu', 'language_model.model.layers.39': 'cpu', 'language_model.model.layers.40': 'cpu', 'language_model.model.layers.41': 'cpu', 'language_model.model.layers.42': 'cpu', 'language_model.model.layers.43': 'cpu', 'language_model.model.layers.44': 'cpu', 'language_model.model.layers.45': 'cpu', 'language_model.model.layers.46': 'cpu', 'language_model.model.layers.47': 'cpu', 'language_model.model.norm': 'cpu', 'language_model.model.rotary_emb': 'cpu', 'language_model.model.rotary_emb_local': 'cpu'}

This looks completely correct for Gemma3

nvidia-smi output

Tue Mar 25 19:29:29 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX 4000...  Off  | 00000000:01:00.0 Off |                  Off |
| 30%   40C    P8     5W /  70W |  17542MiB / 20475MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    277401      C   /opt/conda/bin/python           17540MiB |
+-----------------------------------------------------------------------------+

accelerate version: 1.5.2

@HJJ256
Copy link
Author

HJJ256 commented Mar 25, 2025

I tried to put the above device_map directly for v4.50.0 and I got CUDA OOM error

File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4455, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4884, in _load_pretrained_model
    disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 852, in _load_state_dict_into_meta_model
    {param_type: param.to(param_device)},
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.88 GiB. GPU 0 has a total capacity of 19.68 GiB of which 970.88 MiB is free. Process 279374 has 18.73 GiB memory in use. Of the allocated memory 16.44 GiB is allocated by PyTorch, and 1.60 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables

Still do not get how it was able to load in previous version
I put PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True in the environment and now the model did load with the following device map, however it looks like some layers are missing completely

HF_DEVICE_MAP = {'vision_tower.vision_model.embeddings': 0, 'vision_tower.vision_model.encoder.layers.0': 0, 'vision_tower.vision_model.encoder.layers.1': 0, 'vision_tower.vision_model.encoder.layers.2': 0, 'vision_tower.vision_model.encoder.layers.3': 0, 'vision_tower.vision_model.encoder.layers.4': 0, 'vision_tower.vision_model.encoder.layers.5': 0, 'vision_tower.vision_model.encoder.layers.6': 0, 'vision_tower.vision_model.encoder.layers.7': 0, 'vision_tower.vision_model.encoder.layers.8': 0, 'vision_tower.vision_model.encoder.layers.9': 0, 'vision_tower.vision_model.encoder.layers.10': 0, 'vision_tower.vision_model.encoder.layers.11': 0, 'vision_tower.vision_model.encoder.layers.12': 0, 'vision_tower.vision_model.encoder.layers.13': 0, 'vision_tower.vision_model.encoder.layers.14': 'cpu', 'vision_tower.vision_model.encoder.layers.15': 'cpu', 'vision_tower.vision_model.encoder.layers.16': 'cpu', 'vision_tower.vision_model.encoder.layers.17': 'cpu', 'vision_tower.vision_model.encoder.layers.18': 'cpu', 'vision_tower.vision_model.encoder.layers.19': 'cpu', 'vision_tower.vision_model.encoder.layers.20': 'cpu', 'vision_tower.vision_model.encoder.layers.21': 'cpu', 'vision_tower.vision_model.encoder.layers.22': 'cpu', 'vision_tower.vision_model.encoder.layers.23': 'cpu', 'vision_tower.vision_model.encoder.layers.24': 'cpu', 'vision_tower.vision_model.encoder.layers.25': 'cpu', 'vision_tower.vision_model.encoder.layers.26': 'cpu', 'vision_tower.vision_model.post_layernorm': 'cpu', 'multi_modal_projector': 'cpu', 'language_model': 'cpu'}

I am getting this error during inference

  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 2195, in generate
    self._prepare_cache_for_generation(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1837, in _prepare_cache_for_generation
    model_kwargs[cache_name] = self._get_cache(
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1733, in _get_cache
    layer_device_map = self._get_layer_device_map_for_cache_init()
  File "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py", line 1690, in _get_layer_device_map_for_cache_init
    raise RuntimeError(f"layer {idx} has not been mapped to a device.")
RuntimeError: layer 27 has not been mapped to a device.

@HJJ256
Copy link
Author

HJJ256 commented Mar 25, 2025

@SunMarc @zucchini-nlp It seems that the problem is only occurring when the vision_tower is getting split on multiple devices, because in "/opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py" _get_layer_device_map_for_cache_init, we are getting the num_hidden_layers for the text_config only, and it is trying to map the number of layers in vision_tower (SIGLIP) model to the number of layers in gemma3_text model.

@zucchini-nlp
Copy link
Member

Oh interesting! That is smth I was going to work on, saw problems in multi-GPU inference a few times with Gemma3. Didn't think could affect cases with one gpu also, when the device map contains simply "language_model"

We call the LM part as "text_model" in some models (can check if we have other names as well). cc @SunMarc if you have bandwidth to make a workaround

Side note for cache: @gante, imo _get_layer_device_map_for_cache_init is assuming a lot of things and will break with more modalities. For example we are adding an omni modal model, and the LM backbone is nested smth model.thinker_model.language_model. I know we can't do meta device, so do you think we can get anything similar without inferring per-layer device on-the-fly?

@SunMarc
Copy link
Member

SunMarc commented Mar 26, 2025

in v4.50, for some kind of reason, the device_map that you get is HF_DEVICE_MAP = {'': 'cpu'}, so the model is just loaded on cpu hence the device mismatch.

Can you share the value of the args that goes into this function ?
device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)

If you can find the faulty commit that triggered this device_map + OOM issue, that would be even better using git bissect.
Thanks for your help so far !

@gante
Copy link
Member

gante commented Mar 26, 2025

Parallel to the comment above: I'm updating how we find the decoder layers device mapping for cache initialization, to handle the case where a non-decoder module has the pattern (...).X (X being an integer)

This should fix the case @zucchini-nlp described, as well as the device map example above where the language model is in a single device (but the vision model is not)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants