Skip to content

Commit

Permalink
[feat] Add some 2d patterns (facebookresearch#75)
Browse files Browse the repository at this point in the history
* Add some 2d-specific attention patterns
* Add notebook with examples
  • Loading branch information
fmassa committed Apr 30, 2021
1 parent c3ddb0a commit 2816f1b
Show file tree
Hide file tree
Showing 2 changed files with 466 additions and 0 deletions.
412 changes: 412 additions & 0 deletions docs/source/2d_attention_patterns.ipynb

Large diffs are not rendered by default.

54 changes: 54 additions & 0 deletions xformers/components/attention/attention_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,57 @@ def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor:
def causal_1d_pattern(attn_size: int) -> torch.Tensor:
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
return mask


# 2d-specific cases
def _generate_2d_grid(H, W):
i = torch.arange(H)
j = torch.arange(W)
i, j = torch.meshgrid(i, j)
return i, j


def horizontal_axial_2d_distance(H, W, p=2.0):
i, _ = _generate_2d_grid(H, W)
ij = i.reshape(-1, 1).float()
d = torch.cdist(ij, ij, p=p)
return d


def vertical_axial_2d_distance(H, W, p=2.0):
_, j = _generate_2d_grid(H, W)
ij = j.reshape(-1, 1).float()
d = torch.cdist(ij, ij, p=p)
return d


def local_2d_distance(H, W, p=2.0):
# axial is a special case with p=0 and distance=2
i, j = _generate_2d_grid(H, W)
ij = torch.stack([i.flatten(), j.flatten()], 1).float()
d = torch.cdist(ij, ij, p=p)
return d


def local_2d_gausian_distribution(H, W, sigma=1):
d = local_2d_distance(H, W, p=2.0)
d = torch.exp(-0.5 * sigma ** (-0.5) * d)
return d


def local_2d_pattern(H, W, distance, p=2.0):
d = local_2d_distance(H, W, p=p)
return d < distance


def axial_2d_pattern(H, W):
# axial is a special case with p=0 and distance=2
d = local_2d_distance(H, W, p=0)
return d < 2


def random_pattern_from_probability_matrix(dist_matrix, nnz):
att = torch.zeros_like(dist_matrix, dtype=torch.bool)
idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False)
att.view(-1)[idxs] = True
return att

0 comments on commit 2816f1b

Please sign in to comment.