In [1]:
# import tensorflow as tf
import functools
import jax.numpy as jnp
import os
import json_lines
import numpy as np
import json

import jax
from jax.experimental import optimizers
from src import data, model_utils, optim_utils, measurements
from renn.rnn import cells, unroll, network
from renn import utils
import renn

from data_processing import analysis_utils as au
from renn import analysis_utils as renn_au

import tensorflow_datasets as tfds

from sklearn.decomposition import PCA

import matplotlib.pyplot as plt
%matplotlib inline

import jetpack

# Model and data loading

In [2]:
folders_to_use = {'GRU': 'results/yelp/five_class_fine_L2/GRU_eta_0.01_L2_0.1_run_om1jxhhv/',
                  'LSTM': 'results/yelp/five_class_fine_L2/LSTM_eta_0.01_L2_0.26826958_run_t9xq09dl/',
                  'UGRNN': 'results/yelp/five_class_fine_L2/UGRNN_eta_0.01_L2_0.26826958_run_eeh67rsj/'}

eigenvalues = {}

# del folders_to_use['GRU']

for arch, data_folder in folders_to_use.items():
    
    # load config
    with open(os.path.join(data_folder, 'config.json')) as f:
        config = json.load(f)
        
    # load data
    vocab_size, train_dset, test_dset = data.get_dataset(config['data'])
    
    cell = model_utils.get_cell(config['model']['cell_type'],
                            num_units=config['model']['num_units'])
    
    init_fun, apply_fun, emb_apply, readout_apply = network.build_rnn(vocab_size,
                                                                  config['model']['emb_size'],
                                                                  cell,
                                                                  num_outputs=config['model']['num_outputs'])
    
    network_params = model_utils.load_params(os.path.join(data_folder, 'final_params'))
    emb_params, rnn_params, readout_params = network_params
    
    test_batch = next(iter(tfds.as_numpy(test_dset)))
    
    visited_states = au.rnn_states(cell, test_batch, rnn_params, emb_params, emb_apply)
    final_states = au.rnn_end_states(cell, test_batch, rnn_params, emb_params, emb_apply)
    
    fixed_points_, loss_hist, fp_losses = au.fixed_points(cell,
                                                    rnn_params,
                                                    visited_states,
                                                    tolerance=5e-6,
                                                    embedding_size=config['model']['emb_size'],
                                                    noise_scale=0.4,
                                                    decimation_factor=4
                                                    )

    linearization_point = fixed_points_[0]
    
    J_hh = cell.rec_jac(rnn_params, 
             jnp.zeros(config['model']['emb_size']), 
             linearization_point)

    R, E, L = renn.eigsorted(J_hh)
    
    eigenvalues[arch] = E

Instructions for updating:
`tf.batch_gather` is deprecated, please use `tf.gather` with `batch_dims=-1` instead.


In [3]:
import pickle
with open('paper_figures/zero_d_eigenvalues.pickle', 'wb') as f:
    pickle.dump(eigenvalues, f)

# Image making

In [None]:
def plot_evals(evals, ax_obj):
    ax_obj.scatter(np.real(evals), np.imag(evals), c='k', marker='.')
    
def plot_unit_circle(ax_obj):
    t = np.linspace(0,np.pi*2,100)
    ax_obj.plot(np.cos(t),np.sin(t),linewidth=1)

def set_figure_settings(arch, arch_ind, ax_obj):
    ax_obj.set_aspect('equal')
    ax_obj.set_xlim(-1,1)
    ax_obj.set_ylim(-1,1)
    ax_obj.set_title(f'{arch} spectrum', fontsize=24)
    ax_obj.grid()
    ax_obj.set_xlabel('Real[eigenvalue]', fontsize=20)
 
    ax_obj.set_xticks([-1, -0.5, 0, 0.5, 1])
    ax_obj.set_yticks([-1, -0.5, 0, 0.5, 1])
    ax_obj.set_xticklabels([-1, -0.5, 0, 0.5, 1], fontsize=20)

    if arch_ind == 0:
        ax_obj.set_ylabel('Imag[eigenvalue]', fontsize=20)
        ax_obj.set_yticklabels([-1, -0.5, 0, 0.5, 1], fontsize=20)   
    
fig, ax = plt.subplots(figsize=(14, 5), ncols=3, sharey=True)    
architectures = ['LSTM', 'GRU', 'UGRNN']
for arch_ind, arch in enumerate(architectures):
    plot_evals(eigenvalues[arch], ax[arch_ind])
    plot_unit_circle(ax[arch_ind])
    set_figure_settings(arch, arch_ind, ax[arch_ind])
