Skip to content

Commit

Permalink
use Sequence to replace Tuple in type hitting
Browse files Browse the repository at this point in the history
  • Loading branch information
TsumiNa committed Jun 1, 2021
1 parent fd3857c commit 93a5c02
Showing 1 changed file with 30 additions and 23 deletions.
53 changes: 30 additions & 23 deletions xenonpy/model/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# license that can be found in the LICENSE file.

import math
from typing import Union, Tuple, Callable, Any, Optional
from typing import Union, Sequence, Callable, Any, Optional

from torch import nn

Expand All @@ -16,11 +16,14 @@ class LinearLayer(nn.Module):
See here for details: http://pytorch.org/docs/master/nn.html#
"""

def __init__(self, in_features: int, out_features: int, bias: bool = True, *,
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
*,
dropout: float = 0.,
activation_func: Callable = nn.ReLU(),
normalizer: Union[float, None] = .1
):
normalizer: Union[float, None] = .1):
"""
Parameters
----------
Expand Down Expand Up @@ -59,13 +62,18 @@ class SequentialLinear(nn.Module):
"""

def __init__(self, in_features: int, out_features: int, bias: bool = True, *,
h_neurons: Union[Tuple[float, ...], Tuple[int, ...]] = (),
h_bias: Union[bool, Tuple[bool, ...]] = True,
h_dropouts: Union[float, Tuple[float, ...]] = 0.1,
h_normalizers: Union[float, None, Tuple[Optional[float], ...]] = 0.1,
h_activation_funcs: Union[Callable, None, Tuple[Optional[Callable], ...]] = nn.ReLU(),
):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
*,
h_neurons: Union[Sequence[float], Sequence[int]] = (),
h_bias: Union[bool, Sequence[bool]] = True,
h_dropouts: Union[float, Sequence[float]] = 0.1,
h_normalizers: Union[float, None, Sequence[Optional[float]]] = 0.1,
h_activation_funcs: Union[Callable, None, Sequence[Optional[Callable]]] = nn.ReLU(),
):
"""
Parameters
Expand Down Expand Up @@ -111,25 +119,24 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, *,
bias = (bias,) + self._check_input(h_bias)

for i in range(self._h_layers):
setattr(self, f'layer_{i}', LinearLayer(
in_features=neurons[i],
out_features=neurons[i + 1],
bias=bias[i],
dropout=dropouts[i],
activation_func=activation_funcs[i],
normalizer=normalizers[i]
))
setattr(
self, f'layer_{i}',
LinearLayer(in_features=neurons[i],
out_features=neurons[i + 1],
bias=bias[i],
dropout=dropouts[i],
activation_func=activation_funcs[i],
normalizer=normalizers[i]))

self.output = nn.Linear(neurons[-1], out_features, bias[-1])
else:
self.output = nn.Linear(in_features, out_features, bias)

def _check_input(self, i):
if isinstance(i, Tuple):
if isinstance(i, Sequence):
if len(i) != self._h_layers:
raise RuntimeError(
f'number of parameter not consistent with number of layers, '
f'input is {len(i)} but need to be {self._h_layers}')
raise RuntimeError(f'number of parameter not consistent with number of layers, '
f'input is {len(i)} but need to be {self._h_layers}')
return tuple(i)
else:
return tuple([i] * self._h_layers)
Expand Down

0 comments on commit 93a5c02

Please sign in to comment.