# Continuous RNN with full recurrence
Here we will expand on the continuous recurrent network we developed previously and implement full recurrence between the hidden layer and the output layer. This type of layer structure is sometimes called an "attractor network" because the time-varying nature of the learning process simulated and the "basins of attraction" that develop as a result of the error landscape that results.

The task in this model is the same as in the last module: produce a sequence of letters given a "word" as input (i.e., a one-hot representation of the word). The architecture here is based Simulation 3 from Plaut, McClelland, Seidenberg, & Patterson (1996) in Psychological Review. Their model produced a phonological output given an orthographic input. The model we are working with here does something slightly different - though the architectures are otherwise very similar. In their paper they refer to the output layer as containing "cleanup units" because of the ways in which the recurrent structure affords a "cleanup" process for the phonological outputs in their network.


In [None]:
# Install a pip package created for our course
!pip install connectionist

## Data

The data for this model is identical to the data we used in the last module. While the original PMSP model produced phonological output for an orthographic input, we will continue to simulate the production of a sequence of letters given the one-hot representation of each of 20 words.

- Input: word representations, fixed across time. 
- Output: letter representations, changing across time.

In [None]:
from connectionist.data import ToyOP

# Data is identical to the last module, import from connectionist.data
data = ToyOP()

# Give a shorter name for easy access
letters = data.letters # this data can be interpreted as letters or sounds. In this example, there isn't a meaningful distinction.
words = data.words
x_train = data.x_train
y_train = data.y_train


## Model creation

We will start by calling the `PMSP` model from the `connectionist` package, but we will break it down into individual layers later so you can build it more from scratch yourself. This allows you to work with a higher level wrapper around the code to get a feel for things. Note that we set the value for `tau` here in the call to the `PMSP` layer, which you were exposed to in the module on multi input time-averaging previously.

The `PMSP` model is single layer model that contains what we will call a `PMSPLayer` (i.e., `model = Sequential([PMSPLayer(...)])`). The `PMSPCell` is the building block of the `PMSPLayer`, which defines the computing of a single time step in the recurrent process.

In [None]:
import tensorflow as tf

# Instead of using `connectionist.layers.PMSPLayer` and build the model by Sequential API, 
# we can directly import the entire PMSP model.
from connectionist.models import PMSP
model = PMSP(tau=0.1, h_units=10, p_units=9, c_units=5)

### Information flow in the model

In the model, one tick of time is required for information to propagate from the one layer to another. This is the result of the "continuous time" nature of the information flow in the network. For example, if we start at time $t$ the minimal travel distance from orthographic input ($O$) via the hidden layer ($H$) to the phonological layer ($P$) is 2 ticks: $O$ (at $t$) -> $H$ (at $t+1$) -> $P$ (at $t+2$). The attractor cycle (the portion of the network consisting of $P$ and the cleanup layer, $C$) also requires 2 ticks (again starting at $t$): $P$ (at $t$) -> $C$ (at $t+1$) -> $P$ (at $t+2$).

