Skip to content

Commit

Permalink
[feat] Add basic AMP support to scaled_dot_product_attention (faceboo…
Browse files Browse the repository at this point in the history
  • Loading branch information
fmassa committed Aug 23, 2021
1 parent c339af1 commit 45ae8c5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
64 changes: 64 additions & 0 deletions tests/test_core_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
import torch

from xformers.components.attention._sputnik_sparse import SparseCS
from xformers.components.attention.core import scaled_dot_product_attention

_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


def test_core_attention():
b, s, d = 8, 900, 32
Expand All @@ -16,3 +20,63 @@ def test_core_attention():
r_dense = scaled_dot_product_attention(a, a, a, m.to_dense())

assert torch.allclose(r_sparse, r_dense)


@pytest.mark.parametrize("device", _devices)
def test_amp_attention_dense_no_mask(device):
b, s, d = 8, 64, 32

a = torch.rand(b, s, d, device=device)

with torch.cuda.amp.autocast():
r = scaled_dot_product_attention(a, a, a, att_mask=None)

expected_device = torch.float16 if device == "cuda" else torch.float32
assert r.dtype == expected_device


@pytest.mark.parametrize("device", _devices)
def test_amp_attention_dense(device):
b, s, d = 8, 64, 32
prob = 0.9

a = torch.rand(b, s, d, device=device)
m = torch.rand(s, s, device=device) > prob

with torch.cuda.amp.autocast():
r = scaled_dot_product_attention(a, a, a, m)

expected_device = torch.float16 if device == "cuda" else torch.float32
assert r.dtype == expected_device


@pytest.mark.parametrize("device", _devices)
def test_amp_attention_sparse(device):
b, s, d = 8, 64, 32
prob = 0.9

a = torch.rand(b, s, d, device=device)
m = torch.rand(s, s, device=device) > prob
m = m.to_sparse()

with torch.cuda.amp.autocast():
r = scaled_dot_product_attention(a, a, a, m)

expected_device = torch.float32
assert r.dtype == expected_device


@pytest.mark.parametrize("device", _devices)
def test_amp_attention_sparsecs(device):
b, s, d = 8, 64, 32
prob = 0.9

a = torch.rand(b, s, d, device=device)
m = torch.rand(s, s, device=device) > prob
m = SparseCS(m, device)

with torch.cuda.amp.autocast():
r = scaled_dot_product_attention(a, a, a, m)

expected_device = torch.float32
assert r.dtype == expected_device
20 changes: 14 additions & 6 deletions xformers/components/attention/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from contextlib import nullcontext
from typing import Optional

import torch
Expand Down Expand Up @@ -161,12 +162,19 @@ def scaled_dot_product_attention(
att_mask: Optional[torch.Tensor],
dropout: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
att = scaled_query_key_softmax(q, k, att_mask)
autocast_disabled = isinstance(att_mask, SparseCS) or (
att_mask is not None and att_mask.is_sparse
)
with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext():
if autocast_disabled:
q, k, v = q.float(), k.float(), v.float()

# Optional dropout, could be part of the masking in the future
att = _apply_dropout(att, dropout)
att = scaled_query_key_softmax(q, k, att_mask)

# Get to the predicted values, for all heads
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
y = bmm(att, v)
# Optional dropout, could be part of the masking in the future
att = _apply_dropout(att, dropout)

# Get to the predicted values, for all heads
# y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)
y = bmm(att, v)
return y

0 comments on commit 45ae8c5

Please sign in to comment.