Skip to content

Commit

Permalink
Merge branch 'master' into pyg_hetero_dict_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 20, 2022
2 parents 2775502 + 3108d10 commit 2b09ff9
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class Linear(torch.nn.Module):
(:obj:`"zeros"` or :obj:`None`).
If set to :obj:`None`, will match default bias initialization of
:class:`torch.nn.Linear`. (default: :obj:`None`)
Shapes:
- **input:** features :math:`(*, F_{in})`
- **output:** features :math:`(*, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int, bias: bool = True,
weight_initializer: Optional[str] = None,
Expand Down Expand Up @@ -105,7 +109,10 @@ def reset_parameters(self):
f"'{self.bias_initializer}' is not supported")

def forward(self, x: Tensor) -> Tensor:
""""""
r"""
Args:
x (Tensor): The features.
"""
return F.linear(x, self.weight, self.bias)

@torch.no_grad()
Expand Down Expand Up @@ -167,6 +174,12 @@ class HeteroLinear(torch.nn.Module):
num_types (int): The number of types.
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.Linear`.
Shapes:
- **input:**
features :math:`(*, F_{in})`,
type vector :math:`(*)`
- **output:** features :math:`(*, F_{out})`
"""
def __init__(self, in_channels: int, out_channels: int, num_types: int,
**kwargs):
Expand All @@ -187,7 +200,11 @@ def reset_parameters(self):
lin.reset_parameters()

def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
""""""
r"""
Args:
x (Tensor): The input features.
type_vec (LongTensor): A vector that maps each entry to a type.
"""
out = x.new_empty(x.size(0), self.out_channels)
for i, lin in enumerate(self.lins):
mask = type_vec == i
Expand Down

0 comments on commit 2b09ff9

Please sign in to comment.