From 3108d10ed9ced905abe053ba93b1329fee7e7e31 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 20 Jan 2022 10:22:39 +0100 Subject: [PATCH] Documentation: `Linear` (#3893) * update linear doc * update linear doc * fix lint * fix lint --- torch_geometric/nn/dense/linear.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/dense/linear.py b/torch_geometric/nn/dense/linear.py index cbf02d4fbb9b..4fdaef0092bf 100644 --- a/torch_geometric/nn/dense/linear.py +++ b/torch_geometric/nn/dense/linear.py @@ -37,6 +37,10 @@ class Linear(torch.nn.Module): (:obj:`"zeros"` or :obj:`None`). If set to :obj:`None`, will match default bias initialization of :class:`torch.nn.Linear`. (default: :obj:`None`) + + Shapes: + - **input:** features :math:`(*, F_{in})` + - **output:** features :math:`(*, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, bias: bool = True, weight_initializer: Optional[str] = None, @@ -105,7 +109,10 @@ def reset_parameters(self): f"'{self.bias_initializer}' is not supported") def forward(self, x: Tensor) -> Tensor: - """""" + r""" + Args: + x (Tensor): The features. + """ return F.linear(x, self.weight, self.bias) @torch.no_grad() @@ -167,6 +174,12 @@ class HeteroLinear(torch.nn.Module): num_types (int): The number of types. **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.Linear`. + + Shapes: + - **input:** + features :math:`(*, F_{in})`, + type vector :math:`(*)` + - **output:** features :math:`(*, F_{out})` """ def __init__(self, in_channels: int, out_channels: int, num_types: int, **kwargs): @@ -187,7 +200,11 @@ def reset_parameters(self): lin.reset_parameters() def forward(self, x: Tensor, type_vec: Tensor) -> Tensor: - """""" + r""" + Args: + x (Tensor): The input features. + type_vec (LongTensor): A vector that maps each entry to a type. + """ out = x.new_empty(x.size(0), self.out_channels) for i, lin in enumerate(self.lins): mask = type_vec == i