Skip to content

Commit

Permalink
ERROR: NamedTensor is not work for jit now!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
p768lwy3 committed Oct 18, 2019
1 parent 33fd598 commit f3f561f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
4 changes: 2 additions & 2 deletions torecsys/inputs/base/single_index_emb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
import torch
import torch.nn as nn


class SingleIndexEmbedding(_Inputs):
r"""Base Inputs class for embedding a single index of a input field.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
field_size : int,
Expand Down
17 changes: 11 additions & 6 deletions torecsys/layers/ctr/compose_excitation_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,21 @@ def forward(self, field_emb_inputs: torch.Tensor) -> torch.Tensor:
T, shape = (B, N * N, E), dtype = torch.long: Output of ComposeExcitationNetworkLayer.
"""
# pooling with inputs' shape = (B, N * N, E) to output's shape = (B, N * N, 1)
pooled_inputs = self.pooling(field_emb_inputs)
pooled_inputs = self.pooling(field_emb_inputs.rename(None))
pooled_inputs.names = ("B", "N", "E")

# squeeze pooled_inputs into shape = (B, N * N)
pooled_inputs = pooled_inputs.squeeze()
## pooled_inputs = pooled_inputs.squeeze()
pooled_inputs = pooled_inputs.flatten(["N", "E"], "N")

# output's shape of attn_weights = (B, N * N)
attn_weights = self.fc(pooled_inputs)
# output's shape of attn_w = (B, N * N)
attn_w = self.fc(pooled_inputs.rename(None))
attn_w.names = ("B", "N")

# unsqueeze to (B, N * N, 1) and expand as x's shape = (B, N * N, E)
attn_weights = attn_weights.unsqueeze(-1)
outputs = field_emb_inputs * attn_weights.expand_as(field_emb_inputs)
## attn_w = attn_w.unsqueeze(-1)
attn_w = attn_w.unflatten("N", (("N", attn_w.size("N")), ("E", 1)))
## outputs = field_emb_inputs * attn_w.expand_as(field_emb_inputs)
outputs = field_emb_inputs * attn_w

return outputs
11 changes: 11 additions & 0 deletions torecsys/utils/decorator/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper

def no_jit_experimental_by_namedtensor(func: callable):
@wraps(func)
def wrapper(*args, **kwargs):
warnings.warn(
"The module is checked that it is not compatible with torch.jit.trace " +
"due to the NamedTensor method. This will be updated to compatibilized " +
"when PyTorch update.", UserWarning
)
return func(*args, **kwargs)
return wrapper


def jit_experimental(func: callable):
r"""a decorator to write a message in a layer or a estimator where they have been checked
Expand Down

0 comments on commit f3f561f

Please sign in to comment.