-
Notifications
You must be signed in to change notification settings - Fork 0
/
kfac_small_generalization.py
455 lines (384 loc) · 15.4 KB
/
kfac_small_generalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# Mac iteration time: 1606 ms
# Linux 1080 TI iteration time: 132 ms
# iteration times in ms on on gtx titan x: min: 134.06, median: 155.68
# conventions. "_op things are ops"
# "x0" means numpy
# _live means it's used to update a variable value
# experiment prefixes
# prefix = "small_final" # for checkin
prefix = "generalize"
import util
import util as u
drop_l2 = True # drop L2 term
drop_sparsity = True # drop KL term
use_gpu = True
do_line_search = False # line-search and dump values at each iter
import sys
run_small = False # overrides dsize and num_steps
whitening_mode = 4 # 0 for gradient, 4 for full whitening
whiten_every_n_steps = 1 # how often to whiten
report_frequency = 3 # how often to print loss
num_steps = 10000
util.USE_MKL_SVD=True # Tensorflow vs MKL SVD
purely_linear = False # convert sigmoids into linear nonlinearities
use_tikhonov = True # use Tikhonov reg instead of Moore-Penrose pseudo-inv
Lambda = 1e-3 # magic lambda value from Jimmy Ba for Tikhonov
# adaptive line search
adaptive_step = False # adjust step length based on predicted decrease
adaptive_step_frequency = 1 # how often to adjust
adaptive_step_burn_in = 0 # let optimization go for a bit before adjusting
local_quadratics = False # use quadratic approximation to predict loss drop
measure_validation = True
import networkx as nx
import load_MNIST
import numpy as np
import scipy.io # for loadmat
from scipy import linalg # for svd
import math
import time
import os, sys
if use_gpu:
os.environ['CUDA_VISIBLE_DEVICES']='0'
else:
os.environ['CUDA_VISIBLE_DEVICES']=''
import tensorflow as tf
from util import t # transpose
def W_uniform(s1, s2): # uniform weight init from Ng UFLDL
r = np.sqrt(6) / np.sqrt(s1 + s2 + 1)
result = np.random.random(2*s2*s1)*2*r-r
return result
if __name__=='__main__':
np.random.seed(0)
tf.set_random_seed(0)
dtype = np.float32
# 64-bit doesn't help much, search for 64-bit in
# https://www.wolframcloud.com/objects/5f297f41-30f7-4b1b-972c-cac8d1f8d8e4
u.default_dtype = dtype
machine_epsilon = np.finfo(dtype).eps # 1e-7 or 1e-16
train_images = load_MNIST.load_MNIST_images('data/train-images-idx3-ubyte')
dsize = 10000
if run_small:
dsize = 1000
num_steps = 100
patches = train_images[:,:dsize];
test_patches = train_images[:,-dsize:]
assert dsize<25000
fs = [dsize, 28*28, 196, 28*28]
# values from deeplearning.stanford.edu/wiki/index.php/UFLDL_Tutorial
X0=patches
lambda_=3e-3
rho=tf.constant(0.1, dtype=dtype)
beta=3
W0f = W_uniform(fs[2],fs[3])
def f(i): return fs[i+1] # W[i] has shape f[i] x f[i-1]
dsize = f(-1)
n = len(fs) - 2
# helper to create variables with numpy or TF initial value
init_dict = {} # {var_placeholder: init_value}
vard = {} # {var: util.VarInfo}
def init_var(val, name, trainable=False, noinit=False):
if isinstance(val, tf.Tensor):
collections = [] if noinit else None
var = tf.Variable(val, name=name, collections=collections)
else:
val = np.array(val)
assert u.is_numeric, "Unknown type"
holder = tf.placeholder(dtype, shape=val.shape, name=name+"_holder")
var = tf.Variable(holder, name=name, trainable=trainable)
init_dict[holder] = val
var_p = tf.placeholder(var.dtype, var.shape)
var_setter = var.assign(var_p)
vard[var] = u.VarInfo(var_setter, var_p)
return var
lr = init_var(0.2, "lr")
if purely_linear: # need lower LR without sigmoids
lr = init_var(.02, "lr")
Wf = init_var(W0f, "Wf", True)
Wf_copy = init_var(W0f, "Wf_copy", True)
W = u.unflatten(Wf, fs[1:]) # perftodo: this creates transposes
X = init_var(X0, "X")
W.insert(0, X)
def sigmoid(x):
if not purely_linear:
return tf.sigmoid(x)
else:
return tf.identity(x)
def d_sigmoid(y):
if not purely_linear:
return y*(1-y)
else:
return 1
def kl(x, y):
return x * tf.log(x / y) + (1 - x) * tf.log((1 - x) / (1 - y))
def d_kl(x, y):
return (1-x)/(1-y) - x/y
# A[i] = activations needed to compute gradient of W[i]
# A[n+1] = network output
A = [None]*(n+2)
# A[0] is just for shape checks, assert fail on run
with tf.control_dependencies([tf.assert_equal(1, 0, message="too huge")]):
A[0] = u.Identity(dsize, dtype=dtype)
A[1] = W[0]
for i in range(1, n+1):
A[i+1] = sigmoid(W[i] @ A[i])
# reconstruction error and sparsity error
err = (A[3] - A[1])
rho_hat = tf.reduce_sum(A[2], axis=1, keep_dims=True)/dsize
# B[i] = backprops needed to compute gradient of W[i]
# B2[i] = backprops from sampled labels needed for natural gradient
B = [None]*(n+1)
B2 = [None]*(n+1)
B[n] = err*d_sigmoid(A[n+1])
sampled_labels_live = tf.random_normal((f(n), f(-1)), dtype=dtype, seed=0)
sampled_labels = init_var(sampled_labels_live, "sampled_labels", noinit=True)
B2[n] = sampled_labels*d_sigmoid(A[n+1])
for i in range(n-1, -1, -1):
backprop = t(W[i+1]) @ B[i+1]
backprop2 = t(W[i+1]) @ B2[i+1]
if i == 1 and not drop_sparsity:
backprop += beta*d_kl(rho, rho_hat)
backprop2 += beta*d_kl(rho, rho_hat)
B[i] = backprop*d_sigmoid(A[i+1])
B2[i] = backprop2*d_sigmoid(A[i+1])
# dW[i] = gradient of W[i]
dW = [None]*(n+1)
pre_dW = [None]*(n+1) # preconditioned dW
pre_dW_stable = [None]*(n+1) # preconditioned stable dW
cov_A = [None]*(n+1) # covariance of activations[i]
cov_B2 = [None]*(n+1) # covariance of synthetic backprops[i]
vars_svd_A = [None]*(n+1)
vars_svd_B2 = [None]*(n+1)
for i in range(1,n+1):
cov_A[i] = init_var(A[i]@t(A[i])/dsize, "cov_A%d"%(i,))
cov_B2[i] = init_var(B2[i]@t(B2[i])/dsize, "cov_B2%d"%(i,))
vars_svd_A[i] = u.SvdWrapper(cov_A[i],"svd_A_%d"%(i,))
vars_svd_B2[i] = u.SvdWrapper(cov_B2[i],"svd_B2_%d"%(i,))
if use_tikhonov:
whitened_A = u.regularized_inverse2(vars_svd_A[i],L=Lambda) @ A[i]
else:
whitened_A = u.pseudo_inverse2(vars_svd_A[i]) @ A[i]
if use_tikhonov:
whitened_B2 = u.regularized_inverse2(vars_svd_B2[i],L=Lambda) @ B[i]
else:
whitened_B2 = u.pseudo_inverse2(vars_svd_B2[i]) @ B[i]
whitened_A_stable = u.pseudo_inverse_sqrt2(vars_svd_A[i]) @ A[i]
whitened_B2_stable = u.pseudo_inverse_sqrt2(vars_svd_B2[i]) @ B[i]
pre_dW[i] = (whitened_B2 @ t(whitened_A))/dsize
pre_dW_stable[i] = (whitened_B2_stable @ t(whitened_A_stable))/dsize
dW[i] = (B[i] @ t(A[i]))/dsize
# Loss function
reconstruction = u.L2(err) / (2 * dsize)
sparsity = beta * tf.reduce_sum(kl(rho, rho_hat))
L2 = (lambda_ / 2) * (u.L2(W[1]) + u.L2(W[1]))
loss = reconstruction
if not drop_l2:
loss = loss + L2
if not drop_sparsity:
loss = loss + sparsity
grad_live = u.flatten(dW[1:])
pre_grad_live = u.flatten(pre_dW[1:]) # fisher preconditioned gradient
pre_grad_stable_live = u.flatten(pre_dW_stable[1:]) # sqrt fisher preconditioned grad
grad = init_var(grad_live, "grad")
pre_grad = init_var(pre_grad_live, "pre_grad")
pre_grad_stable = init_var(pre_grad_stable_live, "pre_grad_stable")
update_params_op = Wf.assign(Wf-lr*pre_grad).op
update_params_stable_op = Wf.assign(Wf-lr*pre_grad_stable).op
save_params_op = Wf_copy.assign(Wf).op
pre_grad_dot_grad = tf.reduce_sum(pre_grad*grad)
pre_grad_stable_dot_grad = tf.reduce_sum(pre_grad*grad)
grad_norm = tf.reduce_sum(grad*grad)
pre_grad_norm = u.L2(pre_grad)
pre_grad_stable_norm = u.L2(pre_grad_stable)
def dump_svd_info(step):
"""Dump singular values and gradient values in those coordinates."""
for i in range(1, n+1):
svd = vars_svd_A[i]
s0, u0, v0 = sess.run([svd.s, svd.u, svd.v])
util.dump(s0, "A_%d_%d"%(i, step))
A0 = A[i].eval()
At0 = v0.T @ A0
util.dump(A0 @ A0.T, "Acov_%d_%d"%(i, step))
util.dump(At0 @ At0.T, "Atcov_%d_%d"%(i, step))
util.dump(s0, "As_%d_%d"%(i, step))
for i in range(1, n+1):
svd = vars_svd_B2[i]
s0, u0, v0 = sess.run([svd.s, svd.u, svd.v])
util.dump(s0, "B2_%d_%d"%(i, step))
B0 = B[i].eval()
Bt0 = v0.T @ B0
util.dump(B0 @ B0.T, "Bcov_%d_%d"%(i, step))
util.dump(Bt0 @ Bt0.T, "Btcov_%d_%d"%(i, step))
util.dump(s0, "Bs_%d_%d"%(i, step))
def advance_batch():
sess.run(sampled_labels.initializer) # new labels for next call
def update_covariances():
ops_A = [cov_A[i].initializer for i in range(1, n+1)]
ops_B2 = [cov_B2[i].initializer for i in range(1, n+1)]
sess.run(ops_A+ops_B2)
def update_svds():
if whitening_mode>1:
vars_svd_A[2].update()
if whitening_mode>2:
vars_svd_B2[2].update()
if whitening_mode>3:
vars_svd_B2[1].update()
def init_svds():
"""Initialize our SVD to identity matrices."""
ops = []
for i in range(1, n+1):
ops.extend(vars_svd_A[i].init_ops)
ops.extend(vars_svd_B2[i].init_ops)
sess = tf.get_default_session()
sess.run(ops)
# create validation loss eval
layer = init_var(test_patches, "X_test")
for i in range(1, n+1):
layer = sigmoid(W[i] @ layer)
err = (layer - test_patches)
vloss = u.L2(err) / (2 * dsize)
init_op = tf.global_variables_initializer()
# tf.get_default_graph().finalize()
sess = tf.InteractiveSession()
sess.run(Wf.initializer, feed_dict=init_dict)
sess.run(X.initializer, feed_dict=init_dict)
advance_batch()
update_covariances()
init_svds()
sess.run(init_op, feed_dict=init_dict) # initialize everything else
print("Running training.")
u.reset_time()
step_lengths = [] # keep track of learning rates
losses = []
vlosses = []
ratios = [] # actual loss decrease / expected decrease
grad_norms = []
pre_grad_norms = [] # preconditioned grad norm squared
pre_grad_stable_norms = [] # sqrt preconditioned grad norms squared
target_delta_list = [] # predicted decrease linear approximation
target_delta2_list = [] # predicted decrease quadratic appromation
actual_delta_list = [] # actual decrease
# adaptive line search parameters
alpha=0.3 # acceptable fraction of predicted decrease
beta=0.8 # how much to shrink when violation
growth_rate=1.05 # how much to grow when too conservative
def update_cov_A(i):
sess.run(cov_A[i].initializer)
def update_cov_B2(i):
sess.run(cov_B2[i].initializer)
# only update whitening matrix of input activations in the beginning
if whitening_mode>0:
vars_svd_A[1].update()
# compute t(delta).H.delta/2
def hessian_quadratic(delta):
# update_covariances()
W = u.unflatten(delta, fs[1:])
W.insert(0, None)
total = 0
for l in range(1, n+1):
decrement = tf.trace(t(W[l])@cov_B2[l]@W[l]@cov_A[l])
total+=decrement
return (total/2).eval()
# compute t(delta).H^-1.delta/2
def hessian_quadratic_inv(delta):
# update_covariances()
W = u.unflatten(delta, fs[1:])
W.insert(0, None)
total = 0
for l in range(1, n+1):
invB2 = u.pseudo_inverse2(vars_svd_B2[l])
invA = u.pseudo_inverse2(vars_svd_A[l])
decrement = tf.trace(t(W[l])@invB2@W[l]@invA)
total+=decrement
return (total/2).eval()
# do line search, dump values as csv
def line_search(initial_value, direction, step, num_steps):
saved_val = tf.Variable(Wf)
sess.run(saved_val.initializer)
pl = tf.placeholder(dtype, shape=(), name="linesearch_p")
assign_op = Wf.assign(initial_value - direction*step*pl)
vals = []
for i in range(num_steps):
sess.run(assign_op, feed_dict={pl: i})
vals.append(loss.eval())
sess.run(Wf.assign(saved_val)) # restore original value
return vals
for step in range(num_steps):
update_covariances()
if step % whiten_every_n_steps==0:
update_svds()
sess.run(grad.initializer)
sess.run(pre_grad.initializer)
if measure_validation:
lr0, loss0, vloss0 = sess.run([lr, loss, vloss])
else:
lr0, loss0 = sess.run([lr, loss])
vloss0 = 0
save_params_op.run()
# regular inverse becomes unstable when grad norm exceeds 1
stabilized_mode = grad_norm.eval()<1
if stabilized_mode and not use_tikhonov:
update_params_stable_op.run()
else:
update_params_op.run()
loss1 = loss.eval()
advance_batch()
# line search stuff
target_slope = (-pre_grad_dot_grad.eval() if stabilized_mode else
-pre_grad_stable_dot_grad.eval())
target_delta = lr0*target_slope
target_delta_list.append(target_delta)
# second order prediction of target delta
# TODO: the sign is wrong, debug this
# https://www.wolframcloud.com/objects/8f287f2f-ceb7-42f7-a599-1c03fda18f28
if local_quadratics:
x0 = Wf_copy.eval()
x_opt = x0-pre_grad.eval()
# computes t(x)@H^-1 @(x)/2
y_opt = loss0 - hessian_quadratic_inv(grad)
# computes t(x)@H @(x)/2
y_expected = hessian_quadratic(Wf-x_opt)+y_opt
target_delta2 = y_expected - loss0
target_delta2_list.append(target_delta2)
actual_delta = loss1 - loss0
actual_slope = actual_delta/lr0
slope_ratio = actual_slope/target_slope # between 0 and 1.01
actual_delta_list.append(actual_delta)
if do_line_search:
vals1 = line_search(Wf_copy, pre_grad, lr/100, 40)
vals2 = line_search(Wf_copy, grad, lr/100, 40)
u.dump(vals1, "line1-%d"%(i,))
u.dump(vals2, "line2-%d"%(i,))
losses.append(loss0)
vlosses.append(vloss0)
step_lengths.append(lr0)
ratios.append(slope_ratio)
grad_norms.append(grad_norm.eval())
pre_grad_norms.append(pre_grad_norm.eval())
pre_grad_stable_norms.append(pre_grad_stable_norm.eval())
if step % report_frequency == 0:
print("Step %d loss %.2f, vloss: %.2f, target decrease %.3f, actual decrease, %.3f"%(step, loss0, vloss0, target_delta, actual_delta))
# print("Step %d loss %.2f, target decrease %.3f, actual decrease, %.3f ratio %.2f grad norm: %.2f pregrad norm: %.2f"%(step, loss0, target_delta, actual_delta, slope_ratio, grad_norm.eval(), pre_grad_norm.eval()))
if adaptive_step_frequency and adaptive_step and step>adaptive_step_burn_in:
# shrink if wrong prediction, don't shrink if prediction is tiny
if slope_ratio < alpha and abs(target_delta)>1e-6 and adaptive_step:
print("%.2f %.2f %.2f"%(loss0, loss1, slope_ratio))
print("Slope optimality %.2f, shrinking learning rate to %.2f"%(slope_ratio, lr0*beta,))
sess.run(vard[lr].setter, feed_dict={vard[lr].p: lr0*beta})
# grow learning rate, slope_ratio .99 worked best for gradient
elif step>0 and i%50 == 0 and slope_ratio>0.90 and adaptive_step:
print("%.2f %.2f %.2f"%(loss0, loss1, slope_ratio))
print("Growing learning rate to %.2f"%(lr0*growth_rate))
sess.run(vard[lr].setter, feed_dict={vard[lr].p:
lr0*growth_rate})
u.record_time()
u.dump(losses, "%s_losses.csv"%(prefix))
u.dump(vlosses, "%s_vlosses.csv"%(prefix,))
u.dump(step_lengths, "%s_step_lengths.csv"%(prefix,))
u.dump(ratios, "%s_ratios.csv"%(prefix,))
u.dump(grad_norms, "%s_grad_norms.csv"%(prefix,))
u.dump(pre_grad_norms, "%s_pre_grad_norms.csv"%(prefix,))
u.dump(pre_grad_stable_norms, "%s_pre_grad_stable_norms.csv"%(prefix,))
u.dump(target_delta_list, "%s_target_delta.csv"%(prefix,))
u.dump(target_delta2_list, "%s_target_delta2.csv"%(prefix,))
u.dump(actual_delta_list, "%s_actual_delta.csv"%(prefix,))
u.summarize_time()