Skip to content

Commit

Permalink
Merge pull request #275 from pyt-team/allsettransformer_efficient
Browse files Browse the repository at this point in the history
efficient implementation of attention
  • Loading branch information
ninamiolane committed May 14, 2024
2 parents 2267768 + d702faf commit b70ece9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
default_language_version:
python: python3.10
python: python3.11

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
32 changes: 14 additions & 18 deletions topomodelx/nn/hypergraph/allset_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.utils import softmax
from torch_scatter import scatter

from topomodelx.base.message_passing import MessagePassing

Expand Down Expand Up @@ -257,8 +259,10 @@ def forward(self, x_source, neighborhood):

# Obtain Y from Eq(8) in AllSet paper [1]
# Skip-connection (broadcased)
x_message_on_target = x_message_on_target + self.multihead_att.Q_weight

x_message_on_target = x_message_on_target + self.multihead_att.Q_weight.permute(
1, 0, 2
)
x_message_on_target = x_message_on_target.unsqueeze(2)
# Permute: n,h,q,c -> n,q,h,c
x_message_on_target = x_message_on_target.permute(0, 2, 1, 3)
x_message_on_target = self.ln0(
Expand Down Expand Up @@ -368,21 +372,10 @@ def attention(self, x_source, neighborhood):
torch.Tensor, shape = (n_target_cells, heads, number_queries, n_source_cells)
Attention weights: one scalar per message between a source and a target cell.
"""
x_K = torch.matmul(x_source, self.K_weight)
alpha = torch.matmul(self.Q_weight, x_K.transpose(1, 2))
expanded_alpha = torch.sparse_coo_tensor(
indices=neighborhood.indices(),
values=alpha.permute(*torch.arange(alpha.ndim - 1, -1, -1))[
self.source_index_j
],
size=[
neighborhood.shape[0],
neighborhood.shape[1],
alpha.shape[1],
alpha.shape[0],
],
)
return torch.sparse.softmax(expanded_alpha, dim=1).to_dense().transpose(1, 3)
x_K = torch.matmul(x_source, self.K_weight).permute(1, 0, 2)
alpha = (x_K * self.Q_weight.permute(1, 0, 2)).sum(-1)
alpha = F.leaky_relu(alpha, 0.2)
return softmax(alpha[self.source_index_j], index=self.target_index_i)

def forward(self, x_source, neighborhood):
"""Forward pass.
Expand Down Expand Up @@ -410,8 +403,11 @@ def forward(self, x_source, neighborhood):
attention_values = self.attention(x_source, neighborhood)

x_message = torch.matmul(x_source, self.V_weight)
return torch.matmul(attention_values, x_message)

x_message = x_message.permute(1, 0, 2)[
self.source_index_j
] * attention_values.unsqueeze(-1)
return scatter(x_message, self.target_index_i, dim=0, reduce="sum")

class MLP(nn.Sequential):
"""MLP Module.
Expand Down

0 comments on commit b70ece9

Please sign in to comment.