Skip to content

Commit 941b7fc

Browse files
Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) (#11763)
* Avoid creating tensor in CosmosAttnProcessor2_0 (#11761) * up --------- Co-authored-by: yiyixuxu <yixu310@gmail.com>
1 parent 76a62ac commit 941b7fc

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,15 @@ def __call__(
187187
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
188188

189189
# 4. Prepare for GQA
190-
query_idx = torch.tensor(query.size(3), device=query.device)
191-
key_idx = torch.tensor(key.size(3), device=key.device)
192-
value_idx = torch.tensor(value.size(3), device=value.device)
190+
if torch.onnx.is_in_onnx_export():
191+
query_idx = torch.tensor(query.size(3), device=query.device)
192+
key_idx = torch.tensor(key.size(3), device=key.device)
193+
value_idx = torch.tensor(value.size(3), device=value.device)
194+
195+
else:
196+
query_idx = query.size(3)
197+
key_idx = key.size(3)
198+
value_idx = value.size(3)
193199
key = key.repeat_interleave(query_idx // key_idx, dim=3)
194200
value = value.repeat_interleave(query_idx // value_idx, dim=3)
195201

0 commit comments

Comments
 (0)