In [None]:
from LNN import *

from scipy.stats import lognorm

# matplotlib settings
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib import rc
plt.style.use('ggplot')
rc('font',**{'family':'serif','serif':['Computer Modern Roman']})
rc('text', usetex = True)
from mpl_toolkits.axes_grid1 import make_axes_locatable

CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a',
                  '#f781bf', '#a65628', '#984ea3',
                  '#999999', '#e41a1c', '#dede00']

# Functions

In [None]:
def struct_weight_plot_nonlinear_N(N_max, ks, v=None, s=1., typ=1, colors=CB_color_cycle,
                           sigmaM=1., sigmaC=1., ax=None, linestyle='-'):
    # create plot
    if ax is None:
        fig = plt.figure(figsize = (10, 8))
        ax = fig.add_subplot(111)
    # iterate over scales
    for k_idx, k in enumerate(ks):
        Ns = np.arange(k, N_max, k)
        data = np.zeros(Ns.shape)
        for N_idx, N in enumerate(Ns):
            if typ==1:
                w=LNN.struct_weight_maker(N, k)
            else:
                w=LNN.struct_weight_maker(N, N/k)
            lnn = LNN(v=np.ones(N), w=w, 
                      sigmaM=sigmaM, sigmaC=sigmaC, nonlinearity='squared')
            data[N_idx] = lnn.FI_squared_nonlin(s)
        if typ==1:
            ax.plot(Ns, data, label=r'$k=%s$' %k, linewidth=4, color=colors[-k_idx], linestyle=linestyle)
        else:
            ax.plot(Ns, data, label=r'$k=N/%s$' %k, linewidth=4, color=colors[k_idx], linestyle=linestyle)
    ax.set_facecolor('white')
    ax.set_xlabel(r'$N$', fontsize = 30)
    ax.tick_params(labelsize=20)
    lgd = ax.legend(loc=2, ncol=2, facecolor='white', prop={'size' : 15})
    lgd.get_frame().set_edgecolor('k')
    for spine in ax.spines.values():
        spine.set_edgecolor('k')
    return ax

In [None]:
def plot_fisher_nonlinear_2d(N, ratios, ks, v=None, s=1., typ=1, colors=CB_color_cycle,
                            ax=None):
    # create plot
    if ax is None:
        fig = plt.figure(figsize = (10, 8))
        ax = fig.add_subplot(111)
    if v is None:
        v = np.ones(N)
    fishers = np.zeros((ratios.size, ks.size))
    for ratio_idx, ratio in enumerate(ratios):
        sigmaC = 1
        sigmaM = ratio * sigmaC
        for k_idx, k in enumerate(ks):
            if typ==1:
                w=LNN.struct_weight_maker(N, k)
            else:
                w=LNN.struct_weight_maker(N, N/k)    
            lnn = LNN(v=np.ones(N), w=w, 
                      sigmaM=sigmaM, sigmaC=sigmaC, nonlinearity='squared')
            fishers[ratio_idx, k_idx] = lnn.FI_squared_nonlin(s)
        fishers[ratio_idx, :] = fishers[ratio_idx, :]/np.max(fishers[ratio_idx, :])
    ax.grid(False)
    img = ax.imshow(np.flip(fishers.T, axis=0), interpolation='spline36')
    ax.tick_params(labelsize=20)
    return img

# Plot Figure 3

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24,8))

struct_weight_plot_nonlinear_N(1000, [1, 2, 3, 4], typ=1, ax=axes[0], sigmaM=1.)
struct_weight_plot_nonlinear_N(1000, [1, 2, 3, 4], typ=2, ax=axes[0], linestyle='--', sigmaM=1.)
axes[0].set_xlim([0, 1000])
axes[0].set_ylabel('Fisher Information', fontsize=30)
axes[0].set_title(r'\textbf{A}', fontsize=30)
axes[0].text(500, 8, s=r'$\sigma_M=1$', fontsize=30)

struct_weight_plot_nonlinear_N(1000, [1, 2, 3, 4], typ=1, ax=axes[1], sigmaM=5.)
struct_weight_plot_nonlinear_N(1000, [1, 2, 3, 4], typ=2, ax=axes[1], linestyle='--', sigmaM=5.)
axes[1].set_xlim([0, 1000])
axes[1].set_ylabel('Fisher Information', fontsize=30)
axes[1].set_title(r'\textbf{B}', fontsize=30)
axes[1].text(500, 0.54, s=r'$\sigma_M=5$', fontsize=30)

ratios = np.linspace(0.1, 20, 1000)
ks = np.arange(1, 11)
img = plot_fisher_nonlinear_2d(1000, ratios, ks, s=1, typ=1, ax=axes[2])
axes[2].set_yticks([0, 2, 4, 6, 8])
axes[2].set_yticklabels([9, 7, 5, 3, 1])
axes[2].set_xticks([0, 95, 246, 497, 748, 999])
axes[2].set_xticklabels([0.1, 2, 5, 10, 15, 20])
axes[2].set_aspect(ratios.size/ks.size)
axes[2].tick_params(labelsize=20)
axes[2].set_xlabel(r'$\sigma_M$', fontsize=30)
axes[2].set_ylabel(r'$k$', fontsize=30)
axes[2].set_title(r'\textbf{C}', fontsize=30)
cb = plt.colorbar(img, fraction=0.046, pad=0.04)
cb.ax.set_ylabel(r'Normalized Fisher Information', fontsize=30)
cb.ax.tick_params(labelsize=20) 
plt.tight_layout(rect=[0, 0.0, 1, 0.94])
plt.savefig('figure3.pdf')