In [None]:
import sys
import matplotlib.pyplot as plt
import torch
import numpy as np
sys.path.append('..')
from topographic.utils.plotting.EI import compare_jj_lambdas
from topographic.config import SAVE_DIR

# Comparing models with different spatial penalties
Here, we compare models of a given architecture at a range of $\lambda_w$ values

In [None]:
architectures = ['FNN', 'RNN', 'EFF-FNN', 'EFF-RNN', 'E/I-FNN', 'E/I-RNN', 'E/I-EFF-RNN']

for arch in ['E/I-EFF-RNN']:
    plt.close('all')
    models = compare_jj_lambdas(
        seed=1, 
        mod_id='jjd2', 
        mod_type=arch, 
        wrap=False, 
        lambdas=[0, 0.01, 0.1, 0.5, 1.0, 10.0], 
        do_performance=True, do_generic=False, do_domains=False, do_combined=True, do_summary=True, 
        show=True,
        )
    plt.show()

# Comparing models with different architectures
Here, we compare architectures using the value of $\lambda_w$ that produced the largest domain-level topographic summary statistic, with a new seed

In [None]:
from topographic.utils.plotting.EI import ei_ablations
plt.close('all')

models = ei_ablations(seed=2, jj=True, ln=True, wrap=False, subset=True, do_performance=True, do_generic=False, do_domains=False, do_combined=True, do_summary=True, show=True,)

# Compare models with different architectures and spatial penalty
Plots a wider range of metrics than before, including wiring cost, but does not perform visualizations. 

Note: this may take a while to run, which is why we save the results for all the models at the end of the first cell.

In [None]:
from topographic.config import SAVE_DIR, FIGS_DIR
from topographic.utils.commons import exp_times
from topographic.utils.experiments import run_wiring_cost_experiment


In [None]:
from topographic.utils.plotting.EI import compute_domain_summary_stat, compute_generic_summary_stat, get_binary_selectivity, load_for_tests
import os
from topographic.config import SAVE_DIR
import pandas as pd
import pickle

seed=1
ln=True
wrap=False
mod_types=['E/I-EFF-RNN', 'EFF-RNN', 'E/I-RNN', 'RNN', 'EFF-FNN', 'E/I-FNN', 'FNN']
lambdas=[0, 0.001, 0.01, 0.1, 0.5, 1.0, 10.0]
layers=['pIT', 'cIT', 'aIT']

stat_types = [1,2]
time='exp_end'
celltypes = ['', '_I']
binary_map_smoothing = 0
neglogp_thresh=3
overwrite_df = False

model_dict = {'arch':[], 'lambda':[], 'base_fn':[], 'generic_stat':[], 'final accuracy':[], 'final top5 accuracy':[], 'jj_ff':[], 'jj_rec':[]}
for stat_type in stat_types:
    model_dict[f'domain_stat_{stat_type}'] = []
for layer in layers:
    model_dict[f'{layer}_topography'] = []
for sparsity in [0.01, 0.05, 0.1]:
    model_dict[f'cost_at_{sparsity}'] = []
    model_dict[f'rec_cost_at_{sparsity}'] = []
    model_dict[f'ff_cost_at_{sparsity}'] = []
    model_dict[f'acc_at_{sparsity}'] = []
    
wrap_tag = '_nwr-1' if not wrap else ''
ln_tag = 'noln' if not ln else '' 
mod_id = 'jjd2'+ln_tag

os.makedirs(f'{SAVE_DIR}/dfs', exist_ok=True)
save_file =  f'{SAVE_DIR}/dfs/full_ablations{ln_tag}_results.pkl'

