# Geometry of abstraction - DNN for MNIST recognition and information structure

The geometry of abstraction in hippocampus and prefrontal cortex <br>
Silvia Bernardi, Marcus K Benna, Mattia Rigotti, Jérôme Munuera, Stefano Fusi, Daniel Salzman <br>
bioRxiv 2018

In [None]:
import pdb

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
%matplotlib inline

from keras.models import Sequential
from keras.layers import Dense, Activation

from time import time

In [None]:
from keras.datasets import mnist

In [None]:
import os, sys
lib_path = os.path.abspath('../methods')
sys.path.insert(0, lib_path)

In [None]:
from models import Model
from data_tools import ImageDataset
import data_tools as dt

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
filt_labels = range(8)
mnist_8 = ImageDataset(x_train, y_train, x_test, y_test, filt_labels=filt_labels, spl=0.08)

Loading MNIST data, unfolding square to long representation:

In [None]:
# Currently dichotomies will only be binary
mnist8_parity = [list(map(lambda x: 2*x, range(4))), list(map(lambda x: 2*x + 1, range(4)))]
mnist8_smallness = [range(0,4), range(4,8)]
mnist8_prod = [set(s1).intersection(set(s2)) for s2 in mnist8_smallness for s1 in mnist8_parity]

In [None]:
mnist_8.build_dichLabels(mnist8_smallness, 'smaller_than_4')
mnist_8.build_dichLabels(mnist8_parity, 'parity')

mnist_8.hstack_dichs('parity', 'smaller_than_4')

In [None]:
# Alright, let's start with a three layer NN
w_in = mnist_8.tot_dim
w_1 = 100
w_2 = 100
w_out = 4

max_epochs = 400

## 1) Let's reproduce the DNN that was used in the paper:

In [None]:
dnn_from_paper = Sequential([
    Dense(w_1, input_shape=(w_in,)),
    Activation('tanh'),
    Dense(w_2),
    Activation('tanh'),
    Dense(4),
    Activation('tanh')
])

In [None]:
model = Model(dnn_from_paper)

In [None]:
# Train the model, iterating on the data in batches of 32 samples
model.fit(mnist_8, dich_name='parity_hstack_smaller_than_4', epochs=10, batch_size=32)

In [None]:
model.evaluate(mnist_8, dich_name='parity_hstack_smaller_than_4', batch_size=128)

In [None]:
model.sample_eval(mnist_8, 3)

## 2) Dimensionality reduction on the representations of the layers

Ok, so now let's try to use dimensionality reduction to analyze the content of the different layers.

In [None]:
from numpy.random import shuffle

spl_size = 75
spl_ids = np.arange(mnist_8.n_train)
shuffle(spl_ids)
spl_ids = spl_ids[:spl_size]

mnist_8.spl = {
    'x': mnist_8.train['x'][spl_ids],
    'y': mnist_8.train['y'][spl_ids]
}

In [None]:
from sklearn.decomposition import PCA

pca2 = PCA(n_components=2)
rprs1 = model.get_repr(mnist_8, mnist_8.spl, pca2)
fig1 = model.plot_reprs(mnist_8, mnist_8.spl, pca2)

In [None]:
for lay_id, rpr in enumerate(rprs1):
    print('Layer {0:d} - {1:.1f}% 2d var - {2:.1f}% + {3:.1f}%'.format(lay_id, 100*(rpr['reduced']['expl_var'][0]+rpr['reduced']['expl_var'][1]), 100*rpr['reduced']['expl_var'][0], 100*rpr['reduced']['expl_var'][1]))

In [None]:
%matplotlib widget
pca3 = PCA(n_components=3)
rprs2 = model.get_repr(mnist_8, mnist_8.spl, pca3)
fig2 = model.plot_reprs(mnist_8, mnist_8.spl, pca3)

In [None]:
for lay_id, rpr in enumerate(rprs2):
    print('Layer {0:d} - {1:.1f}% 3d var - {2:.1f}% + {3:.1f}% + {4:.1f}%'.format(lay_id, 100*(rpr['reduced']['expl_var'][0]+rpr['reduced']['expl_var'][1]+rpr['reduced']['expl_var'][2]), 100*rpr['reduced']['expl_var'][0], 100*rpr['reduced']['expl_var'][1], 100*rpr['reduced']['expl_var'][2]))

## 3) Analysis in terms of cross-condition generalization performance (CCGP) and parallelism score (PS)

## Paralellism Score (PS)

In [None]:
PScores = model.get_all_PS(mnist_8, rprs1, lay_id=5)

In [None]:
import matplotlib.gridspec as gridspec

# Creates subplots and unpacks the output array immediately
ps = [pscore for pscore in PScores.values()]
ps_top_dichs = [pscore for pscore in PScores.keys()]
plt.plot(ps, marker='.')

