Skip to content

Commit

Permalink
Refact: FORCE class split in two: RLS and LMS
Browse files Browse the repository at this point in the history
  • Loading branch information
nTrouvain committed May 17, 2022
1 parent cec1c2f commit a156103
Show file tree
Hide file tree
Showing 13 changed files with 709 additions and 115 deletions.
21 changes: 16 additions & 5 deletions reservoirpy/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,25 @@
NVAR - Non-linear Vector Autoregressive machine (NG-RC)
IPReservoir - Reservoir with intrinsic plasticy learning rule
Readouts
========
Offline readouts
================
.. autosummary::
:toctree: generated/
:template: autosummary/class.rst
Ridge - Layer of neurons connected through offline linear regression.
FORCE - Layer of neurons connected through online FORCE learning.
Online readouts
===============
.. autosummary::
:toctree: generated/
:template: autosummary/class.rst
LMS - Layer of neurons connected through least mean squares learning rule.
RLS - Layer of neurons connected through recursive least squares learning rule.
FORCE - Layer of neurons connected through online learning rules.
Optimized ESN
=============
Expand Down Expand Up @@ -74,17 +84,18 @@
from .activations import Identity, ReLU, Sigmoid, Softmax, Softplus, Tanh
from .concat import Concat
from .esn import ESN
from .force import FORCE
from .io import Input, Output
from .readouts import FORCE, LMS, RLS, Ridge
from .reservoirs import NVAR, IPReservoir, Reservoir
from .ridge import Ridge