if not os.path.exists(save_file) or overwrite_df:

    for mod_type in mod_types:
        if mod_type == 'E/I-EFF-RNN':
            temp_fn = 'a-0.2_cIT-1_cell-EI4_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-v3_tr-miniOFS'
        elif mod_type == 'EFF-RNN':
            temp_fn = 'a-0.2_cIT-1_cell-SRNEFF_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-v3_tr-miniOFS'
        elif mod_type == 'E/I-RNN':
            temp_fn = 'a-0.2_cIT-1_cell-EI5_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-v3_tr-miniOFS'
        elif mod_type == 'RNN':
            temp_fn = 'a-0.2_cIT-1_cell-SRN_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-v3_tr-miniOFS'
        elif mod_type == 'EFF-FNN':
            temp_fn = 'a-1.0_cIT-1_cell-SRNEFF_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1_norec-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-ff_tr-miniOFS'
        elif mod_type == 'E/I-FNN':
            temp_fn = 'a-1.0_cIT-1_cell-EI5_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1_norec-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-ff_tr-miniOFS'
        elif mod_type == 'FNN':
            temp_fn = 'a-1.0_cIT-1_cell-SRN_conn-full_ar-1.0_enc-resnet50_err-1_fsig-5.0_gc-10.0_i2e-1.0_id-{}_imd-112{}_lr0-0.01_me-300_ms-32_nl-relu_noi-uniform_nret-1_norec-1{}_optim-sgd_rv4-1_rs-{}_rsig-5.0_sch-plateau_sq-1_t-ff_tr-miniOFS'
        else:
            raise ValueError()
        for lam in lambdas:
            jj_tag = f'_jj-{lam}' if lam else ''
            jj_name = r'$\lambda=$' + str(lam)

            full_fn = temp_fn.format(mod_id, jj_tag, wrap_tag, seed)
            # first check it exists
            if not os.path.exists(f'{SAVE_DIR}/results/{full_fn}_losses.pkl'):
                print(f'{full_fn} does not exist')
                continue
            # now check to make sure model finished, which we will do by looking at the final learning rate
            sd = torch.load(f'{SAVE_DIR}/models/{full_fn}.pkl')
            if sd[2]['_last_lr'][0] != 1e-05:
                print(f'{full_fn} not done')
                continue

            outputs = load_for_tests(full_fn, as_dict=True)
            timing = exp_times[outputs['opt']['timing']]
            t = getattr(timing, time)
            res = outputs['res']
            these_generic_stats = []

            for si, stat_type in enumerate(stat_types):
                these_domain_stats = []
                for layer in layers:
                    for celltype in celltypes:
                        if 'EI' not in outputs['opt']['cell'] and celltype == '_I':
                            continue
                        dists = getattr(outputs['model'], layer).rec_distances[:outputs['opt'].ms**2,:outputs['opt'].ms**2]
                        selectivity_dict = {domain: res.iloc[np.logical_and.reduce((res.t == t-1, res.layer == layer))][f'{domain}_selectivity_neglogp{celltype}'].iloc[0] for domain in ['object', 'face', 'scene']}
                        domain_stat = compute_domain_summary_stat(selectivity_dict, dists, outputs['opt'].ms, celltype, stat_type=stat_type, binary_map_smoothing=binary_map_smoothing)
                        these_domain_stats.append(domain_stat)
                        activations = outputs['acts'][layer][:, t-1] # don't select celltype units, as this is done inside helper function
                        if si == 0:
                            # super hacky way to only do this once
                            generic_stat = compute_generic_summary_stat(activations, dists, outputs['opt']['ms'], celltype, plot=False)
                            these_generic_stats.append(generic_stat)
                model_dict[f'domain_stat_{stat_type}'].append(np.nanmean(these_domain_stats))

            model_dict['base_fn'].append(full_fn)
            model_dict['arch'].append(mod_type)
            model_dict['lambda'].append(lam)
            model_dict['generic_stat'].append(np.nanmean(these_generic_stats))
            for layer in layers:
                distances = getattr(outputs['model'], layer).rec_distances[:outputs['opt'].ms**2,:outputs['opt'].ms**2]
                colored_map, cmap = get_binary_selectivity(res, t, distances, layer, ms=outputs['opt'].ms, smoothing_p=binary_map_smoothing, neglogp_thresh=neglogp_thresh)
                model_dict[f'{layer}_topography'].append(colored_map)

            if os.path.exists(f'{SAVE_DIR}/results/{full_fn}_wiring_cost.pkl'): 
                jj_losses = np.load(f'{SAVE_DIR}/results/{full_fn}_wiring_cost.pkl',allow_pickle=True)
                model_dict['jj_rec'].append(jj_losses['rec'])
                model_dict['jj_ff'].append(jj_losses['ff'])

            else:
                for conn_type in ['ff', 'rec']:
                    jj_loss = 0
                    for param, dists in outputs['model'].spatial_params(subcells=[conn_type]):
                        if 'jjl2' in getattr(outputs['model'], outputs['model'].cells[-1]).cell.modid:
                            jj_loss += torch.sum((param**2)*dists/(1+param**2))
                        elif 'jjd2' in getattr(outputs['model'], outputs['model'].cells[-1]).cell.modid:
                            jj_loss += torch.sum((param**2)*(dists**2)/(1+param**2))
                        else:
                            jj_loss += torch.sum(torch.abs(param)*dists)
                    model_dict[f'jj_{conn_type}'].append(jj_loss.item())
            # now compute sparse binary wiring cost
            wiring_results = run_wiring_cost_experiment(full_fn, 
                                        sparsity_vals=[0.01],
                                        analysis_times=np.unique([timing.stim_off, timing.exp_end, timing.exp_end+timing.jitter]),
                                        alg='l2',
                                        local_pruning=True,
                                        inputs_and_outputs=False,
                                        overwrite=False,
                                        overwrite_indiv_results=False, overwrite_indiv_exp=False,

            )
            wiring_results = pd.DataFrame(wiring_results)
            for sparsity in [0.01]:
                model_dict[f'cost_at_{sparsity}'].append(wiring_results[wiring_results.sparsity==sparsity].wiring_cost.mean())
                model_dict[f'rec_cost_at_{sparsity}'].append(wiring_results[wiring_results.sparsity==sparsity].wiring_cost_rec.mean())
                model_dict[f'ff_cost_at_{sparsity}'].append(wiring_results[wiring_results.sparsity==sparsity].wiring_cost_ff.mean())
                model_dict[f'acc_at_{sparsity}'].append(wiring_results[np.logical_and(wiring_results.t==timing.exp_end, wiring_results.sparsity==sparsity)].accuracy.mean())

            with open(f'{SAVE_DIR}/results/{full_fn}_losses.pkl', 'rb') as f:
                these_losses = pickle.load(f)
            model_dict['final accuracy'].append(these_losses['val']['acc'][-1])
            model_dict['final top5 accuracy'].append(these_losses['val']['top5'][-1])

    with open(save_file, 'wb') as f:
        pickle.dump(model_dict, f)
else:
    with open(save_file, 'rb') as f:
        model_dict = pickle.load(f)

In [None]:
import matplotlib.gridspec as gridspec
from topographic.config import FIGS_DIR

with open(f'full_ablations{ln_tag}_results.pkl', 'wb') as f:
    pickle.dump(model_dict, f)
    
save_dir = f'{FIGS_DIR}/topographic/EI/experiments/arch_by_lambda{ln_tag}'
os.makedirs(save_dir, exist_ok=True)

df = pd.DataFrame(model_dict)
df = df[df['lambda'] != 0.001]
df['jj'] = df['jj_rec'] + df['jj_ff']
df['jj_rec'][['FNN' in arch for arch in df.arch]] = np.nan

for stat, ylabel in [
    ('generic_stat', r'Generic topography ($T_g$)'), 
    ('domain_stat_2', r'Domain topography ($T_d$)'), 
    ('final accuracy', 'Accuracy'), 
    ('acc_at_0.01', 'Accuracy at s=0.99'), 
    ('cost_at_0.01', 'Average connection $d^2$ (s=0.99)'),
    ('jj', '$\mathcal{L}_w$')
]:
# for stat in ['cost_at_0.01', 'jj_ff', 'jj_rec', 'jj']:
    # fig, axs = plt.subplots(1,2,sharey=True)
    fig = plt.figure()
    gs = gridspec.GridSpec(1,2,width_ratios=[1,5])
    ax1 = fig.add_subplot(gs[0])
    ax2 = fig.add_subplot(gs[1], sharey=ax1)
    # ax1 = plt.subplot2grid((4,1),(0,0))
    # ax2 = plt.subplot2grid((4,1),(0,1), colspan=3)
    axs = [ax1, ax2]
    kwargs = dict(markersize=7, linewidth=1) if stat in ['cost_at_0.01', 'jj_ff', 'jj_rec', 'jj'] else dict(markersize=10, linewidth=1)
    g = sns.lineplot(data=df[df['lambda'] == 0], ax=axs[0], legend=False, hue_order=['E/I-EFF-RNN', 'E/I-RNN', 'E/I-FNN', 'EFF-RNN', 'EFF-FNN', 'RNN', 'FNN'],
                     x='lambda', y=stat, hue='arch', style='arch', markers=True, dashes=False,
                     **kwargs,
                    )
    g = sns.lineplot(data=df[df['lambda'] != 0], ax=axs[1], hue_order=['E/I-EFF-RNN', 'E/I-RNN', 'E/I-FNN', 'EFF-RNN', 'EFF-FNN', 'RNN', 'FNN'],
                     x='lambda', y=stat, hue='arch', style='arch', markers=True, dashes=False,
                     **kwargs,
                    )
    axs[1].set_xscale('log')
    axs[0].set_xlim([-0.01, 0.01])
    axs[0].set_xticks([0])
    axs[0].spines['right'].set_visible(False)
    axs[0].set_xlabel('')
    axs[0].set_ylabel(ylabel)
    axs[1].spines['left'].set_visible(False)
    axs[1].set_ylabel('')
    axs[1].set_xlabel(r'$\lambda_w$')
    axs[1].legend(bbox_to_anchor=(1.05, 0.8), loc=2, borderaxespad=0.)
    if stat in ['jj_rec', 'jj_ff', 'cost_at_0.01', 'jj']:
        axs[1].set_yscale('log')
#         axs[1].set_aspect('equal', adjustable='datalim', share=True)
#         axs[1].set_xscale('log') # do it again..
    plt.setp(ax2.get_yticklabels(), visible=False)

    plt.subplots_adjust(wspace=0.05)
    plt.tight_layout()
    plt.savefig(f'{save_dir}/{stat}.png', dpi=300, bbox_inches='tight')
    plt.show()