-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential issue with load_modal function from trainer chekpoints for all models #176
Comments
Hi @AmazingK2k3 Thank you for reporting this! It does appear that the hardcoded I'm not yet sure how to solve this problem. @AmazingK2k3 @probicheaux @Matvezy I'd love to hear your ideas. Here's my initial one: Instead of always setting def load_model(
model_id_or_path: str = DEFAULT_QWEN2_5_VL_MODEL_ID,
revision: str = DEFAULT_QWEN2_5_VL_MODEL_REVISION,
device: str | torch.device = "auto",
device_map: Optional[dict] = None, # <-- new parameter
optimization_strategy: OptimizationStrategy = OptimizationStrategy.NONE,
cache_dir: Optional[str] = None,
min_pixels: int = 256 * 28 * 28,
max_pixels: int = 1280 * 28 * 28,
) -> tuple[Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration]:
device = parse_device_spec(device)
processor = Qwen2_5_VLProcessor.from_pretrained(
...
)
if optimization_strategy in {OptimizationStrategy.LORA, OptimizationStrategy.QLORA}:
...
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map=device_map if device_map else "auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
...
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_id_or_path,
revision=revision,
trust_remote_code=True,
device_map=device_map if device_map else "auto",
torch_dtype=torch.bfloat16,
cache_dir=cache_dir,
)
if not device_map:
model.to(device)
return processor, model That should give maximum flexibility with reasonable defaults:
processor, model = load_model(
model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
device="cpu"
)
processor, model = load_model(
model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
device="mps"
)
processor, model = load_model(
model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
device="cuda:0"
)
processor, model = load_model(
model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
device_map="auto"
)
processor, model = load_model(
model_id_or_path="Qwen/Qwen2.5-VL-7B-Instruct",
device_map={"": "cuda:0"}
) |
@SkalskiP I think your solution is generally good. The only small change that could help can be But overall this fix seems to do ok for giving users flexibility to control GPU allocation while maintaining existing behavior. |
@AmazingK2k3 would you like to implement it? |
@SkalskiP Yes will be happy to. On a note unrelated to this, i was wondering whether it would be good if the load model can take in flash attention as a hyperparameter as well. This would help in memory saving especially while dealing with multiple image scenarios. This is already implemented and recommended by hugging face transformers library as well:
this would be great for the models loaded via maestro to handle multi images efficiently. But not sure if there might be any issue regarding this, let me know! |
Let’s keep the flash attention enhancement as a separate PR so we can track and test it properly on its own. If you’d like to open a new issue or pull request for enabling Looking forward to your PR! |
Search before asking
Bug
Hello,
I was testing out the zeroshot object detection colab notebook personally in my aws environment and I noticed initially that the qwen model was loading across different gpus using this below code and not just the same code even after setting the cuda device.:
I did browse through the checkpoints file for the qwen model and figured this might be the issue, where regardless of the parameter to the function the model device map is set 'auto'.
https://github.com/roboflow/maestro/blob/develop/maestro/trainer/models/qwen_2_5_vl/checkpoints.py#L81C2-L103C28
which might override the hyperparameters.
I am confident this is a bug but let me know if this is a issue from my side.
Environment
os.environ["CUDA_VISIBLE_DEVICES"] = "0", unable to prevent the model loading in all 4 GPUs.
Minimal Reproducible Example
No response
Additional
No response
Are you willing to submit a PR?
The text was updated successfully, but these errors were encountered: