Skip to content

Commit

Permalink
Remove float64 cast for OwlVit and OwlV2 to support MPS device (huggi…
Browse files Browse the repository at this point in the history
…ngface#31071)

Remove float64
  • Loading branch information
qubvel authored May 28, 2024
1 parent 936ab7b commit c31473e
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 2 deletions.
1 change: 0 additions & 1 deletion src/transformers/models/owlv2/modeling_owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,6 @@ def forward(
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)

pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)

Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/owlvit/modeling_owlvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,6 @@ def forward(
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)

pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)

Expand Down

0 comments on commit c31473e

Please sign in to comment.