# Experiments 5: CNN on MNIST

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
%matplotlib inline
import tensorflow as tf

from data import Datafile, load_data
from influence.emp_risk_optimizer import EmpiricalRiskOptimizer
from influence.plot_utils import compare_with_loo, show_graph
from influence.closed_forms import I_loss_RidgeCf
from models.neural_nets import ConvNet

  from ._conv import register_converters as _register_converters


In [2]:
X_train, X_test, y_train, y_test, test_indices = load_data(
    Datafile.BinaryMNIST17, test_config=10)
n_tr, p = X_train.shape
n_te, _ = X_test.shape
y_train_onehot = np.eye(2)[y_train.reshape(-1)]
y_test_onehot = np.eye(2)[y_test.reshape(-1)]
print(n_tr, p)

X_train shape: (9075, 784)
y_train shape: (9075, 1)
X_test shape: (10, 784)
y_test shape: (10, 1)
9075 784


In [3]:
init_eta = 0.01
batch_size = 1000
train_iter = 10
traceback_checkpoint = 800
loo_extra_iter = 200
decay_epochs = (10000, 20000)
checkpoint_iter = traceback_checkpoint - 1
iter_to_switch_off_minibatch = np.inf
iter_to_switch_to_sgd = np.inf
# LOO a on random set of training indices, otherwise too slow
leave_indices = np.random.choice(n_tr, size=150, replace=False)

if hasattr(test_indices, '__iter__') and hasattr(leave_indices, '__iter__'):
    assert not set(test_indices) & set(leave_indices)
    print(test_indices)
    print(leave_indices)

## Configure and Fit CNN

In [12]:
model = ConvNet(
    model_name='CNN-MNIST',
    init_eta=init_eta,
    decay_epochs=decay_epochs,
    batch_size=batch_size,
    input_side=28,
    n_channels=1,
    filter_size1=5,
    n_filters1=1,
    filter_size2=5,
    n_filters2=1,
    fc_size=10,
    down_sample=2
)

In [13]:
tf.reset_default_graph()
model.fit(
    X_train, y_train_onehot,
    n_iter=train_iter,
    verbose=1000,
    iter_to_switch_off_minibatch=iter_to_switch_off_minibatch,
    iter_to_switch_to_sgd=iter_to_switch_to_sgd,
    traceback_checkpoint=traceback_checkpoint,
    show_eval=False
)

Step 0, Epoch 0: loss = 9567.10644531 (0.572 sec)
Step 1, Epoch 0: loss = 8296.79296875 (0.334 sec)
Step 2, Epoch 0: loss = 6595.02197266 (0.359 sec)
Step 3, Epoch 0: loss = 5581.01318359 (0.348 sec)
Step 4, Epoch 0: loss = 4101.72998047 (0.353 sec)
Step 5, Epoch 0: loss = 3761.17138672 (0.344 sec)
Step 6, Epoch 0: loss = 2922.09497070 (0.429 sec)
Step 7, Epoch 0: loss = 2314.22460938 (0.393 sec)
Step 8, Epoch 0: loss = 2001.82824707 (0.439 sec)
Step 9, Epoch 0: loss = 1577.54846191 (0.382 sec)


CNN-MNIST(init_eta=0.01,batch_size=1000,decay_epochs=(10000, 20000),filter_size1=5,n_filters1=1,filter_size2=5,n_filters2=1,fc_size=10)

In [14]:
print("Train accuracy:", np.sum(
    model.predict(X_train).reshape(n_tr,1) == y_train)/n_tr)
print("Test accuracy:", np.sum(
    model.predict(X_test).reshape(n_te,1) == y_test)/n_te)

Train accuracy: 0.6369146005509642
Test accuracy: 0.6


In [15]:
model.get_eval(items=['params'])

{'W_conv1': array([[[[ 0.74146223]],
 
         [[ 0.1709393 ]],
 
         [[-1.6046932 ]],
 
         [[-0.61648387]],
 
         [[-0.84862876]]],
 
 
        [[[-0.01115067]],
 
         [[ 0.6649913 ]],
 
         [[ 0.95795745]],
 
         [[-0.6536788 ]],
 
         [[ 1.3629535 ]]],
 
 
        [[[-0.02723947]],
 
         [[ 0.6491738 ]],
 
         [[-1.1423098 ]],
 
         [[-1.9197695 ]],
 
         [[ 1.5273995 ]]],
 
 
        [[[-0.1930811 ]],
 
         [[ 0.7445659 ]],
 
         [[-1.4942554 ]],
 
         [[-1.4145113 ]],
 
         [[ 0.35153612]]],
 
 
        [[[ 0.64099705]],
 
         [[-0.73602927]],
 
         [[ 0.00317261]],
 
         [[-0.44992003]],
 
         [[-0.7936328 ]]]], dtype=float32), 'W_conv2': array([[[[ 0.6994618 ]],
 
         [[-0.6585857 ]],
 
         [[ 0.5226747 ]],
 
         [[-0.391366  ]],
 
         [[-0.05562441]]],
 
 
        [[[-1.0970186 ]],
 
         [[-0.9105867 ]],
 
         [[ 0.05378523]],
 
         [[ 0.3594064 ]]

In [8]:
show_graph(tf.get_default_graph())

In [17]:
model.influence_loss(
    X_test, y_test_onehot,
    method='lissa',
    damping=0.001, 
    minibatch=True,
    leave_indices=leave_indices,
    batch_size=10,
    depth=10000,
    repeat=1,
    scale=1e6,
    verbose=10
)

Fetch training loss gradients (0.416 sec)
--- Lissa Sample 0 ---
Recursion depth: 0, hvp norm: 36954.7890625
Recursion depth: 10, hvp norm: 221420.59375
Recursion depth: 20, hvp norm: 405550.5
Recursion depth: 30, hvp norm: 586670.1875
Recursion depth: 40, hvp norm: 767865.6875
Recursion depth: 50, hvp norm: 947167.9375
Recursion depth: 60, hvp norm: 1128533.875
Recursion depth: 70, hvp norm: 1305372.25
Recursion depth: 80, hvp norm: 1462532.75
Recursion depth: 90, hvp norm: 1631816.625
Recursion depth: 100, hvp norm: 1811104.0
Recursion depth: 110, hvp norm: 1979793.375
Recursion depth: 120, hvp norm: 2154768.25
Recursion depth: 130, hvp norm: 2321442.25
Recursion depth: 140, hvp norm: 2486775.75
Recursion depth: 150, hvp norm: 2647298.5
Recursion depth: 160, hvp norm: 2815226.75
Recursion depth: 170, hvp norm: 2994566.0
Recursion depth: 180, hvp norm: 3183678.5
Recursion depth: 190, hvp norm: 3348259.25
Recursion depth: 200, hvp norm: 3515818.25
Recursion depth: 210, hvp norm: 366653

Recursion depth: 1880, hvp norm: 287965664.0
Recursion depth: 1890, hvp norm: 285179360.0
Recursion depth: 1900, hvp norm: 282255136.0
Recursion depth: 1910, hvp norm: 279748928.0
Recursion depth: 1920, hvp norm: 277530880.0
Recursion depth: 1930, hvp norm: 275309632.0
Recursion depth: 1940, hvp norm: 272546880.0
Recursion depth: 1950, hvp norm: 270285600.0
Recursion depth: 1960, hvp norm: 268255760.0
Recursion depth: 1970, hvp norm: 266120144.0
Recursion depth: 1980, hvp norm: 261930848.0
Recursion depth: 1990, hvp norm: 260027648.0
Recursion depth: 2000, hvp norm: 258406032.0
Recursion depth: 2010, hvp norm: 256109968.0
Recursion depth: 2020, hvp norm: 254059488.0
Recursion depth: 2030, hvp norm: 252027248.0
Recursion depth: 2040, hvp norm: 249898320.0
Recursion depth: 2050, hvp norm: 248152928.0
Recursion depth: 2060, hvp norm: 246512816.0
Recursion depth: 2070, hvp norm: 244817088.0
Recursion depth: 2080, hvp norm: 243437152.0
Recursion depth: 2090, hvp norm: 241772896.0
Recursion 

Recursion depth: 3690, hvp norm: 9511901184.0
Recursion depth: 3700, hvp norm: 9398560768.0
Recursion depth: 3710, hvp norm: 9296368640.0
Recursion depth: 3720, hvp norm: 9180270592.0
Recursion depth: 3730, hvp norm: 9070653440.0
Recursion depth: 3740, hvp norm: 8970588160.0
Recursion depth: 3750, hvp norm: 8878090240.0
Recursion depth: 3760, hvp norm: 8783551488.0
Recursion depth: 3770, hvp norm: 8690686976.0
Recursion depth: 3780, hvp norm: 8618057728.0
Recursion depth: 3790, hvp norm: 8588221952.0
Recursion depth: 3800, hvp norm: 8408559104.0
Recursion depth: 3810, hvp norm: 8367671296.0
Recursion depth: 3820, hvp norm: 8316149248.0
Recursion depth: 3830, hvp norm: 8300037632.0
Recursion depth: 3840, hvp norm: 8261448704.0
Recursion depth: 3850, hvp norm: 8232407040.0
Recursion depth: 3860, hvp norm: 8204686848.0
Recursion depth: 3870, hvp norm: 8181956608.0
Recursion depth: 3880, hvp norm: 8129021440.0
Recursion depth: 3890, hvp norm: 8074961408.0
Recursion depth: 3900, hvp norm: 8

Recursion depth: 5460, hvp norm: 23724302336.0
Recursion depth: 5470, hvp norm: 23651903488.0
Recursion depth: 5480, hvp norm: 23510044672.0
Recursion depth: 5490, hvp norm: 23405897728.0
Recursion depth: 5500, hvp norm: 23335714816.0
Recursion depth: 5510, hvp norm: 23261990912.0
Recursion depth: 5520, hvp norm: 23195305984.0
Recursion depth: 5530, hvp norm: 23143706624.0
Recursion depth: 5540, hvp norm: 23094358016.0
Recursion depth: 5550, hvp norm: 21883826176.0
Recursion depth: 5560, hvp norm: 21860704256.0
Recursion depth: 5570, hvp norm: 21776934912.0
Recursion depth: 5580, hvp norm: 21656428544.0
Recursion depth: 5590, hvp norm: 21542975488.0
Recursion depth: 5600, hvp norm: 21513383936.0
Recursion depth: 5610, hvp norm: 21455927296.0
Recursion depth: 5620, hvp norm: 21439551488.0
Recursion depth: 5630, hvp norm: 21468950528.0
Recursion depth: 5640, hvp norm: 21432899584.0
Recursion depth: 5650, hvp norm: 21360771072.0
Recursion depth: 5660, hvp norm: 21350674432.0
Recursion dep

Recursion depth: 7180, hvp norm: 768283705344.0
Recursion depth: 7190, hvp norm: 778823335936.0
Recursion depth: 7200, hvp norm: 5670941753344.0
Recursion depth: 7210, hvp norm: 5601505050624.0
Recursion depth: 7220, hvp norm: 5527082893312.0
Recursion depth: 7230, hvp norm: 5322529832960.0
Recursion depth: 7240, hvp norm: 5257503440896.0
Recursion depth: 7250, hvp norm: 5200592502784.0
Recursion depth: 7260, hvp norm: 5135366356992.0
Recursion depth: 7270, hvp norm: 5066025074688.0
Recursion depth: 7280, hvp norm: 4993167917056.0
Recursion depth: 7290, hvp norm: 4934837731328.0
Recursion depth: 7300, hvp norm: 4881667588096.0
Recursion depth: 7310, hvp norm: 4825918472192.0
Recursion depth: 7320, hvp norm: 4774283968512.0
Recursion depth: 7330, hvp norm: 4725958246400.0
Recursion depth: 7340, hvp norm: 4671118245888.0
Recursion depth: 7350, hvp norm: 4613335941120.0
Recursion depth: 7360, hvp norm: 4570733346816.0
Recursion depth: 7370, hvp norm: 4530286624768.0
Recursion depth: 7380,

Recursion depth: 8850, hvp norm: 20511148998656.0
Recursion depth: 8860, hvp norm: 20482506096640.0
Recursion depth: 8870, hvp norm: 20462394408960.0
Recursion depth: 8880, hvp norm: 20379544322048.0
Recursion depth: 8890, hvp norm: 20347692777472.0
Recursion depth: 8900, hvp norm: 20297434529792.0
Recursion depth: 8910, hvp norm: 20352715456512.0
Recursion depth: 8920, hvp norm: 19542667427840.0
Recursion depth: 8930, hvp norm: 19443042222080.0
Recursion depth: 8940, hvp norm: 19360349421568.0
Recursion depth: 8950, hvp norm: 19295639699456.0
Recursion depth: 8960, hvp norm: 19221673148416.0
Recursion depth: 8970, hvp norm: 19154759319552.0
Recursion depth: 8980, hvp norm: 19128511365120.0
Recursion depth: 8990, hvp norm: 19126279995392.0
Recursion depth: 9000, hvp norm: 19105614659584.0
Recursion depth: 9010, hvp norm: 19032621187072.0
Recursion depth: 9020, hvp norm: 18966575579136.0
Recursion depth: 9030, hvp norm: 18968702091264.0
Recursion depth: 9040, hvp norm: 18906257293312.0


Recursion depth: 670, hvp norm: 0.0
Recursion depth: 680, hvp norm: 0.0
Recursion depth: 690, hvp norm: 0.0
Recursion depth: 700, hvp norm: 0.0
Recursion depth: 710, hvp norm: 0.0
Recursion depth: 720, hvp norm: 0.0
Recursion depth: 730, hvp norm: 0.0
Recursion depth: 740, hvp norm: 0.0
Recursion depth: 750, hvp norm: 0.0
Recursion depth: 760, hvp norm: 0.0
Recursion depth: 770, hvp norm: 0.0
Recursion depth: 780, hvp norm: 0.0
Recursion depth: 790, hvp norm: 0.0
Recursion depth: 800, hvp norm: 0.0
Recursion depth: 810, hvp norm: 0.0
Recursion depth: 820, hvp norm: 0.0
Recursion depth: 830, hvp norm: 0.0
Recursion depth: 840, hvp norm: 0.0
Recursion depth: 850, hvp norm: 0.0
Recursion depth: 860, hvp norm: 0.0
Recursion depth: 870, hvp norm: 0.0
Recursion depth: 880, hvp norm: 0.0
Recursion depth: 890, hvp norm: 0.0
Recursion depth: 900, hvp norm: 0.0
Recursion depth: 910, hvp norm: 0.0
Recursion depth: 920, hvp norm: 0.0
Recursion depth: 930, hvp norm: 0.0
Recursion depth: 940, hvp no

Recursion depth: 2910, hvp norm: 0.0
Recursion depth: 2920, hvp norm: 0.0
Recursion depth: 2930, hvp norm: 0.0
Recursion depth: 2940, hvp norm: 0.0
Recursion depth: 2950, hvp norm: 0.0
Recursion depth: 2960, hvp norm: 0.0
Recursion depth: 2970, hvp norm: 0.0
Recursion depth: 2980, hvp norm: 0.0
Recursion depth: 2990, hvp norm: 0.0
Recursion depth: 3000, hvp norm: 0.0
Recursion depth: 3010, hvp norm: 0.0
Recursion depth: 3020, hvp norm: 0.0
Recursion depth: 3030, hvp norm: 0.0
Recursion depth: 3040, hvp norm: 0.0
Recursion depth: 3050, hvp norm: 0.0
Recursion depth: 3060, hvp norm: 0.0
Recursion depth: 3070, hvp norm: 0.0
Recursion depth: 3080, hvp norm: 0.0
Recursion depth: 3090, hvp norm: 0.0
Recursion depth: 3100, hvp norm: 0.0
Recursion depth: 3110, hvp norm: 0.0
Recursion depth: 3120, hvp norm: 0.0
Recursion depth: 3130, hvp norm: 0.0
Recursion depth: 3140, hvp norm: 0.0
Recursion depth: 3150, hvp norm: 0.0
Recursion depth: 3160, hvp norm: 0.0
Recursion depth: 3170, hvp norm: 0.0
R

Recursion depth: 5130, hvp norm: 0.0
Recursion depth: 5140, hvp norm: 0.0
Recursion depth: 5150, hvp norm: 0.0
Recursion depth: 5160, hvp norm: 0.0
Recursion depth: 5170, hvp norm: 0.0
Recursion depth: 5180, hvp norm: 0.0
Recursion depth: 5190, hvp norm: 0.0
Recursion depth: 5200, hvp norm: 0.0
Recursion depth: 5210, hvp norm: 0.0
Recursion depth: 5220, hvp norm: 0.0
Recursion depth: 5230, hvp norm: 0.0
Recursion depth: 5240, hvp norm: 0.0
Recursion depth: 5250, hvp norm: 0.0
Recursion depth: 5260, hvp norm: 0.0
Recursion depth: 5270, hvp norm: 0.0
Recursion depth: 5280, hvp norm: 0.0
Recursion depth: 5290, hvp norm: 0.0
Recursion depth: 5300, hvp norm: 0.0
Recursion depth: 5310, hvp norm: 0.0
Recursion depth: 5320, hvp norm: 0.0
Recursion depth: 5330, hvp norm: 0.0
Recursion depth: 5340, hvp norm: 0.0
Recursion depth: 5350, hvp norm: 0.0
Recursion depth: 5360, hvp norm: 0.0
Recursion depth: 5370, hvp norm: 0.0
Recursion depth: 5380, hvp norm: 0.0
Recursion depth: 5390, hvp norm: 0.0
R

KeyboardInterrupt: 