In [None]:
for rk, dich in enumerate(ps_top_dichs[:8]):
    print("Dich: {0:s} ranks {1:d} (PS={2:f})".format(str(dich), rk+1, ps[rk]))

## Cross-Condition Generalization Performance (CCGP)

In [None]:
(w_in, w_1, w_2, w_out) = (mnist_8.tot_dim, 100, 100, 4)
max_epochs = 400
max_epochs = 400
dnn_hstack_classif = Sequential([
    Dense(w_1, input_shape=(w_in,)),
    Activation('tanh'),
    Dense(w_2),
    Activation('tanh'),
    Dense(w_out),
    Activation('softmax')
])

ccgp_model = Model(dnn_hstack_classif)

In [None]:
import pdb
import types
from sklearn.svm import SVC, LinearSVC

def _patch_get_dich_CCGP(self, ds, dich, n_labels_retained=2):
		"""
		CCGP is defined as the capacity to generalize on an unseen condition from training on a subset of possible conditions
		"""
		dich_0 = dich
		dich_1 = np.setdiff1d(ds.labels, dich_0)
		assert len(dich_0) == len(dich_1)

		# For now, labels will only be retained from dich_1 and will be random within dich_1
		retained_labels = np.random.choice(dich_1, size=n_labels_retained, replace=False)
		train_labels = np.setdiff1d(ds.labels, retained_labels)

		ds_gnr = ds.generate_gnr_set(train_labels)
		#ds_gnr.spl = ds.spl

		dich_name = ''.join([str(x) for x in dich_0])
		ds_gnr.build_dichLabels([dich_0, dich_1], dich_name)
		ds_gnr.train['dichs'][dich_name] = ds_gnr.train['dichs'][dich_name][:,0]
		ds_gnr.test['dichs'][dich_name] = ds_gnr.test['dichs'][dich_name][:,0]
		ds_gnr.gnr['dichs'][dich_name] = ds_gnr.gnr['dichs'][dich_name][:,0]
		
		# Create the linear classifier and train it on the submodel instance
		svc = LinearSVC()
		rpr = {
			'train': self.get_repr(ds_gnr, ds_gnr.train, dimRed=None),
			'test': self.get_repr(ds_gnr, ds_gnr.test, dimRed=None),
			'gnr': self.get_repr(ds_gnr, ds_gnr.gnr, dimRed=None)
		}

		CCGP_across_layers = []

		for lay_id in range(self.n_layers):
			#pca2 = PCA(n_components=2)
			#train_red_repr = pca2.fit_transform(rpr['train'][lay_id]['original']['repr'])
            #svc.fit(train_red_repr, ds_gnr.train['dichs'][dich_name])
			svc.fit(rpr['train'][lay_id]['original']['repr'], ds_gnr.train['dichs'][dich_name])

			train_score = svc.score(rpr['train'][lay_id]['original']['repr'], ds_gnr.train['dichs'][dich_name])
			test_score = svc.score(rpr['test'][lay_id]['original']['repr'], ds_gnr.test['dichs'][dich_name])
			gnr_score = svc.score(rpr['gnr'][lay_id]['original']['repr'], ds_gnr.gnr['dichs'][dich_name])

			# Evaluate performance
			CCGP_across_layers.append({
				'train_labels': train_labels,
				'retained_labels': retained_labels,
				'train_score': train_score,
				'test_score': test_score,
				'gnr_score': gnr_score
			})

		return CCGP_across_layers
    
ccgp_model.get_dich_CCGP  = types.MethodType(_patch_get_dich_CCGP, ccgp_model)

