Skip to content

Commit

Permalink
change: add set_schema to trs.input.base._Inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jasper430 committed Oct 21, 2019
1 parent fce660b commit 902cee6
Show file tree
Hide file tree
Showing 15 changed files with 164 additions and 88 deletions.
21 changes: 16 additions & 5 deletions torecsys/inputs/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import namedtuple
import torch.nn as nn
from typing import List, Union


class _Inputs(nn.Module):
Expand All @@ -20,11 +21,21 @@ def __len__(self) -> int:
"""
return self.length

def get_schema(self) -> namedtuple:
return self.schema

def set_schema(self, *args):
raise NotImplementedError("set_schema cannot be called in the base class.")
def set_schema(self, inputs: Union[str, List[str]]):
r"""Initialize input layer's schema.
Args:
inputs (Union[str, List[str]]): String or list of strings of inputs' field names.
"""
# convert string to list of string
if isinstance(inputs, str):
inputs = [inputs]

# create a namedtuple of schema
schema = namedtuple("Schema", ["inputs"])

# initialize self.schema with the namedtuple
self.schema = schema(inputs=inputs)


# from .audio_inp import AudioInputs
Expand Down
86 changes: 56 additions & 30 deletions torecsys/inputs/base/concat_inputs.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,66 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
import torch
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
from typing import Dict, List


class ConcatInputs(_Inputs):
r"""Base Inputs class for concatenation of list of Base Inputs class in rowwise. The shape of output
is :math:`(B, 1, E_{1} + ... + E_{k})`, where :math:`E_{i}` is embedding size of :math:`i-th` field.
"""
@jit_experimental
def __init__(self, schema: List[tuple]):
@no_jit_experimental_by_namedtensor
def __init__(self, inputs: List[_Inputs]):
r"""Initialize ConcatInputs.
Args:
schema (List[tuple]): Schema of ConcatInputs. List of Tuple of Inputs class (i.e. class in
trs.inputs.base) and list of string of input fields. e.g.
inputs (List[_Inputs]): List of input's layers (trs.inputs.base._Inputs),
i.e. class of trs.inputs.base. e.g.
.. code-block:: python
import torecsys as trs
schema = [
(trs.inputs.base.SingleIndexEmbedding(4, 10), ["userId"]),
(trs.inputs.base.SingleIndexEmbedding(4, 10), ["movieId"])
]
# initialize embedding layers used in ConcatInputs
single_index_emb_0 = trs.inputs.base.SingleIndexEmbedding(2, 8)
single_index_emb_1 = trs.inputs.base.SingleIndexEmbedding(2, 8)
# set schema, including field names etc
single_index_emb_0.set_schema(["userId"])
single_index_emb_1.set_schema(["movieId"])
# create ConcatInputs embedding layer
inputs = [single_index_emb_0, single_index_emb_1]
concat_emb = trs.inputs.base.ConcatInputs(inputs=inputs)
Attributes:
schema (List[tuple]): Schema of ConcatInputs.
length (int): Sum of length of inputs (i.e. number of fields of inputs, or embedding size of
embedding) in schema.
inputs (List[_Inputs]): List of input's layers.
length (int): Sum of length of input's layers,
i.e. number of fields of inputs, or embedding size of embedding.
"""
# refer to parent class
super(ConcatInputs, self).__init__()

# bind schema to schema
self.schema = schema
# bind inputs to inputs
self.inputs = inputs

# add schemas and modules from inputs to this module
inputs = []
for idx, inp in enumerate(self.inputs):
# add module
self.add_module("input_%d" % idx, inp)

# add modules in schema to the Module
for i, tup in enumerate(schema):
self.add_module("embedding_%d" % i, tup[0])
# append fields name to the list `inputs`
schema = inp.schema
for arguments in schema:
if isinstance(arguments, list):
inputs.extend(arguments)
elif isinstance(arguments, str):
inputs.append(arguments)

self.set_schema(inputs=list(set(inputs)))

# bind length to sum of lengths of inputs (i.e. number of fields of inputs, or embedding
# size of embedding)
self.length = sum([len(tup[0]) for tup in self.schema])
# bind length to sum of lengths of inputs,
# i.e. number of fields of inputs, or embedding size of embedding.
self.length = sum([len(inp) for inp in self.inputs])

def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
r"""Foward calculation of ConcatInputs.
Expand All @@ -56,24 +76,30 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
# initialize list to store tensors temporarily
outputs = list()

# loop through schema
for args_tuple in self.schema:
# get basic args from tuple in schema
embedding = args_tuple[0]
inp_names = args_tuple[1]
# loop through inputs
for inp in self.inputs:
# get schema, i.e. input's field names, from input in list
inp_names = inp.schema.inputs

