Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential issue with load_modal function from trainer chekpoints for all models #176

Open
2 tasks done
AmazingK2k3 opened this issue Feb 26, 2025 · 6 comments
Open
2 tasks done
Assignees
Labels
bug Something isn't working enhancement New feature or request

Comments

@AmazingK2k3
Copy link

AmazingK2k3 commented Feb 26, 2025

Search before asking

  • I have searched the Multimodal Maestro issues and found no similar bug report.

Bug

Hello,

I was testing out the zeroshot object detection colab notebook personally in my aws environment and I noticed initially that the qwen model was loading across different gpus using this below code and not just the same code even after setting the cuda device.:

from maestro.trainer.models.qwen_2_5_vl.checkpoints import load_model, OptimizationStrategy

MODEL_ID_OR_PATH = "Qwen/Qwen2.5-VL-7B-Instruct"
MIN_PIXELS = 512 * 28 * 28
MAX_PIXELS = 2048 * 28 * 28

processor, model = load_model(
    model_id_or_path= "Qwen/Qwen2.5-VL-7B-Instruct",
    device = 'cuda:0',
    optimization_strategy=OptimizationStrategy.NONE,
    min_pixels=MIN_PIXELS,
    max_pixels=MAX_PIXELS,
)

I did browse through the checkpoints file for the qwen model and figured this might be the issue, where regardless of the parameter to the function the model device map is set 'auto'.

https://github.com/roboflow/maestro/blob/develop/maestro/trainer/models/qwen_2_5_vl/checkpoints.py#L81C2-L103C28

// maestro/trainer/models/qwen_2_5_vl/checkpoints.py

     model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id_or_path,
            revision=revision,
            trust_remote_code=True,
            device_map="auto",
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    else:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id_or_path,
            revision=revision,
            trust_remote_code=True,
            device_map="auto",
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        model.to(device)

which might override the hyperparameters.

I am confident this is a bug but let me know if this is a issue from my side.

Environment

  • maestro[qwen_2_5_vl]==1.1.0rc2"
  • aws environment - with 4A10 GPUs
  • WIthout import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "0", unable to prevent the model loading in all 4 GPUs.

Minimal Reproducible Example

No response

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@AmazingK2k3 AmazingK2k3 added the bug Something isn't working label Feb 26, 2025
@SkalskiP
Copy link
Collaborator

Hi @AmazingK2k3 Thank you for reporting this! It does appear that the hardcoded device_map="auto" is causing the model to distribute across all visible GPUs, even when a specific device is requested.

I'm not yet sure how to solve this problem. @AmazingK2k3 @probicheaux @Matvezy I'd love to hear your ideas. Here's my initial one:

Instead of always setting device_map="auto", allow users to pass in a custom device_map so they can control exactly which GPU(s) the model loads onto. For example:

def load_model(
    model_id_or_path: str = DEFAULT_QWEN2_5_VL_MODEL_ID,
    revision: str = DEFAULT_QWEN2_5_VL_MODEL_REVISION,
    device: str | torch.device = "auto",
    device_map: Optional[dict] = None,  # <-- new parameter
    optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
    cache_dir: Optional[str] = None,
    min_pixels: int = 256 * 28 * 28,
    max_pixels: int = 1280 * 28 * 28,
) -> tuple[Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration]:

    device = parse_device_spec(device)
    processor = Qwen2_5_VLProcessor.from_pretrained(
        ...
    )

    if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}:
        ...
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id_or_path,
            revision=revision,
            trust_remote_code=True,
            device_map=device_map if device_map else "auto",
            quantization_config=bnb_config,
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        ...
    else:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_id_or_path,
            revision=revision,
            trust_remote_code=True,
            device_map=device_map if device_map else "auto",
            torch_dtype=torch.bfloat16,
            cache_dir=cache_dir,
        )
        if not device_map:
            model.to(device)

    return processor, model

That should give maximum flexibility with reasonable defaults:

  • Load on CPU
processor, model = load_model(
    model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
    device="cpu"
)
  • Load on MPS
processor, model = load_model(
    model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
    device="mps"
)
  • Load on single GPU machine
processor, model = load_model(
    model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
    device="cuda:0"
)
  • Load model on all GPUs
processor, model = load_model(
    model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
    device_map="auto"
)
  • Load model on specific subset of GPUs
processor, model = load_model(
    model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
    device_map={"": "cuda:0"}
)

@Matvezy
Copy link
Contributor

Matvezy commented Feb 27, 2025

@SkalskiP I think your solution is generally good. The only small change that could help can be
device_map: Optional[Union[str, dict]] = None, # Support string values like "auto"

But overall this fix seems to do ok for giving users flexibility to control GPU allocation while maintaining existing behavior.

@AmazingK2k3
Copy link
Author

AmazingK2k3 commented Feb 27, 2025

@SkalskiP Thank you for following up on the issue really quickly. I think the solution you proposed broadly will work along with the one by @Matvezy to resolve this issue. Will you update this or should i submit a PR?

@SkalskiP
Copy link
Collaborator

@AmazingK2k3 would you like to implement it?

@AmazingK2k3
Copy link
Author

@SkalskiP Yes will be happy to. On a note unrelated to this, i was wondering whether it would be good if the load model can take in flash attention as a hyperparameter as well. This would help in memory saving especially while dealing with multiple image scenarios. This is already implemented and recommended by hugging face transformers library as well:

\\ https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )

this would be great for the models loaded via maestro to handle multi images efficiently. But not sure if there might be any issue regarding this, let me know!

@SkalskiP
Copy link
Collaborator

Let’s keep the flash attention enhancement as a separate PR so we can track and test it properly on its own. If you’d like to open a new issue or pull request for enabling attn_implementation="flash_attention_2", that would be very welcome. It definitely seems like a feature that could improve performance and memory usage, especially for multi-image setups.

Looking forward to your PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants