Skip to content

Commit

Permalink
CHANGE: change calculation to .einsum(...)
Browse files Browse the repository at this point in the history
This should be much more better to show the calculation logic inside the model.
  • Loading branch information
p768lwy3 committed Oct 29, 2019
1 parent b357829 commit 64d3f25
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions torecsys/layers/ctr/attentional_factorization_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,33 @@ def forward(self, emb_inputs: torch.Tensor) -> Tuple[torch.Tensor]:
Returns:
Tuple[T], shape = ((B, E) (B, NC2, 1)), dtype = torch.float: Output of AttentionalFactorizationMachineLayer and Attention weights.
"""
# Calculate inner product
# Calculate hadamard product
# inputs: emb_inputs, shape = (B, N, E)
# output: inner, shape = (B, NC2, E)
# output: products, shape = (B, NC2, E)
emb_inputs = emb_inputs.rename(None)
inner = emb_inputs[:, self.rowidx] * emb_inputs[:, self.colidx]
inner.names = ("B", "N", "E")
## products = emb_inputs[:, self.rowidx] * emb_inputs[:, self.colidx]
products = torch.einsum("ijk,ijk->ijk", [emb_inputs[:, self.rowidx], emb_inputs[:, self.colidx]])
## products.names = ("B", "N", "E")

# Calculate attention scores
# inputs: inner, shape = (B, NC2, E)
# inputs: products, shape = (B, NC2, E)
# output: attn_scores, shape = (B, NC2, 1)
attn_scores = self.attention(inner.rename(None))
attn_scores.names = ("B", "N", "E")
attn_scores = self.attention(products.rename(None))
## attn_scores.names = ("B", "N", "E")

# Apply attention on inner product
# inputs: inner, shape = (B, NC2, E)
# inputs: products, shape = (B, NC2, E)
# inputs: attn_scores, shape = (B, NC2, 1)
# output: outputs, shape = (B, E)
outputs = (inner * attn_scores).sum(dim="N")
## outputs = (products * attn_scores).sum(dim="N")
outputs = torch.einsum("ijk,ijh->ijk", [products, attn_scores])
outputs.names = ("B", "N", "E")
outputs = outputs.sum(dim="N")

# Apply dropout on outputs
# inputs: outputs, shape = (B, E)
# output: outputs, shape = (B, E)
outputs = self.dropout(outputs)

return outputs, attn_scores


0 comments on commit 64d3f25

Please sign in to comment.