# convert list of inputs to tensor, with shape = (B, N, *)
inp_val = [inputs[i] for i in inp_names]
inp_val = torch.cat(inp_val, dim=1)
args = [inp_val]
inp_args = [inp_val]

# set args for specific input
if embedding.__class__.__name__ == "SequenceIndexEmbedding":
arg_name = args_tuple[2][0]
args.append(inputs[arg_name])
if inp.__class__.__name__ == "SequenceIndexEmbedding":
inp_names = inp.schema.lengths
inp_args.append(inputs[inp_names])

# calculate embedding values
output = inp(*inp_args)

# check if output dimension is less than 3, then .unsqueeze(1)
if output.dim() < 3:
output = output.unflatten("E", [("N", 1), ("E", output.size("E"))])

# append tensor to outputs
outputs.append(embedding(*args))
outputs.append(output)

# concat in the third dimension, and the shape of output = (B, 1, sum(E))
outputs = torch.cat(outputs, dim=2)
Expand Down
4 changes: 2 additions & 2 deletions torecsys/inputs/base/image_inp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
import torch
import torch.nn as nn
from typing import List
Expand All @@ -9,7 +9,7 @@ class ImageInputs(_Inputs):
r"""Base Inputs class for image, which embed image by a stack of convalution neural network (CNN)
and fully-connect layer.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
in_channels : int,
Expand Down
1 change: 1 addition & 0 deletions torecsys/inputs/base/images_list_inp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import _Inputs
import torch


class ImagesListInputs(_Inputs):
r"""Base Inputs class for list of images
"""
Expand Down
4 changes: 2 additions & 2 deletions torecsys/inputs/base/list_indices_emb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import _Inputs
from torecsys.functional import show_attention, dummy_attention
from torecsys.utils.decorator import jit_experimental, no_jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
from functools import partial
import numpy as np
import torch
Expand All @@ -12,7 +12,7 @@ class ListIndicesEmbedding(_Inputs):
r"""Base Inputs class for embedding of list of indices without order, which embed the
list by multihead attention and aggregate before return.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
field_size : int,
Expand Down
4 changes: 2 additions & 2 deletions torecsys/inputs/base/multi_indices_emb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -10,7 +10,7 @@ class MultiIndicesEmbedding(_Inputs):
r"""Base Inputs class for embedding indices in multi fields of inputs, which is more
efficent than embedding with a number of SingleIndexEmbedding.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
field_sizes : List[int],
Expand Down
4 changes: 2 additions & 2 deletions torecsys/inputs/base/multi_indices_field_aware_emb.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -17,7 +17,7 @@ class MultiIndicesFieldAwareEmbedding(_Inputs):
#. `Yuchin Juan et al, 2016. Field-aware Factorization Machines for CTR Prediction <https://www.csie.ntu.edu.tw/~cjlin/papers/ffm.pdf>`_.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
field_sizes : List[int],
Expand Down
4 changes: 2 additions & 2 deletions torecsys/inputs/base/pretrained_image_inp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor
import torch
import torch.nn as nn
import torchvision
Expand All @@ -9,7 +9,7 @@ class PretrainedImageInputs(_Inputs):
r"""Base Inputs class for image, which embed by famous pretrained model in Computer
Vision.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
model_name : str,
Expand Down
1 change: 1 addition & 0 deletions torecsys/inputs/base/pretrained_text_inp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import _Inputs
import torch


class PretrainedTextInputs(_Inputs):
r"""Base Inputs class for text, which embed by famous pretrained model in NLP.
"""
Expand Down
15 changes: 13 additions & 2 deletions torecsys/inputs/base/sequence_indices_emb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from . import _Inputs
from torecsys.utils.decorator import jit_experimental
from collections import namedtuple
from functools import partial
import torch
import torch.nn as nn
from torecsys.utils.decorator import jit_experimental, no_jit_experimental_by_namedtensor


class SequenceIndicesEmbedding(_Inputs):
r"""Base Inputs class for embedding of sequence of indices with order, which embed the
sequence by Recurrent Neural Network (RNN) and aggregate before return.
"""
@jit_experimental
@no_jit_experimental_by_namedtensor
def __init__(self,
embed_size : int,
field_size : int,
Expand Down Expand Up @@ -112,6 +113,16 @@ def __init__(self,
raise ValueError('output_method only allows ["avg_pooling", "max_pooling", "mean", "none", "sum"].')
self.output_method = output_method

def set_schema(self, inputs: str, lengths: str):
r"""Initialize input layer's schema of SequenceIndicesEmbedding.
Args:
inputs (str): String of input's field name.
lengths (str): String of length's field name.
"""
schema = namedtuple("Schema", ["inputs", "lengths"])
self.schema = schema(inputs=[inputs], lengths=lengths)

def forward(self, inputs: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
r"""Forward calculation of SequenceIndicesEmbedding
Expand Down

0 comments on commit 902cee6

Please sign in to comment.