In [None]:
def _patch_generate_gnr_set(self, train_labels, normalize=True):
	gnr_labels = np.setdiff1d(self.labels, train_labels)
	in_filt_train = np.isin(self.train['y'], train_labels)
	x_filt_train = np.where(in_filt_train)[0]
	x_train = self.train['x'].take(x_filt_train, axis=0)
	if normalize:
		x_train = normalize_array(x_train)

	dim_list = [d for d in self.axes_dim]
	dtuple_train = dim_list.copy()
	dtuple_train.insert(0, x_train.shape[0])
	dtuple_train = tuple(dtuple_train)
	x_train = x_train.reshape(dtuple_train)
	y_train = self.train['y'][in_filt_train]
	
	in_filt_test = np.isin(self.test['y'], train_labels)
	x_filt_test = np.where(in_filt_test)[0]
	x_test = self.test['x'].take(x_filt_test, axis=0)
	if normalize:
		x_test = normalize_array(x_test)

	dtuple_test = dim_list.copy()
	dtuple_test.insert(0, x_test.shape[0])
	dtuple_test = tuple(dtuple_test)
	x_test = x_test.reshape(dtuple_test)
	y_test = self.test['y'][in_filt_test]
	
	x_filt_gnrtrain = np.where(~in_filt_train)[0]
	x_filt_gnrtest = np.where(~in_filt_test)[0]
	x_gnr = self.train['x'].take(x_filt_gnrtrain, axis=0)
	x_gnr = np.vstack((x_gnr, self.test['x'].take(x_filt_gnrtest, axis=0)))
	if normalize:
		x_gnr = normalize_array(x_gnr)
		
	dtuple_gnr = dim_list.copy()
	dtuple_gnr.insert(0, x_gnr.shape[0])
	dtuple_gnr = tuple(dtuple_gnr)
	x_gnr = x_gnr.reshape(dtuple_gnr)
	y_gnr = self.train['y'][~in_filt_train]
	y_gnr = np.hstack((y_gnr, self.test['y'][~in_filt_test]))
	
	return ImageDataset(x_train, y_train, x_test, y_test, x_gnr, y_gnr)

mnist_8.generate_gnr_set = types.MethodType(_patch_generate_gnr_set, mnist_8)

def normalize_array(ar):
	"""
	Each row must be an entry, each column must be a dimension
	"""
	ar_mean = np.mean(ar, axis=0)
	nar = ar - ar_mean
	ar_min = np.min(nar, axis=0)
	ar_max = np.max(nar, axis=0)
	nar = np.divide(2*nar, ar_max-ar_min, out=np.zeros_like(nar), where=(ar_max!=ar_min)) - np.divide(ar_max+ar_min, ar_max-ar_min, out=np.zeros_like(nar), where=(ar_max!=ar_min))
	return nar

In [None]:
start = time()
CCGP1_across_layers = ccgp_model.get_dich_CCGP(mnist_8, (0, 2, 4, 6), n_labels_retained=1)
dur1 = time() - start
print("CCGP_1 computation on parity dichotomy took {0:.1f}s".format(dur1))

In [None]:
start = time()
CCGP2_across_layers = ccgp_model.get_dich_CCGP(mnist_8, (0, 2, 4, 6), n_labels_retained=2)
dur2 = time() - start
print("CCGP_2 computation on parity dichotomy took {0:.1f}s".format(dur2))

In [None]:
dichs_to_include = [tuple(mnist8_parity[0]), tuple(mnist8_smallness[0])]
perfs = ccgp_model.get_all_CCGP(mnist_8, max_n_dichs=20, dichs_to_include=dichs_to_include)

In [None]:
ccgp_scores = []
df_perf = []
ccgp_top_dichs = []

fig = plt.figure(figsize=(8, 8*ccgp_model.n_layers))
axes = []
gs1 = gridspec.GridSpec(ccgp_model.n_layers,1)
gs1.update(wspace=0.025, hspace=0.2)

for lay_id, _ in enumerate(ccgp_model.layers):
    print("LAYER {0:d} - Representation CCGP score".format(lay_id))
    
    ax = plt.subplot(gs1[lay_id])
    ccgp_scores.append({k: v[lay_id] for k,v in perfs.items()})
    df_perf.append(pd.DataFrame.from_dict(ccgp_scores[lay_id], orient='index'))
    ccgp = [score for score in df_perf[lay_id]['gnr_score']]
    sort_ids = np.argsort(ccgp).tolist()
    sort_ids.reverse()
    ccgp = [ccgp[i] for i in sort_ids]
    ccgp_top_dichs.append(df_perf[lay_id].index)
    ccgp_top_dichs[lay_id] = [ccgp_top_dichs[lay_id][i] for i in sort_ids]
    ax.plot(ccgp, linestyle='', marker='+')
    
    parity_rank = ccgp_top_dichs[lay_id].index((0, 2, 4, 6))
    greatness_rank = ccgp_top_dichs[lay_id].index((0, 1, 2, 3))
    ax.plot(parity_rank, ccgp[parity_rank], marker='+', color='red', markersize=10, markeredgewidth=4)
    ax.plot(greatness_rank, ccgp[greatness_rank], marker='+', color='blue', markersize=10, markeredgewidth=4)
    
    print("Top dichotomies for CCGP\n________________________")
    for rk, dich in enumerate(ccgp_top_dichs[lay_id][:8]):
        print("Dich: {0:s} ranks {1:d} (CCGP={2:f})".format(str(dich), rk+1, ccgp[rk]))

In [None]:
for rk, dich in enumerate(ccgp_top_dichs_l5[:8]):
    print("Dich: {0:s} ranks {1:d} (CCGP={2:f})".format(str(dich), rk+1, ps[rk]))