![model architechture](https://drive.google.com/uc?id=1oUze7Mx-ue7QaaCe90-H-a--Ugsvl-dx)

Now, let's tease apart the structure of a `PMSPCell` so you can have a sense of how the layers are structured. As in other model APIs we've seen, the `PMSPCell` layer has a `build()` and a `call()` method, along with `reset_states()`. The layer structure within this cell is considerably more complex than previous cells we've worked with because the information flow is more complex than we've seen previously. For example, the phonology layer recieves inputs from the hidden layer, the cleanup layer, and itself through time.

In [None]:
from connectionist.layers import MultiInputTimeAveraging, TimeAveragedDense

class PMSPCell(tf.keras.layers.Layer):
    """RNN cell for PMSP model.
    
    See Plaut, McClelland, Seidenberg and Patterson (1996), simulation 3. 
    """
    def __init__(self, tau, h_units, p_units, c_units):
        super().__init__()
        self.tau = tau
        self.h_units = h_units
        self.p_units = p_units
        self.c_units = c_units

    def build(self, input_shape):
        # Hidden layer
        self.o2h = tf.keras.layers.Dense(self.h_units, activation=None, use_bias=False, name='o2h')  # w_oh
        self.p2h = tf.keras.layers.Dense(self.h_units, activation=None, use_bias=False, name='p2h')  # w_ph
        self.time_averaging_h = MultiInputTimeAveraging(tau=self.tau, average_at="after_activation", activation='sigmoid', name='ta_h')  # bias_h and the time averaging mechanism

        # Phonology layer
        self.h2p = tf.keras.layers.Dense(self.p_units, activation=None, use_bias=False, name='h2p')  # w_hp
        self.p2p = tf.keras.layers.Dense(self.p_units, activation=None, use_bias=False, name='p2p')  # w_pp
        self.c2p = tf.keras.layers.Dense(self.p_units, activation=None, use_bias=False, name='c2p')  # w_cp
        self.time_averaging_p = MultiInputTimeAveraging(tau=self.tau, average_at="after_activation", activation='sigmoid', name='ta_p') # bias_p and the time averaging mechanism

        # Cleanup layer
        self.p2c = TimeAveragedDense(tau=self.tau, average_at="after_activation", units=self.c_units, activation='sigmoid', name='p2c')  # w_pc, bias_c, and the time averaging mechanism
        self.built = True

    def call(self, last_o, last_h, last_p, last_c):
        # Hidden layer activation
        # h_t = tau(act(o_{t-1} @ w_oh + p_{t-1} @ w_ph + bias_h)) + (1 - tau) * h_{t-1}
        oh = self.o2h(last_o)  
        ph = self.p2h(last_p)
        h = self.time_averaging_h([oh, ph])

        # Phonology layer activation
        # p_t = tau(act(h_{t-1} @ w_hp + p_{t-1} @ w_pp + c_{t-1} @ w_cp + bias_p)) + (1 - tau) * p_{t-1}
        hp = self.h2p(last_h)
        pp = self.p2p(last_p)
        cp = self.c2p(last_c)
        p = self.time_averaging_p([hp, pp, cp])  
        
        # Cleanup layer activation
        # c_t = tau(act(p_{t-1} @ w_pc + bias_c)) + (1 - tau) * c_{t-1}
        c = self.p2c(last_p)  

        return h, p, c

    def reset_states(self):  # TODO: This need another name?
        """Reset time averaging history."""
        self.time_averaging_p.reset_states()
        self.time_averaging_h.reset_states()
        self.p2c.reset_states()

- `PMSP` is a loop that unrolls the `PMSPCell` for `n_steps` times.
- The inputs of `PMSP` is the orthographic representation of the word.
- The outputs of `PMSP` typically is the sequence of phonological representations, but we are using it to output the sequence of letter representations in this example. 

In [None]:
class PMSPLayer(tf.keras.layers.Layer):
    """PMSP sim 3 model.
    
    See Plaut, McClelland, Seidenberg and Patterson (1996), simulation 3. 
    """
    def __init__(self, tau, h_units, p_units, c_units) -> None:
        super().__init__()
        self.tau = tau
        self.h_units = h_units
        self.p_units = p_units
        self.c_units = c_units

    def build(self, input_shape):
        self.cell = PMSPCell(tau=self.tau, h_units=self.h_units, p_units=self.p_units, c_units=self.c_units)
        self.built = True

    def call(self, inputs):
        batch_size, max_ticks, o_units = inputs.shape

        # Initialize states
        h = tf.zeros((batch_size, self.cell.h_units))
        p = tf.zeros((batch_size, self.cell.p_units))
        c = tf.zeros((batch_size, self.cell.c_units))

        # Containers for outputs with shape (batch_size, max_ticks, units)
        outputs_h = tf.TensorArray(dtype=tf.float32, size=max_ticks)
        outputs_p = tf.TensorArray(dtype=tf.float32, size=max_ticks)
        outputs_c = tf.TensorArray(dtype=tf.float32, size=max_ticks)

        # Run RNN (Unrolling RNN Cell)
        for t in range(max_ticks):
            o = inputs[:, t]  # for next time tick
            h, p, c = self.cell(last_o=o, last_h=h, last_p=p, last_c=c)
            outputs_h = outputs_h.write(t, h)
            outputs_p = outputs_p.write(t, p)
            outputs_c = outputs_c.write(t, c)

        self.cell.reset_states()
        
        outputs_h = outputs_h.stack()
        outputs_p = outputs_p.stack()
        outputs_c = outputs_c.stack()
        
        outputs_h = tf.transpose(outputs_h, [1, 0, 2])
        outputs_p = tf.transpose(outputs_p, [1, 0, 2])
        outputs_c = tf.transpose(outputs_c, [1, 0, 2])
        return outputs_p

- Note that, the source code in `connectionist.layers` is not identical to the above code, which is out of the scope of this module. However, the core compute logic is the same.

## Model training

In [None]:
loss_fn = tf.keras.losses.BinaryCrossentropy()

# Using a newer optimizer for faster training speed
optimizer = tf.keras.optimizers.Adam(learning_rate=0.05)  

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(x_train, y_train, epochs=300, batch_size=10)  # batch_size must be provided in PMSP model

## Model evaluation

In [None]:
import numpy as np

def decode_prediction(y, idx=None):
    """Decode the prediction."""

    decoded = ''.join([letters[np.argmax(v)] for v in y])
    return decoded

In [None]:
y_pred = model(x_train)

for i, word in enumerate(words):
    print(f"word: {word}; pred: {decode_prediction(y_pred['phonology'][i])}")

- Notice the first letter remains constant? This is because the first letter don't have any information from O yet. Information from O start to reach P at the second time tick. 
- The model may still struggle to predict the correct letter at tick 2, since the new information is significantly slowed down by the time-averaging mechanism. 

### Temporal dynamics of phonology output

TODO: I think this a a good place to show more internal dynamics of the model. Discuss what to show here. e.g., hidden, cleanup. Relative input strength of PP, HP, and CP? 
TODO: Also, think about what is the best interface to show the internal dynamics. e.g., model(x, return_internal=True)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Get first prediction `cat` in proper type and shape
pred0 = y_pred['phonology'][0].numpy().squeeze()

# How the prediction changes over time
fig, ax = plt.subplots()
ax.plot(pred0)
ax.legend(letters)
ax.set_xlabel('Time ticks')
ax.set_ylabel('Unit activation')
ax.set_title('Predicted temporal dynamics of the word "cat"')

### Relative input strength to phonology

Remember that:

$p_t = \tau \times act(input_p) + (1-\tau) p_{t-1}$

$input_p = h_{t-1} \cdot w_{hp} + p_{t-1} \cdot w_{pc} + c_{t-1} \cdot w_{cp} + b_p$

We can break down $p_t$ into 3 parts:

1. input from hidden to phonology = $h_{t-1} \cdot w_{hp}$
2. input from phonology to phonology = $p_{t-1} \cdot w_{pp}$
3. input from cleanup to phonology = $c_{t-1} \cdot w_{cp}$



Let's get the internal dynamics of the model

In [None]:
model(x_train, return_internals=True)

In [None]:
all_outputs = model(x_train, return_internals=True)
print(f"{all_outputs.keys() = }")

See the docs for more details on `all_outputs`.

In [None]:
print(PMSP.__doc__)

In [None]:
def plot_input_to_phonology_layer(internal_outputs, sample_idx, unit_idx):
    """Plotting the inputs to phonology in selected sample and unit."""

    hidden = internal_outputs['hp'][sample_idx, :, unit_idx].numpy()
    phonology = internal_outputs['pp'][sample_idx, :, unit_idx].numpy()
    cleanup = internal_outputs['cp'][sample_idx, :, unit_idx].numpy()

    fig, ax = plt.subplots()
    ax.plot(hidden, label='hidden to phonology')
    ax.plot(phonology, label='phonology to phonology')
    ax.plot(cleanup, label='cleanup to phonology')
    ax.set_xlabel('Time ticks')
    ax.set_ylabel('Input strength')
    ax.legend()
    ax.set_title('Relative input strength to phonology layer')



What is the main source of information that drives the model to produce the letter `c`?

In [None]:
plot_input_to_phonology_layer(all_outputs, sample_idx=words.index('cat'), unit_idx=letters.index('c'))

How about letter `a`?

In [None]:
plot_input_to_phonology_layer(all_outputs, sample_idx=words.index('cat'), unit_idx=letters.index('a'))

And the letter `t`?

In [None]:
plot_input_to_phonology_layer(all_outputs, sample_idx=words.index('cat'), unit_idx=letters.index('t'))

Perhaps mention division of labor and why it is important/interesting in general. 

- What main message we wants to make... 
- Discuss other internals visualization, e.g., 2d projection of hidden layer activation over training by word
    - Which axis to reduce during the 2d projection
        - p_units? hidden_units?
        - time_axis? and how. 



## Illustrating hidden layer representation over time ticks

- tick by tick act in all word @ hidden
- illustrating the representation in the hidden layer


In [None]:
all_outputs['hidden'].shape  # Tim will help here.

Steps
1. TSNE to 2d at unit axis
2. Plot each word in TSNE space
3. Animate over timetick axis

In [None]:
!pip install scikit-learn

In [None]:
import plotly.express as px
import pandas as pd
import numpy as np
from sklearn.manifold import TSNE

### Helper functions

In [None]:
def tsne(x: tf.Tensor) -> np.ndarray:
    """Apply TSNE to output tensor on units axis."""

    batch_size, max_ticks, units = x.shape
    a = x.numpy().reshape((batch_size * max_ticks, units))

    def _apply_tsne(array: np.ndarray) -> np.ndarray:
        """Apply 2d TSNE."""
        return TSNE(n_components=2).fit_transform(array)

    return _apply_tsne(a).reshape((batch_size, max_ticks, 2))


def array2df(array: np.ndarray) -> pd.DataFrame:
    """Flatten, label, normalize and cast to dataframe."""

    df = pd.DataFrame()
    for i, word in enumerate(words):
        for t in range(30):
            case_df = pd.DataFrame(
                {
                    "word": word,
                    "timetick": t,
                    "tsne_a1": array[i, t, 0],
                    "tsne_a2": array[i, t, 1]
                },
                index = [0]
            )
            df = pd.concat([df, case_df], ignore_index=True)

    # Character labels
    df['char1'] = df.word.str[0]
    df['char2'] = df.word.str[1]
    df['char3'] = df.word.str[2]

    # Normalize
    def normalize(x:pd.Series) -> pd.Series:
        """Normalize to 0-1 range."""
        return (x-min(x))/(max(x)-min(x))

    df['tsne_a1'] = normalize(df.tsne_a1)
    df['tsne_a2'] = normalize(df.tsne_a2)

    return df



Test on phonological representation to make sure reshaping is somewhat correct? 
TODO: Need extra checking in reshape...

In [None]:
# Run the tidying pipeline
p_tsne = tsne(all_outputs['phonology'])
df = array2df(p_tsne)
df

Plot the animated 2d tsne scatter 

In [None]:
px.scatter(
    df, x="tsne_a1", y="tsne_a2", 
    animation_frame="timetick", animation_group="word",
    color="char2", hover_name="word", 
    range_x=[0, 1], range_y=[0,1], 
    width=800, height=800,
    title="How predicted phonological representation cluster over time"
)

- Character 2 clustering is very good (group together between 10-20 ticks, where it should be)

Plot the same thing, but on hidden representation

In [None]:
df = array2df(tsne(all_outputs['hidden']))

px.scatter(
    df, x="tsne_a1", y="tsne_a2", 
    animation_frame="timetick", animation_group="word",
    color="char2", hover_name="word", 
    range_x=[0, 1], range_y=[0,1], 
    width=800, height=800,
    title="How hidden representation cluster over time"
)