<a href="https://colab.research.google.com/github/trendinafrica/Comp_Neuro-ML_course/blob/main/FPF_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! git clone https://github.com/mattgolub/fixed-point-finder.git
! git clone https://github.com/mattgolub/recurrent-whisperer.git

Cloning into 'fixed-point-finder'...
remote: Enumerating objects: 782, done.[K
remote: Counting objects: 100% (175/175), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 782 (delta 156), reused 149 (delta 145), pack-reused 607[K
Receiving objects: 100% (782/782), 501.80 KiB | 2.20 MiB/s, done.
Resolving deltas: 100% (474/474), done.
Cloning into 'recurrent-whisperer'...
remote: Enumerating objects: 974, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (33/33), done.[K
remote: Total 974 (delta 19), reused 29 (delta 10), pack-reused 931[K
Receiving objects: 100% (974/974), 437.36 KiB | 2.06 MiB/s, done.
Resolving deltas: 100% (636/636), done.


In [None]:
! pip install numpy==1.24.3 scipy==1.10.1 scikit-learn==1.2.2 matplotlib==3.7.1 PyYAML==6.0 tensorflow==2.8.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy==1.24.3
  Downloading numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.3/17.3 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
Collecting tensorflow==2.8.0
  Downloading tensorflow-2.8.0-cp310-cp310-manylinux2010_x86_64.whl (497.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m497.6/497.6 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Collecting keras-preprocessing>=1.1.1 (from tensorflow==2.8.0)
  Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.6/42.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting tensorboard<2.9,>=2.8 (from tensorflow==2.8.0)
  Downloading tensorboard-2.8.0-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [None]:
'''
run_flipflop.py
Written for Python 3.6.9 and TensorFlow 2.8.0
@ Matt Golub, October 2018
Please direct correspondence to mgolub@cs.washington.edu
'''

import sys, os
import numpy as np

In [None]:
addpath = lambda dir: sys.path.insert(0, os.path.join('/content/', dir))
addpath('recurrent-whisperer')
addpath('fixed-point-finder')
addpath('fixed-point-finder/example')

In [None]:
from FlipFlop import FlipFlop
from FixedPointFinder import FixedPointFinder
from FixedPoints import FixedPoints
from plot_utils import plot_fps

No display found. Using non-interactive Agg backend.
No display found. Using non-interactive Agg backend.
No display found. Using non-interactive Agg backend.


In [None]:
def train_FlipFlop(train_mode):
    ''' Train an RNN to solve the N-bit memory task.

        Args:
            train_mode: 1, 2, or 3.

                1.  Generate on-the-fly training data (new data for each
                    gradient step)
                2.  Provide a single, fixed set of training data.
                3.  Provide, single, fixed set of training data (as in 2) and
                    a single, fixed set of validation data.

                (see docstring to RecurrentWhisperer.train() for more detail)

        Returns:
            model: FlipFlop object.

                The trained RNN model.

            valid_predictions: dict.

                The model's predictions on a set of held-out validation trials.
    '''

    assert train_mode in [1, 2, 3], \
        ('train_mode must be 1, 2, or 3, but was %s' % str(train_mode))

    # Hyperparameters for FlipFlop
    # See FlipFlop.py for detailed descriptions.
    hps = {
            'rnn_type': 'lstm',
            'n_hidden': 16,
            'min_loss': 1e-4,
            'log_dir': './logs/',
            'do_generate_pretraining_visualizations': True,

            'data_hps': {
                'n_batch': 512,
                'n_time': 64,
                'n_bits': 3,
                'p_flip': 0.5
                },

            # Hyperparameters for AdaptiveLearningRate
            'alr_hps': {
                'initial_rate': 1.0,
                'min_rate': 1e-5
                }
            }

    model = FlipFlop(**hps)

    train_data = model.generate_data()
    valid_data = model.generate_data()

    if train_mode == 1:
        model.train()
    elif train_mode == 2:
        # This runs much faster at the expense of overfitting potential
        model.train(train_data)
    elif train_mode == 3:
        # This requires some changes to hps to fully leverage validation
        model.train(train_data, valid_data)

    # Get example state trajectories from the network
    # Visualize inputs, outputs, and RNN predictions from example trials
    valid_predictions, valid_summary = model.predict(valid_data)
    model.plot_trials(valid_data, valid_predictions)

    return model, valid_predictions

def find_fixed_points(model, valid_predictions):
    ''' Find, analyze, and visualize the fixed points of the trained RNN.

    Args:
        model: FlipFlop object.

            Trained RNN model, as returned by train_FlipFlop().

        valid_predictions: dict.

            Model predictions on validation trials, as returned by
            train_FlipFlop().

    Returns:
        None.
    '''

    '''Initial states are sampled from states observed during realistic
    behavior of the network. Because a well-trained network transitions
    instantaneously from one stable state to another, observed networks states
    spend little if any time near the unstable fixed points. In order to
    identify ALL fixed points, noise must be added to the initial states
    before handing them to the fixed point finder. In this example, the noise
    needed is rather large, which can lead to identifying fixed points well
    outside of the domain of states observed in realistic behavior of the
    network--such fixed points can be safely ignored when interpreting the
    dynamical landscape (but can throw visualizations).'''

    NOISE_SCALE = 0.5 # Standard deviation of noise added to initial states
    N_INITS = 1024 # The number of initial states to provide

    n_bits = model.hps.data_hps['n_bits']
    is_lstm = model.hps.rnn_type == 'lstm'

    '''Fixed point finder hyperparameters. See FixedPointFinder.py for detailed
    descriptions of available hyperparameters.'''
    fpf_hps = {}

    # Setup the fixed point finder
    fpf = FixedPointFinder(model.rnn_cell, model.session, **fpf_hps)

    # Study the system in the absence of input pulses (e.g., all inputs are 0)
    inputs = np.zeros([1,n_bits])

    '''Draw random, noise corrupted samples of those state trajectories
    to use as initial states for the fixed point optimizations.'''
    initial_states = fpf.sample_states(valid_predictions['state'],
        n_inits=N_INITS,
        noise_scale=NOISE_SCALE)

    # Run the fixed point finder
    unique_fps, all_fps = fpf.find_fixed_points(initial_states, inputs)

    # Visualize identified fixed points with overlaid RNN state trajectories
    # All visualized in the 3D PCA space fit the the example RNN states.
    fig = plot_fps(unique_fps, valid_predictions['state'],
        plot_batch_idx=list(range(30)),
        plot_start_time=10)

In [None]:
train_mode = 1

# Step 1: Train an RNN to solve the N-bit memory task
model, valid_predictions = train_FlipFlop(train_mode)

  self.rnn_cell = tf1.nn.rnn_cell.LSTMCell(n_hidden)
  self._kernel = self.add_variable(
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor



Creating run directory: ./logs/06862fd7da.
Attempting to build TF model on gpu:0

Placing CPU-only ops on cpu:0



  self._bias = self.add_variable(




Initializing new run (06862fd7da).

Trainable variables:
	RecurrentWhisperer/lstm_cell/kernel:0: (19, 64)
	RecurrentWhisperer/lstm_cell/bias:0: (64,)
	RecurrentWhisperer/W_out:0: (16, 3)
	RecurrentWhisperer/b_out:0: (3,)


Total run time time: 3.46s. 
	0.0% (986us): setup_hps
	0.1% (4.21ms): _setup_run_dir
	0.0% (1.15ms): set_random_seed
	0.0% (172us): init AdaptiveLearningRate
	0.0% (15.5us): init AdaptiveGradNormClip
	1.8% (62.8ms): _setup_records
	33.2% (1.15s): _setup_model
	49.9% (1.73s): _setup_optimizer
	0.0% (77.0us): _setup_visualizations
	5.1% (177ms): _setup_tensorboard
	2.4% (81.5ms): _setup_savers
	0.0% (1.51ms): _setup_session
	7.3% (254ms): initialize_or_restore

	Updating Tensorboard images.
Entering training loop.
Epoch 1 (step 1):
	Learning rate: 1.00e+00
	Training loss: 1.03e+00
	Improvement: nan
	Logging to: ./logs/06862fd7da
	Epoch time: 3.78s. [ prep data: 4.1% (155ms); batching: 0.0% (16.2us); train: 94.8% (3.59s); ltl: 1.1% (39.9ms); lvl: 0.0% (12.6us); visual

Instructions for updating:
Use standard file APIs to delete files with this prefix.



Stopping optimization: loss meets convergence criteria.
	Training loss: 9.56e-05
	Improvement: 1.44e-04
	Logging to: ./logs/06862fd7da
	Epoch time: 188ms. [ prep data: 73.8% (139ms); batching: 0.0% (11.2us); train: 25.0% (47.0ms); ltl: 0.3% (570us); lvl: 0.0% (8.11us); visualize: 0.1% (118us); terminate: 0.4% (824us); ]

	Saving SESO checkpoint.
	Saving .done file.

Closing training:
	Updating Tensorboard images.
	Saving SESO visualizations.
	Saving LVL summary (train).
	Saving LVL predictions (train).
	Updating Tensorboard images.
	Saving LVL visualizations.
	Saving LTL summary (train).
	Saving LTL predictions (train).
	Updating Tensorboard images.
	Saving LTL visualizations.

Total run time time: 36.8s. 
	0.0% (986us): setup_hps
	0.0% (4.21ms): _setup_run_dir
	0.0% (1.15ms): set_random_seed
	0.0% (172us): init AdaptiveLearningRate
	0.0% (15.5us): init AdaptiveGradNormClip
	0.2% (62.8ms): _setup_records
	3.1% (1.15s): _setup_model
	4.7% (1.73s): _setup_optimizer
	0.0% (77.0us): _setu

# New Section

In [None]:
# STEP 2: Find, analyze, and visualize the fixed points of the trained RNN
find_fixed_points(model, valid_predictions)


Searching for fixed points from 1024 initial states.

	Finding fixed points via joint optimization.
	Optimization complete to desired tolerance.
		526 iters
		q = 1.37e-14 +/- 2.34e-14
		dq = 1.58e-14 +/- 2.79e-13
		learning rate = 1.96e+01
		avg iter time = 3.27e-03 sec
	Identified 27 unique fixed points.
		initial_states: 0 outliers detected (of 1024).
		fixed points: 0 outliers detected (of 27).
	Computing recurrent Jacobian at 27 unique fixed points.
	Computing input Jacobian at 27 unique fixed points.
	Decomposing Jacobians in a single batch.
	Sorting by Eigenvalue magnitude.
	Fixed point finding complete.

