-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Closed
Description
diffusers/src/diffusers/models/transformers/transformer_cosmos.py
Lines 188 to 193 in 42077e6
# 4. Prepare for GQA | |
query_idx = torch.tensor(query.size(3), device=query.device) | |
key_idx = torch.tensor(key.size(3), device=key.device) | |
value_idx = torch.tensor(value.size(3), device=value.device) | |
key = key.repeat_interleave(query_idx // key_idx, dim=3) | |
value = value.repeat_interleave(query_idx // value_idx, dim=3) |
# 4. Prepare for GQA
query_idx = query.size(3)
key_idx = key.size(3)
value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
Speedup ~10% here in Cosmos2TextToImagePipeline and Cosmos2VideoToWorldPipeline.
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Milestone
Relationships
Development
Select code repository
Activity
yiyixuxu commentedon Jun 21, 2025
thanks!
would you like to share a PR so we can test it out on our end?
Avoid creating tensor in CosmosAttnProcessor2_0 (huggingface#11761)
Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)
Avoid creating tensor in CosmosAttnProcessor2_0 (huggingface#11761) (h…
Avoid creating tensor in CosmosAttnProcessor2_0 (huggingface#11761) (h…