## Overview
This script contains the source code for generating figures used in ```NSCMasterThesis_XinyueYao.pdf``` to demonstrate results of a recurrent neural network. Users need to install necessary libraries and initialize the inference before creating figures.

In all functions:
* $\beta=0.0$ indicates an $L^p$ regularized model
* $\beta=1.0$ indicates a distance-constrained model

Relevant figures are:

* [Fig. 3.3](#acc)
* [Fig. 3.4 & A.2](#trial)
* [Fig. 3.5](#ed)
* [Fig. 3.6](#conn)
* [Fig. 3.7](#connd)
* [Fig. A.3](#wd)
* [Fig. A.4](#cdf)

### Libraries required to run the script

In [None]:
import torch
import numpy as np
from glob import glob
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import os
%matplotlib notebook
from matplotlib.ticker import (MultipleLocator)

from pathlib import Path

import networkx as nx
import collections

from sklearn.cluster import SpectralClustering
import utils
from model import ConstrainedModel
from generate_input import get_data
from scipy import stats

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, AutoMinorLocator)
import matplotlib.patches as mpatches
from scipy.optimize import curve_fit

### Load the json file and initialize the model setup

In [None]:
json_path = "./hps.json"
hps = utils.Params(json_path)
model = ConstrainedModel(hps.n_bits, hps.hidden_size, hps.n_bits, hps.n_spatial_dims, hps.norm)
%matplotlib inline

In [None]:
def load_alpha_path(path):
    # Obtains the path for all checkpoints in a sorted list
    alpha_path = list(glob(os.path.join(path, 'alpha_*')))
    file_paths = []
    for idx in range(len(alpha_path)):
        file_path = list(glob(os.path.join(str(alpha_path[idx]))))#, 'checkpoints'
        if len(file_path)!=1:
            raise ValueError(f'File path should receive one element: {file_path}')
        file_path = file_path[0]
        file_paths.append(file_path)
    return file_paths

In [None]:
def get_wd(model, ckpt_file):
    ckpt = utils.load_checkpoint(ckpt_file, model)
    weight = []
    dist = []
    for w, d in zip(model.layers_weight, model.distances):
        weight += [torch.flatten(w.data).numpy()]
        dist += [torch.flatten(d.data).numpy()]
    weight = np.concatenate(weight)
    dist = np.concatenate(dist)
    return model, weight, dist

In [None]:
def average_trial(trial, fn, alpha, beta, **kwargs):
    json_path = "./hps.json"
    hps = utils.Params(json_path)
    model = ConstrainedModel(hps.n_bits, hps.hidden_size, hps.n_bits, hps.n_spatial_dims, hps.norm)
    
    list_metric = []
    files_seed = list(glob(f'./trial{trial}/seed*trained_model/alpha_{alpha}_beta_{beta}/checkpoints/last.pth'))
    for filename in files_seed:
        _, h_w, _ = get_wd(model, filename)
        list_metric.append(fn(h_w, **kwargs))
    return np.median(list_metric), np.std(list_metric)

In [None]:
def ks_test(a, b, thres=0.05):
    d, p = stats.ks_2samp(a, b)
    return d, p, p < thres

In [None]:
'''Remove nodes weights associated with nodes that only have self-connection but are not connected to other nodes. '''
def trim_wt(ckpt_path, hps, threshold=1e-8):
    alpha = ckpt_path.split('alpha_')[-1].split('_beta')[0]   
#     for idx in range(len(alpha_paths)):
    model = ConstrainedModel(hps.n_bits, 64, hps.n_bits, hps.n_spatial_dims, hps.norm)
    # Reload checkpoint file for inference.    
    ckpt = utils.load_checkpoint(ckpt_path, model)

    # change weight matrices
    for x in model.layers_weight:
        x.data = x * (torch.abs(x) > threshold)
    
    ih = ckpt['state_dict']['rnn.weight_ih_l0']
    hh = ckpt['state_dict']['rnn.weight_hh_l0']
    oh = ckpt['state_dict']['output_layer.weight']
    a  = np.concatenate((ih.numpy(), hh.numpy()), axis=1)
    h_w = np.concatenate((a, np.transpose(oh.numpy())), axis=1)

    # Remove nodes with only self connections

    for i, row in enumerate(h_w):
        for j, val in enumerate(row):
            if j < (len(row)-hps.n_bits):
                if val!=0 and sum(abs(h_w[i]) )- abs(val)==0 and i+hps.n_bits==j:
                    model.rnn.weight_hh_l0.data[i][i] = 0
    h_w = []
    for layer in model.layers_weight:
        h_w += [torch.flatten(layer.data).cpu().numpy()]
    h_w = np.concatenate(h_w)
    return h_w

<a id="ed"></a>

### [Plot network edge density vs. $\alpha$ values (averaged across 20 trials)](#ed)
```plot_ED_alpha``` produces the edge density vs. alphas for both types of constrained models.  

Users need to specify the values of $\alpha$ to be plotted.

In [None]:
def plot_ED_alpha(trial=hps.trial, threshold=0., print_graph=True):
    betas = [0.0, 1.0]
    fig, ax = plt.subplots(figsize=(9, 4))
    labels = [f'L{hps.norm} Regularization', 'Distance Constrained']
    max_y = 0
    for beta, label in zip(betas, labels):
        alphas = np.arange(0.0, 0.009, 0.001)
        def fn(h_w, threshold):
            A = (np.abs(h_w) > threshold).astype(int)
            a_sum = A.sum()
            # Network Connection Density is computed as nonzero_weights/total_weigths
            valid_conn = np.round(a_sum / len(h_w), 4)
            return valid_conn  
        ed = []
        ed_std = []
        for alpha in alphas:
            avg_conn, std_cond = average_trial(trial, fn, alpha, beta, threshold=threshold)
            ed.append(avg_conn)
            ed_std.append(std_cond)
        ax.errorbar(alphas, ed, yerr=ed_std, label=f'{label}')
        max_y = max(max_y, max(ed)+max(ed_std))
    ax.set_xticks(alphas)
    ax.set(xlabel=r'$\alpha$', ylabel='Edge Density', title=r'Network Edge Density vs. $\alpha$')
    ax.legend(loc=0)
    fig.savefig(f'EdgeDensityAlpha_trial{trial}.png', bbox_inches='tight')


<a id="trial"></a>

### [Plot Flip-flop task trial figures](#trial)
Function `plot_trial` generates a visual representation of the model's performance with MSE denoted.
Users need to specify the seed of the experiment and the type of constrained models to be printed.

In [None]:
def plot_trial(seed, beta):
    files = list(glob(f'./trial{hps.trial}/seed{seed}trained_model/alpha_*_beta_{beta}/checkpoints/last.pth'))
    for f in files:
        alpha = f.split('alpha_')[-1].split('_beta')[0]
        ckpt = torch.load(f, map_location='cpu')
        inputs = ckpt["inputs"]
        outputs = ckpt["outputs"]
        targets = ckpt["targets"]
        accuracy = ckpt["accuracy"][0]
        n_bits = hps.n_bits
        vertical_spacing = 2.5

        fig, ax = plt.subplots(figsize=(8,6))

        for bit_idx in range(n_bits):
            ax.step(
                range(inputs.shape[1]),
                inputs[0, :, bit_idx].detach() + vertical_spacing * bit_idx,
                color="#9C3D17",
                label="Inputs",
                linewidth=2.5,
            )
            ax.plot(
                range(outputs.shape[1]),
                outputs[0, :, bit_idx].detach() + vertical_spacing * bit_idx,
                color="#00119E",
                label="Outputs",
                linewidth=3.5,
            )
            ax.plot(
                range(targets.shape[1]),
                targets[0, :, bit_idx].detach() + vertical_spacing * bit_idx,
                color="#E8C23A",
                label="Targets",
                linewidth=3.5,
            )

        ax.set_yticks([(bit_idx * vertical_spacing) for bit_idx in range(n_bits)])
        ax.set_yticklabels(
            ["Bit %d" % (n_bits - bit_idx) for bit_idx in range(n_bits)],
            fontweight="bold",
        )
        ax.set_title("Trial Plot with mse = {:0.2f}".format(accuracy), fontweight="bold")
        ax.set_xlabel(f"Time Step (alpha_{alpha})", fontweight="bold")

        fig.savefig(f'trial{hps.trial}_seed{seed}_alpha_{alpha}_beta_{beta}_trialplot.png', bbox_inches='tight')

<a id="conn"></a>

### [Plot connectivity vs. distance for a specific alpha](#conn)
The ```plot_dist_conn_hres``` function outputsg a histogram of the counts of connections (excluding nodes with only self-loops) vs. the distance distribution. 

In [None]:
def plot_dist_conn_hres(alpha, hps, threshold=0.):
    files_seed = [np.sort(list(glob(f'./trial{hps.trial}/seed*trained_model/alpha_{alpha}_beta_0.0/checkpoints/last.pth'))).tolist(),
                  np.sort(list(glob(f'./trial{hps.trial}/seed*trained_model/alpha_{alpha}_beta_1.0/checkpoints/last.pth'))).tolist()]
    files = [*zip(files_seed[0], files_seed[1])]
    bins=50
    hist_f1 = []
    hist_f2 = []
    dist_f1 = []
    dist_f2 = []
    fig, ax = plt.subplots(figsize=(10, 7))
    for f1, f2 in files:
        for f, hist_save, dist_save in zip([f1, f2], [hist_f1, hist_f2], [dist_f1, dist_f2]):
            _, _, dist = get_wd(model, f)
            h_w = trim_wt(f, hps)
            h_d = dist[np.abs(h_w) > threshold]
            h_d = h_d[h_d != 0.]
            dist_save.append(h_d.mean())
            hist_save += [*h_d]
    dist_f1 = np.median(dist_f1)
    dist_f2 = np.median(dist_f2)
    
    for hist, d_mean, beta, color in zip([hist_f1,hist_f2], [dist_f1, dist_f2], [f'L{hps.norm} Regularization', 'Distance Constrained',], ['lightblue', 'orange']):
        ax.hist(hist, bins=bins, label=f'{beta}', color=color, alpha=0.8)
        ax.axvline(d_mean, linestyle='dashed', c=color)

    ax.set(title=f'Averaged Distance Distributions vs. #Connections (a={alpha})', 
           ylabel='#Connections', xlabel='Distance Distribution')
    ax.set_xlim([0, 3])
    ax.legend()
    plt.show()
    fig.savefig(f'trial{hps.trial}_alpha_{alpha}averaged_dc.png', bbox_inches='tight')

<a id="connd"></a>

### [Plot the connection distribution at each range of distances](#connd)
Function `plot_dist_conn` plots the distribution of the medians of all experiments at a specific distance.
Users have the option to plot an exponential curve to fit to the distribution.

In [None]:
def func(x, a, b, c) :
    return a * np.exp(-b * x) + c

def curve_fit_log(xdata, ydata) :
    """Fit data to a power law with weights according to a log scale"""
    # Weights according to a log scale
    # Apply fscalex
    logx = np.log10(xdata)
    # Apply fscaley
    logy = np.log10(ydata)
    # Fit linear
    popt_log, pcov_log = curve_fit(linlaw, logx, logy)
    #print(popt_log, pcov_log)
    # Apply fscaley^-1 to fitted data
    ydatafit_log = np.power(10, linlaw(logx, *popt_log))
    # There is no need to apply fscalex^-1 as original data is already available
    return popt_log, pcov_log, ydatafit_log

In [None]:
def plot_dist_conn(alpha, hps, plot_curve=False, threshold=0.):
    files_seed = [np.sort(list(glob(f'./trial{hps.trial}/seed*trained_model/alpha_{alpha}_beta_0.0/checkpoints/last.pth'))).tolist(),
                  np.sort(list(glob(f'./trial{hps.trial}/seed*trained_model/alpha_{alpha}_beta_1.0/checkpoints/last.pth'))).tolist()]
    files = [*zip(files_seed[0], files_seed[1])]
    
    width = 0.5
    bins = np.arange(0,3.,width)
    x_bar = (bins[1:] - bins[:-1]) / 2 + bins[:-1]
    x_ticks = ['D1', 'D2', 'D3', 'D4', 'D5']
    hist_f1 = []
    hist_f2 = []
    dist_f1 = []
    dist_f2 = []
    fig, ax = plt.subplots(figsize=(10, 7))
    for f1, f2 in files:
        for f, hist_save, dist_save in zip([f1, f2], [hist_f1, hist_f2], [dist_f1, dist_f2]):
            _, _, dist = get_wd(model, f)
            h_w = trim_wt(f, hps)
            h_d = dist[np.abs(h_w) > threshold]
            dist_save.append(h_d.mean())
            
            # This outputs the histogram of the distance of nodes which have connections, 
            # and contains the number of samples in each bin; if density is set to be "True", the the result is a pdf
            counts, edges = np.histogram(h_d, bins=bins, density=False)
            
            hist_save.append(counts)
    dist_f1 = np.median(dist_f1)
    dist_f2 = np.median(dist_f2)
    hist_f1 = np.median(np.stack(hist_f1), axis=0)
    hist_f2 = np.median(np.stack(hist_f2), axis=0)
    labels = ['L1 Regularization', 'Distance Constrained']
    for hist, d_mean, label, color in zip([hist_f1,hist_f2], [dist_f1, dist_f2], labels, ['lightblue', 'orange']):
        ax.bar(x_bar, hist, width=width, label=f'{label}', color=color, alpha=0.8)
        ax.axvline(d_mean, linestyle='dashed', c=color)

    ax.set(title=f'Levels of distance vs. #Connections (a={alpha})', 
           ylabel='#Connections', xlabel='Distance Distribution')
    ax.set_xlim([0, 2.5])
    ax.set_ylim([0, max(max(hist_f1), max(hist_f2))+0.1])
    ax.set_xticks(x_bar)
    ax.set_xticklabels(x_ticks)
        
    def func(x, a, b, c):
        return a * np.exp(-b * x) + c
    if plot_curve==True:
        popt, pcov = curve_fit(func, x_bar, hist_f2)
        residuals = hist_f2 - func(x_bar, *popt)
        ss_res = np.sum(residuals**2)
        ss_tot = np.sum((hist_f2-np.mean(hist_f2))**2)
        r_squared = 1 - (ss_res / ss_tot)
        ax.plot(x_bar, func(x_bar, *popt), 'r--', label=r'$R^2$=%4.3f'% r_squared)
    ax.legend()
    plt.show()
    fig.savefig(f'trial{hps.trial}_alpha{alpha}_averaged_leveldis_conn.png', bbox_inches='tight')

<a id="acc"></a>

### [Plot the median accuracy of all experiments for each alphas](#acc)
Function `plot_acc` plots the accuracies of each alpha, and outputs if there is a significant difference between the two types of constrained model's performance.

In [None]:
def plot_acc(trial):
    betas = [0.0, 1.0]
    label = [f'L{hps.norm} Regularization', 'Distance Constrained']
    colors = ['lightblue', 'orange']
    alphas = np.arange(0.0, 0.009, 0.001)
    pos_a = [0.5+i*1.5 for i in range(len(alphas))]
    pos_b = [a+0.5 for a in pos_a]
    x_t = [a+0.5/2 for a in pos_a]
    position = [pos_a, pos_b]
    x_ticks = [str(a) for a in alphas]
    
    fig, ax = plt.subplots(figsize=(10,6))
    l_acc = []
    d_acc = []
    def avg(files):
        accs = []
        for f in files:
            ckpt = torch.load(f, map_location='cpu')
            acc = ckpt['accuracy'][0]
            accs.append(acc)
        return accs
    def add_label(violin, label):
        color = violin["bodies"][0].get_facecolor().flatten()
        labels.append((mpatches.Patch(color=color), label))
    labels = []
    for beta, color, l, pos, accs in zip(betas, colors, label, position, [l_acc, d_acc]):
        stdev = []
        for alpha in alphas:
            alpha = round(alpha, 5)
            files = list(glob(f'./trial{trial}/seed*trained_model/alpha_{alpha}_beta_{beta}/checkpoints/last.pth'))
            acc = avg(files)
            accs.append(acc*100)
        vp = ax.violinplot(accs, pos, points=20, widths=0.3, showmeans=False, showmedians=True, showextrema=False)
        ax.set_xticks(x_t)
        ax.set_xticklabels(x_ticks)
        for i, pc in enumerate(vp['bodies']):
            pc.set_facecolor(f'{color}')
            pc.set_alpha(0.4)
        add_label(vp, f'{l}')
        quartile1, medians, quartile3 = np.percentile(accs, [25, 50, 75], axis=1)
        ax.scatter(pos, medians, marker='D', color=color, s=25, zorder=3)

    for a, b in zip(l_acc, d_acc):
        d, p, pv = ks_test(a, b)
        print(f'D-stat is {d}, p-value is {p}')
    ax.set(xlabel=r"$\alpha$", ylabel='mean square error')
    plt.legend(*zip(*labels), loc=4)
    plt.show()
    fig.savefig(f'trial_{trial}_acc.png', dpi=300, bbox_inches='tight')


<a id="wd"></a>

### [Weight Distribution vs. distances](#wd)
Users need to specify the values of alphas inside the function.

In [None]:
def plot_w_d(seed):
    alphas = [0.0, 0.001,0.002, 0.005, 0.007]
    n_c = 2
    n_r = len(alphas)
    fig, axs = plt.subplots(figsize=(8, 10), nrows=n_r, ncols=n_c)
    xlim = 3
    fig.align_ylabels()
    plt.setp(axs, xlim=(0,xlim))
    for i, a in zip(range(n_r), alphas):  
        lp = np.sort(list(glob(f'./trial{hps.trial}/seed{seed}trained_model/alpha_{a}_beta_0.0/checkpoints/last.pth')))
        dc = np.sort(list(glob(f'./trial{hps.trial}/seed{seed}trained_model/alpha_{a}_beta_1.0/checkpoints/last.pth')))
        il = [*zip(lp, dc)]
        di = []
        wi = []
        dl = []
        wl = []
        model = ConstrainedModel(hps.n_bits, hps.hidden_size, hps.n_bits, hps.n_spatial_dims, hps.norm)
        for f1, f2 in il:
            for f, d_save, w_save in zip([f1, f2], [di, dl], [wi, wl]):
                _, weight, dist = get_wd(model, f)
#                 weight = trim_wt(f, hps)
                
                dist = dist[np.abs(weight) > 0.]
                weight = weight[np.abs(weight) > 0.]
                d_save.append(dist)
                w_save.append(weight)
            for ax, d, w, y_lim in zip([axs[i, 0], axs[i, 1]], [di, dl], [wi, wl], [max(max(np.abs(wi))), max(max(np.abs(wl)))]):
                hb = ax.hexbin(d, w, gridsize=20, cmap=color, norm=matplotlib.colors.LogNorm(),
                               extent=[0, xlim,-y_lim, y_lim], clim=[1, 10])
                ax.set_ylim(-y_lim, y_lim)
            axs[i,0].set_ylabel('weight distributions\nalpha='+f'{a}', multialignment='center')
            
    
    cax = fig.add_axes([0.92, 0.1, 0.03, 0.8])
    cb = fig.colorbar(hb, cax=cax)
    axs[0,0].set(title=f'l{hps.norm}-regularized')
    axs[0,1].set(title='distance-constrained')
    axs[n_r-1,0].set(xlabel='distance distribution')
    axs[n_r-1,1].set(xlabel='distance distribution')
    
    fig.savefig(f'trial{hps.trial}seed{seed}_L{hps.norm}_distance_weight.png', dpi=600, bbox_inches='tight')
    

<a id="cdf"></a>

### [Plot the KS test as a CDF](#cdf)
Users need to specify the values of alphas inside the function. The variable `xlims` needs to be adjusted accordingly for a proper visualization with which both tails can be seen. 

In [None]:
def plot_tails_w_d(seed, pnorm, xlims=None, debug=False):
    alphas = [0.0, 0.001,0.002, 0.005, 0.007]
    n_c = 2
    n_r = len(alphas)
    fig, axs = plt.subplots(figsize=(8, 10), nrows=n_r, ncols=n_c)
    xlim = 3
#     colors = ['lightblue', 'orange']
#     fig.align_ylabels()
#     plt.setp(axs, xlim=(0,xlim))
    for i, a in zip(range(n_r), alphas):  
        lp = np.sort(list(glob(f'./trial{hps.trial}/seed{seed}trained_model/alpha_{a}_beta_0.0/checkpoints/last.pth')))
        dc = np.sort(list(glob(f'./trial{hps.trial}/seed{seed}trained_model/alpha_{a}_beta_1.0/checkpoints/last.pth')))
        il = [*zip(lp, dc)]
        di = []
        wi = []
        dl = []
        wl = []
        model = ConstrainedModel(hps.n_bits, hps.hidden_size, hps.n_bits, hps.n_spatial_dims, hps.norm)
        for f1, f2 in il:
            for f, d_save, w_save in zip([f1, f2], [di, dl], [wi, wl]):
                _, weight, dist = get_wd(model, f)
                
                dist = dist[np.abs(weight) > 1e-5]
                weight = weight[np.abs(weight) > 1e-5]
                d_save.append(dist)
                w_save.append(weight)
                
            for k, (ax, d, w, y_lim) in enumerate(zip([axs[i, 0], axs[i, 1]], [di, dl], [wi, wl], [max(max(np.abs(wi))), max(max(np.abs(wl)))])):
                idxs = np.argsort(d)[0]  # We sort by distance
                dist = d[0][idxs]
                weights = w[0][idxs]
                def get_tails(x):
                    if len(x) % 2 == 1:
                        # Odd
                        median = x[len(x)//2] # Median corresponding to the distance (we sorted by it)
                        median_pos = len(x)//2
                        down_x = x[:median_pos]
                        up_x = x[median_pos +1:]
                    else:
                        median = (x[len(x)//2 + 1] - x[len(x)//2]) / 2
                        down_x = x[:len(x)//2]
                        up_x = x[len(x)//2:]
                    return down_x, up_x
                
                down_w, up_w = get_tails(weights)
                if debug:
                    print(down_w, up_w)
                ax.plot(sorted(down_w), np.arange(len(down_w)) / len(down_w) , label='Lower tail')
                ax.plot(sorted(up_w), np.arange(len(up_w)) / len(up_w), 'r', label='Upper tail')
                
                _,p,_ = ks_test(down_w, up_w)
                print(p)
                
#                 ax.plot(down_d, up_dm)
#                 ax.plot([0,up_dm[-1]], [0, up_dm[-1]], '--k')
                if xlims:
                    ax.set_xlim(xlims[0], xlims[1])
                
                if k == 0:
                    axs[i,k].set_ylabel('Probability\nalpha='+f'{a}\np-val: {p:.6f}', multialignment='center')
                else:
                    axs[i,k].set_ylabel(f'p-val: {p:.6f}', multialignment='center')
    fig.savefig(f'trial{hps.trial}seed{seed}_L{hps.norm}_wd_stat.png', dpi=600, bbox_inches='tight')