# Implementing dual input/output units and error masking

This module introduces two additional concepts common in cognitive neural network models. First, visible units (where activation patterns are directly specified by the environment) should be able to both encode input from the environment *and* target values for learning. Second, often we do not want all output units to always contribute to model loss&mdash;perhaps on a given training example and each timestep we only want some subset of output units to get error. Neither functionality is straighforwardly implemented in the development environment; the current module illustrates how the toolbox supports both functions, and how they can be built from scratch with tools introduced earlier.

## The hub-and-spokes model as an example

The hub-and-spokes model is a model of cross-modal semantic representation and processing that has been useful for understanding healthy and disordered semantic processing (see [Rogers et al., 2004](http://concepts.psych.wisc.edu/papers/RogersETAL04_PR.pdf)). The model can take either visual or verbal inputs. Units in the visual layer locally encode visual features of objects, while those in the verbal layer locally encode propositional statements about objects (e.g. *is big*, *is furry*, *can fly*, *is a bird*, etc.). The model receives direct visual or verbal input about an object and must then generate other, unspecified visual or verbal information about the item.

Like PMSP, this model is fully continuous and recurrent. In PMSP, however, the model always takes an orthographic input and returns a phonological output. In hub-and-spokes, both the visual and the verbal layers sometimes serve as input (environment directly stimulates these units), and other times as outputs (environment provides targets for learning on these units). Many other models in cognitive neuroscience employ architectures of this kind, which are not straightforward to implement in Keras/Tensorflow. This module illustrates one approach, using the hub-and-spokes model as an example.

The figure below shows how the model is typically depicted in figures (top) and how the computation can be unrolled in time, using multi-input time-averaging as the activation function. The orange squares are visible layers that get direct input from the training/testing environment and also have target values specified. The blue rectangles are input layers that can be hard-clamped. Each sends a single connection with a fixed positive weight to the corresponding output unit, so when an input unit is active it essentially provides positive input to just the corresponding output unit. The output unit takes this along with other inputs and sets its activation according to the usual activation update function. Only the output units get targets.

The unfolded illustration makes it clear that target values must be applied to layers that are *not* the final layer of the model (ie, they are not the last to compute activations within a time-step). In each time step, units in the orange layers should be updated before the hidden units (white circle). Otherwise the model is similar to those developed in earlier modules. A key question is how to specify target values for layers in an RNN that are _not_ the last layer specified.


Model summary:

Forward pass: take whatever information available at hand, and output hub and spokes activations all the time.

Backward pass: similar to forward pass, take whatever available as training signal, inject loss at appropriate units (ignoring the masking value) then sum across all axis.

![model architechture](https://drive.google.com/uc?id=1mN62Dn_ugWSEYF007FPGRRmEk-QVZRWN)

### Create Hub-and-Spokes model using high-level API


In [None]:
!pip install connectionist

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from connectionist.models import HubAndSpokes
from connectionist.losses import MaskedBinaryCrossEntropy

In [None]:
model = HubAndSpokes(
    tau = 0.1, 
    hub_name = 'semantics',
    hub_units = 64, 
    spoke_names = ['visual_descriptors', 'visual_features'],  # you can have more than 2 spokes
    spoke_units = [152, 64]  # must match the no. of spoke_names
    )

### A convienient function to get data

In [None]:
def get_data(x_names, y_names, 
             batch_size = 100, max_ticks = 10, 
             visual_descriptor_units = 152, visual_feature_units = 64, semantic_units = 64):
    """Get dummy data for hub-and-spokes model."""

    data = {
        'semantics': tf.random.uniform((batch_size, max_ticks, semantic_units)),
        'visual_descriptors': tf.random.uniform((batch_size, max_ticks, visual_descriptor_units)),
        'visual_features': tf.random.uniform((batch_size, max_ticks, visual_feature_units)),
    }

    return (
        {k: v for k, v in data.items() if k in x_names},
        {k: v for k, v in data.items() if k in y_names}
    )


x_train, y_train = get_data(x_names=['visual_descriptors', 'visual_features'], y_names=['semantics'])


In [None]:
y_pred = model(x_train)
[f"{k}: {v.shape}" for k, v in y_pred.items()]

### Building an entire Hub-and-Spokes (HNS) layer

Now that we know we can just call the `HubAndSpokes()` model at a high level, let's look at what is inside the model in detail. As before we will first consider the RNN cell (what happens in one time step), then how a cell is used to comprise an RNN layer (the loop over time steps).

Note that, for simplicity, below HNS layer example only supports two spokes, but in the `connectionist.layers.HNSLayer` and `connectionist.models.HubAndSpokes` supports any number of spokes, the input arguements are slightly different. 

### Architecture

![model architechture](https://drive.google.com/uc?id=1V_poXKLZ5udEI4Eh0pgOz4zoYEyN-lbO)

### Recap on MultiInputTimeAveraging layer (MITA)

Before using MITA as a building block in the hub-and-spokes model, let's recap how it works. 

see [MITA docs](https://jasonlo.github.io/connectionist/layers/MultiInputTimeAveraging/)

In [None]:
from typing import List, Dict, Union, Optional
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from connectionist.layers import MultiInputTimeAveraging

Using `MultiInputTimeAveraging` to create $a_t = \tau \cdot act(\sum_i x_i + b) + (1-\tau) \cdot a_{t-1}$

In [None]:
mita = MultiInputTimeAveraging(tau=0.1, average_at='after_activation', activation='sigmoid', use_bias=True)

Suppose we have three RNN units receiving net inputs from two different sources (say, two other layers). The RNN units should add the net inputs together and update its activation according to the MITA function. To see this in action, we create two tensors, each capturing the net inputs to the three units from one source. Applying `mita` to a list of those two tensors then updates the activations of the three units:

In [None]:
x1 = tf.ones((1, 3))  # net inputs for three units from one source
x2 = tf.ones((1, 3)) - 0.5  # net inputs to the same three units from another source

mita([x1, x2])

Calling the object on a list containing the two net input tensors generates an output tensor containing the updated activations of units in the `mita` object. 


This layer contains only one bias matrix:

In [None]:
mita.weights

...and stores a vector of states (ie activations), which are re-used as the 'prior state' $a_{t-1}$ in future updates:

In [None]:
mita.states

Why is this the activaton? Since this is the first time the units were updated, the prior state is set to zero by default. Each unit is receiving an input of +1 from x1 and +.5 from x2. The bias weight is set to 0.0 by default, so these are the only inputs, and the total current net input is +1.5. Passing this through the sigmoid function gives an activation $a_t\approx0.8175$&mdash; this is the activation the unit would adopt with instantaneous updating. With time-averaging, however, the unit activation will update to the weighted average of $(1 - \tau) a_{t-1} + \tau a_t$. As noted, the prior activation $a_{t-1}$ is 0.0, and when we created `mita` we set $\tau=0.1$. So, the new activation is $\tau a_t \approx 0.1 * 0.8175 = 0.08175$. 

By default `mita` will retain this activation vector for use as the prior-state in computations for new updates. To clear these states (for instance, at the end of an RNN loop processing a full sequence if you do not want them to persist to a new sequence) , use the `mita` reset method:

In [None]:
mita.reset_states()
print(mita.states)

The following code block uses [list comprehension](https://www.w3schools.com/python/python_lists_comprehension.asp) to simulate repeatedly calling `mita` in an RNN loop, keeping the inputs static for the whole time. It also plots how the activation changes with successive updates in the loop:

In [None]:
y = np.stack([mita([x1, x2]).numpy()[0] for _ in range(30)])  
plt.plot(y)
plt.title("Activation over time.")


Notice that `mita` automatically keeps its current activation, then uses this as the "prior state" when updating its new activation. That is, the activation state persists over multiple calls, unless the reset method is called. Also, recall that the net inputs are specifying that the unit should eventually adopt an activation value of 0.8175. With the time-averaged updating function, you can see that the unit activation is approaching that limit.

What if the input changes? Let's increase it to a very high number by adding 1000 to one of the input sources, then continue updating the RNN activation:

In [None]:
y = np.stack([mita([x1+1000, x2]).numpy()[0] for _ in range(30)])  
plt.plot(y)
plt.title("Activation cap at 1.")

Notice the activation begins near 0.8, as that was the last activation value computed in the cell above. Rather than continuing to flatten out, the activation instead leaps up and approaches 1.0, the upper limit of the sigmoid function.

Now let's return to the original input and continue updating the RNN activation to see how the activation ramps down:

In [None]:
y = np.stack([mita([x1, x2]).numpy()[0] for _ in range(30)])  
plt.plot(y)
plt.title("Ramping down is also slow due to the time-averaging mechanism.")

Finally, let's give the cell a highly negative input:

In [None]:
y = np.stack([mita([x1 - 1000, x2]).numpy()[0] for _ in range(30)])  
plt.plot(y)
plt.title("Ramping down is also slow due to the time-averaging mechanism.")

You can see it now gets pushed toward 0.0, the lower limit of the sigmoid function. For further details, you can refer to the source code.

### Building a Spoke from MultiInputTimeAveraging

With this refresher, let's look at constructing the hub-and-spokes model, beginning with a single "spoke" (ie, a single input/output layer).

In [None]:
class Spoke(tf.keras.layers.Layer):
    """A spoke in the hub-and-spokes model."""

    def __init__(self, tau: float, units: int) -> None:
        super().__init__()
        self.tau = tau
        self.units = units  # Technically can infer from input, but it will make life easier when input is None. i.e., conform with flexible input.  
        self.time_averaging = MultiInputTimeAveraging(tau=self.tau, average_at='after_activation', activation='sigmoid', use_bias=True)

    def call(self, inputs: tf.Tensor=None, cross_tick_states: List[tf.Tensor]=None):  # to avoid confusing with `self.time_averaging.states` (a_{t-1}), I use a new name `cross_tick_states` here to represent the red arrows (cross ticks projection) in the figure.
        """Call the spoke.
        
        Args:
            inputs: clamped input (blue node)
            cross_tick_states: states from the red arrows (cross ticks projection), a_i w_{ij}.
        """
        if inputs is None:
            inputs = tf.zeros((1, self.units))

        if cross_tick_states is None:
            # Green only
            return self.time_averaging([inputs])  
        
        # Red, red, green
        return self.time_averaging([inputs, *cross_tick_states])  # Note that we end up merging inputs and cross_tick_states, spearating them are just for clarity. 
    def reset_states(self):
        self.time_averaging.reset_states()
          


- A `Spoke` is just a thin wrapper of `MultiInputTimeAveraging`

In [None]:
spoke = Spoke(tau=0.1, units=3)

Test the spoke with no input at all

In [None]:
spoke.reset_states()
spoke(None)

In [None]:
[w for w in spoke.weights]

Test spoke with random input

In [None]:
spoke.reset_states()
x = tf.random.uniform((1, 3))
y = [spoke(x).numpy()[0] for _ in range(30)]
y = np.stack(y)
plt.plot(y)
plt.title("Activation over time from random inputs without cross-tick projection.")

Test spoke with random input and random cross-tick projection

In [None]:
spoke.reset_states()
cross_tick_states = [tf.random.uniform((1, 3)) for _ in range(2)]
y = [spoke(x, cross_tick_states).numpy()[0] for _ in range(30)]  # Using the same input as above
plt.plot(y)
plt.title("Activation over time from random inputs with cross-tick projection.")


### Hub-and-spokes cell (HNSCell)

Now, we have a spoke, we can define the compute in one time step of the hub-and-spokes model.

In [None]:
class HNSCell(tf.keras.layers.Layer):

    def __init__(self, tau: float, hub_units: int, spoke1_units: int, spoke2_units:int) -> None:
        super().__init__()
        self.tau = tau
        self.hub_units = hub_units
        self.spoke1_units = spoke1_units
        self.spoke2_units = spoke2_units


    def build(self, input_shape) -> None:

        # Hub
        self.hub = MultiInputTimeAveraging(tau=self.tau, average_at='after_activation', activation='sigmoid', use_bias=True)
        self.w_hh = self.add_weight(shape=(self.hub_units, self.hub_units), initializer="random_normal", trainable=True)  # red
        self.w_s1h = self.add_weight(shape=(self.spoke1_units, self.hub_units), initializer="random_normal", trainable=True)  # blue
        self.w_s2h = self.add_weight(shape=(self.spoke2_units, self.hub_units), initializer="random_normal", trainable=True)  # blue

        # Spoke 1
        self.spoke1 = Spoke(tau=self.tau, units=self.spoke1_units)
        self.w_s1s1 = self.add_weight(shape=(self.spoke1_units, self.spoke1_units), initializer="random_normal", trainable=True) # red
        self.w_hs1 = self.add_weight(shape=(self.hub_units, self.spoke1_units), initializer="random_normal", trainable=True)  # red
        
        # Spoke 2
        self.spoke2 = Spoke(tau=self.tau, units=self.spoke2_units)
        self.w_s2s2 = self.add_weight(shape=(self.spoke2_units, self.spoke2_units), initializer="random_normal", trainable=True)  # red
        self.w_hs2 = self.add_weight(shape=(self.hub_units, self.spoke2_units), initializer="random_normal", trainable=True)  # red

        self.built = True
        
    def call(self, inputs1=None, inputs2=None, last_act_hub = None, last_act_spoke1=None, last_act_spoke2=None) -> List[tf.Tensor]:
        """Returns a list of activations: [act_spoke1, act_spoke2, act_hub]."""

        # Calculate net inputs via red arrows (cross-tick projections)
        hs1 = None if last_act_hub is None else last_act_hub @ self.w_hs1
        s1s1 = None if last_act_spoke1 is None else last_act_spoke1 @ self.w_s1s1
        hs2 = None if last_act_hub is None else last_act_hub @ self.w_hs2
        s2s2 = None if last_act_spoke2 is None else last_act_spoke2 @ self.w_s2s2
        hh = None if last_act_hub is None else last_act_hub @ self.w_hh

        # Calculate spoke activations
        cross_tick_spoke1 = [s for s in [hs1, s1s1] if s is not None]
        act_spoke1 = self.spoke1(inputs1, cross_tick_states=cross_tick_spoke1)

        cross_tick_spoke2 = [s for s in [hs2, s2s2] if s is not None]
        act_spoke2 = self.spoke2(inputs2, cross_tick_states=cross_tick_spoke2)

        # Calculate net inputs via blue arrows (within-tick projection)
        s1h = act_spoke1 @ self.w_s1h  
        s2h = act_spoke2 @ self.w_s2h

        # calculate hub activation
        inputs_to_hub = [s for s in [s1h, s2h, hh] if s is not None]
        act_hub = self.hub(inputs_to_hub)

        return [act_spoke1, act_spoke2, act_hub]

    def reset_states(self) -> None:
        self.hub.reset_states()
        self.spoke1.reset_states()
        self.spoke2.reset_states()


Test cell

In [None]:
cell = HNSCell(tau=0.1, hub_units=5, spoke1_units=3, spoke2_units=3)

No input at all

In [None]:
cell(inputs1=None, inputs2=None)

Having inputs

In [None]:
cell(
    inputs1=tf.random.uniform((1, 3)),
    inputs2=tf.random.uniform((1, 3)),
)

Having inputs and cross-tick projection

In [None]:
cell(
    inputs1=tf.random.uniform((1, 3)),
    inputs2=tf.random.uniform((1, 3)),
    last_act_hub=tf.random.uniform((1, 5)),
    last_act_spoke1=tf.random.uniform((1, 3)),
    last_act_spoke2=tf.random.uniform((1, 3)),
)

### Unrolling the HNSCell to create a HNS layer

In [None]:
class HNSLayer(tf.keras.layers.Layer):

    def __init__(self, tau: float, hub_units: int, spoke1_units: int, spoke2_units:int) -> None:
        super().__init__()
        self.tau = tau
        self.hub_units = hub_units
        self.spoke1_units = spoke1_units
        self.spoke2_units = spoke2_units
        
    def build(self, input_shape) -> None:
        self.cell = HNSCell(tau=self.tau, hub_units=self.hub_units, spoke1_units=self.spoke1_units, spoke2_units=self.spoke2_units)
        self.built = True

    def call(self, inputs1, inputs2) -> List[tf.Tensor]:
        """Returns a list of activations: [act_spoke1, act_spoke2, act_hub]."""

        batch_size = inputs1.shape[0]
        max_ticks = inputs1.shape[1]

        # Initialize activations in hub and spokes
        h = tf.zeros((batch_size, self.hub_units))
        s1 = tf.zeros((batch_size, self.spoke1_units))
        s2 = tf.zeros((batch_size, self.spoke2_units))

        # Make containers for outputs
        output_spoke1 = tf.TensorArray(tf.float32, size=max_ticks)
        output_spoke2 = tf.TensorArray(tf.float32, size=max_ticks)
        output_hub = tf.TensorArray(tf.float32, size=max_ticks)

        for t in range(max_ticks):
            s1, s2, h = self.cell(
                inputs1=inputs1[:, t],
                inputs2=inputs2[:, t],
                last_act_hub=h,
                last_act_spoke1=s1,
                last_act_spoke2=s2,
            )

            output_spoke1 = output_spoke1.write(t, s1)
            output_spoke2 = output_spoke2.write(t, s2)
            output_hub = output_hub.write(t, h)

        self.cell.reset_states()

        output_spoke1 = tf.transpose(output_spoke1.stack(), [1, 0, 2])
        output_spoke2 = tf.transpose(output_spoke2.stack(), [1, 0, 2])
        output_hub = tf.transpose(output_hub.stack(), [1, 0, 2])

        return {
            "hub": output_hub,
            "spoke1": output_spoke1,
            "spoke2": output_spoke2,
        }

In [None]:
manual_model = HNSLayer(tau=0.1, hub_units=5, spoke1_units=3, spoke2_units=4)

manual_model(
    inputs1=tf.random.uniform((1, 10, 3)),
    inputs2=tf.random.uniform((1, 10, 4))
)


We illustrated how to create hub-and-spokes model architecture above.

## Model Training

Since the representations have some un-used filler slots, they are not suppose to providing training signal to the model. Therefore, we need to mask those values when calculating loss. We will use `connectionist.losses.MaskedBinaryCrossEntropy` to do that.

### Using high-level API

In [None]:
from connectionist.losses import MaskedBinaryCrossEntropy

x_train, y_train = get_data(x_names=['visual_descriptors', 'visual_features'], y_names=['semantics'])

loss_fn = MaskedBinaryCrossEntropy(mask_value=9)  
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)

model = HubAndSpokes(
    tau = 0.1, 
    hub_name = 'semantics',
    hub_units = 64, 
    spoke_names = ['visual_descriptors', 'visual_features'], 
    spoke_units = [152, 64]
    )

model.compile(optimizer=optimizer, loss=loss_fn)
history = model.fit(x_train, y_train, epochs=10, batch_size=100)

In [None]:
plt.plot(history.history['loss'])

What's inside the loss function and the training loop looks like if we do it manually?

### Custom loss function: Binary cross-entropy with masking

Binary cross-entropy formula

$$H_p(q) = - \frac{1}{N} \sum_{i=1}^{N}y_i \cdot \log(p(y_i)) + (1-y_i) \cdot \log(1-p(y_i))$$

where $y_i$ is the target, and $p(y_i)$ is the prediction.

In [None]:
class MaskedBinaryCrossEntropy(tf.keras.losses.Loss):
    """Compute Binary Cross-Entropy with mask.

    Args:
        y_true: target y with shape (batch_size, seq_len, feature)
        y_pred: predicted y with shape (batch_size, seq_len, feature)
        mask_value: value in y_true to be masked
    Returns:
        Loss value with shape (batch_size)
    """

    def __init__(
        self,
        mask_value: int = None,
        name="masked_binary_crossentropy",
        reduction="none",
        **kwargs
    ) -> None:
        super().__init__(name=name, reduction=reduction, **kwargs)
        self.mask_value = mask_value

    def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
        epsilon = tf.keras.backend.epsilon()  # very small value to avoid log(0)
        cross_entropy = y_true * tf.math.log(y_pred + epsilon)
        cross_entropy = cross_entropy + (1 - y_true) * tf.math.log(1 - y_pred + epsilon)

        if self.mask_value:
            mask = tf.cast(
                tf.where(y_true == self.mask_value, 0, 1), tf.float32
            )  # create mask
        else:
            mask = tf.ones_like(cross_entropy)  # All inclusive mask if value is none

        cross_entropy = mask * cross_entropy  # zero out the masked values
        cross_entropy = tf.reduce_sum(
            cross_entropy, axis=[1, 2]
        )  # sum over all units (axis 2) and time steps (axis 1)
        return -cross_entropy / (
            epsilon + tf.reduce_sum(mask, axis=[1, 2])
        )  # - (1/N) sum(y * log(p(y)) + (1-y) * log(1-p(y)))

### Optimizer and training function

In [None]:
x_train = {
    'spoke1' : tf.random.uniform((1, 2, 3)),  # clamped input to spoke1
    'spoke2' : tf.random.uniform((1, 2, 4))  # clamped input to spoke2
}


y_train= {
    'hub': tf.convert_to_tensor(
        [[[1, 0, 9, 9, 9], [0, 9, 9, 9, 9]], [[1, 1, 0, 0, 0], [0, 1, 1, 1, 1]]],
        dtype=tf.float32  # target for hub activations
    )
}

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)
masked_cross_entropy = MaskedBinaryCrossEntropy(mask_value=9)

@tf.function
def train_step(x_train: Dict[str, tf.Tensor], y_train: Dict[str, tf.Tensor]) -> tf.Tensor:
    """Custom training loop for HNS model."""

    with tf.GradientTape() as tape:
        y_pred = manual_model(inputs1=x_train['spoke1'], inputs2=x_train['spoke2'], training=True)

        all_losses = []
        for y_name, y_target in y_train.items():  # Only inject error according to y_train (if y_train = {'spoke1': tf.Tensor}, it will only inject error in spoke1)
            sum_loss = tf.reduce_sum(masked_cross_entropy(y_true=y_target, y_pred=y_pred[y_name]))  # reduce over batch_size axis
            all_losses.append(sum_loss)

        loss_value = tf.reduce_sum(all_losses)  # Final loss value is the grand sum over every axis (output layer, batch size, time ticks, units).
        
    grads = tape.gradient(loss_value, manual_model.trainable_weights)  # compute gradients dL/dw
    optimizer.apply_gradients(zip(grads, manual_model.trainable_weights))  # update weights using stock optimizer
    return loss_value

### Train the model manually

In [None]:
loss_history = [train_step(x_train, y_train) for _ in range(100)]
loss_history = [loss_history.numpy() for loss_history in loss_history]
plt.plot(loss_history)