__all__ = [
"Reservoir",
"Input",
"Output",
"Ridge",
"FORCE",
"LMS",
"RLS",
"Tanh",
"Softmax",
"Softplus",
Expand Down
2 changes: 0 additions & 2 deletions reservoirpy/nodes/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from ..node import Node

# from ..utils.validation import check_node_io


def concat_forward(concat: Node, data):
axis = concat.axis
Expand Down
2 changes: 1 addition & 1 deletion reservoirpy/nodes/esn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from ..utils.parallel import get_joblib_backend
from ..utils.validation import is_mapping
from .io import Input
from .readouts import Ridge
from .reservoirs import NVAR, Reservoir
from .ridge import Ridge

_LEARNING_METHODS = {"ridge": Ridge}

Expand Down
6 changes: 6 additions & 0 deletions reservoirpy/nodes/readouts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .force import FORCE
from .lms import LMS
from .ridge import Ridge
from .rls import RLS

__all__ = ["FORCE", "RLS", "LMS", "Ridge"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Copyright: Xavier Hinaut (2018) <xavier.hinaut@inria.fr>
import numpy as np

from ..node import Node
from ..utils.validation import add_bias, check_vector
from ...node import Node
from ...utils.validation import add_bias, check_vector


def _initialize_readout(
Expand Down Expand Up @@ -106,3 +106,10 @@ def _split_and_save_wout(node, wo):
node.set_param("bias", bias)
else:
node.set_param("Wout", wo)


def _compute_error(node, x, y=None):
"""Error between target and prediction."""
prediction = node.state()
error = prediction - y
return error, x.T
121 changes: 27 additions & 94 deletions reservoirpy/nodes/force.py → reservoirpy/nodes/readouts/force.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,31 @@
# Author: Nathan Trouvain at 16/08/2021 <nathan.trouvain@inria.fr>
# Licence: MIT License
# Copyright: Xavier Hinaut (2018) <xavier.hinaut@inria.fr>
import warnings
from functools import partial
from numbers import Number
from typing import Iterable

import numpy as np

from ..mat_gen import zeros
from ..node import Node
from .utils import (
_assemble_wout,
_initialize_readout,
_prepare_inputs_for_learning,
_split_and_save_wout,
readout_forward,
)
from ...mat_gen import zeros
from ...node import Node
from .base import readout_forward
from .lms import initialize as initialize_lms
from .lms import train as lms_like_train
from .rls import initialize as initialize_rls
from .rls import train as rls_like_train

RULES = ("lms", "rls")


def _rls_like_rule(P, r, e):
"""Recursive Least Squares learning rule."""
k = np.dot(P, r)
rPr = np.dot(r.T, k)
c = float(1.0 / (1.0 + rPr))
P = P - c * np.outer(k, k)

dw = -c * np.outer(e, k)

return dw, P


def _lms_like_rule(alpha, r, e):
"""Least Mean Squares learning rule."""
# learning rate is a generator to allow scheduling
dw = -next(alpha) * np.outer(e, r)
return dw


def _compute_error(node, x, y=None):
"""Error between target and prediction."""
prediction = node.state()
error = prediction - y
return error, x.T


def rls_like_train(node: "FORCE", x, y=None):
"""Train a readout using RLS learning rule."""
x, y = _prepare_inputs_for_learning(x, y, bias=node.input_bias, allow_reshape=True)

error, r = _compute_error(node, x, y)

P = node.P
dw, P = _rls_like_rule(P, r, error)
wo = _assemble_wout(node.Wout, node.bias, node.input_bias)
wo = wo + dw.T

_split_and_save_wout(node, wo)

node.set_param("P", P)


def lms_like_train(node: "FORCE", x, y=None):
"""Train a readout using LMS learning rule."""
x, y = _prepare_inputs_for_learning(x, y, bias=node.input_bias, allow_reshape=True)

error, r = _compute_error(node, x, y)

alpha = node._alpha_gen
dw = _lms_like_rule(alpha, r, error)
wo = _assemble_wout(node.Wout, node.bias, node.input_bias)
wo = wo + dw.T

_split_and_save_wout(node, wo)


def initialize_rls(
readout: "FORCE", x=None, y=None, init_func=None, bias_init=None, bias=None
):

_initialize_readout(readout, x, y, init_func, bias_init, bias)

if x is not None:
input_dim, alpha = readout.input_dim, readout.alpha

if readout.input_bias:
input_dim += 1

P = np.eye(input_dim) / alpha

readout.set_param("P", P)


def initialize_lms(
readout: "FORCE", x=None, y=None, init_func=None, bias_init=None, bias=None
):

_initialize_readout(readout, x, y, init_func, bias_init, bias)


class FORCE(Node):
"""Single layer of neurons learning connections through online learning rules.
Warning
-------
This class is deprecated since v0.3.4 and will be removed in future versions.
Please use :py:class:`~reservoirpy.LMS` or :py:class:`~reservoirpy.RLS` instead.
The learning rules involved are similar to Recursive Least Squares (``rls`` rule)
as described in [1]_ or Least Mean Squares (``lms`` rule, similar to Hebbian
learning) as described in [2]_.
Expand All @@ -120,7 +43,8 @@ class FORCE(Node):
:py:attr:`FORCE.hypers` **list**
================== =================================================================
``alpha`` Learning rate (:math:`\\alpha`) (:math:`1\\cdot 10^{-6}` by default).
``alpha`` Learning rate (:math:`\\alpha`) (:math:`1\\cdot 10^{-6}` by
default).
``input_bias`` If True, learn a bias term (True by default).
``rule`` One of RLS or LMS rule ("rls" by default).
================== =================================================================
Expand All @@ -136,12 +60,14 @@ class FORCE(Node):
at each timestep.
rule : {"rls", "lms"}, default to "rls"
Learning rule applied for online training.
Wout : callable or array-like of shape (units, targets), default to :py:func:`~reservoirpy.mat_gen.zeros`
Wout : callable or array-like of shape (units, targets), default to
:py:func:`~reservoirpy.mat_gen.zeros`
Output weights matrix or initializer. If a callable (like a function) is
used, then this function should accept any keywords
parameters and at least two parameters that will be used to define the shape of
the returned weight matrix.
bias : callable or array-like of shape (units, 1), default to :py:func:`~reservoirpy.mat_gen.zeros`
bias : callable or array-like of shape (units, 1), default to
:py:func:`~reservoirpy.mat_gen.zeros`
Bias weights vector or initializer. If a callable (like a function) is
used, then this function should accept any keywords
parameters and at least two parameters that will be used to define the shape of
Expand Down Expand Up @@ -175,6 +101,13 @@ def __init__(
name=None,
):

warnings.warn(
"'FORCE' is deprecated since v0.3.4 and will be removed "
"in "
"future versions. Consider using 'RLS' or 'LMS'.",
DeprecationWarning,
)

params = {"Wout": None, "bias": None}

if rule not in RULES:
Expand Down

0 comments on commit a156103

Please sign in to comment.