-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
393 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
from itertools import product | ||
import numpy as np | ||
import torch | ||
from torch.autograd import grad | ||
|
||
import torch.multiprocessing as mp | ||
import functools | ||
|
||
|
||
def all_indices(a): | ||
return product(*[range(s) for s in a.shape]) | ||
|
||
|
||
def flat_total_derivative(output, input_, create_graph=True): | ||
derivatives = [] | ||
|
||
for i, ind in enumerate(all_indices(output)): | ||
z = torch.zeros_like(output) | ||
z[ind] = 1.0 | ||
|
||
d, = grad(output, input_, grad_outputs=z, create_graph=create_graph) | ||
|
||
# d, = grad(output[ind], input_, create_graph=create_graph) | ||
|
||
derivatives.append(d) | ||
|
||
return torch.stack(derivatives, dim=0).reshape(output.shape + input_.shape) | ||
|
||
|
||
def total_derivative(output, input_, create_graph=True): | ||
return flat_total_derivative(output, input_) | ||
|
||
|
||
def diagonal_contraction(a, b, dims=2): | ||
# contracts the first dims dimensions | ||
assert a.shape[:dims] == b.shape[:dims] | ||
assert a.shape[dims:] == b.shape[dims:] | ||
|
||
m = torch.mul(a, b) | ||
|
||
result = torch.sum(m, dim=range(a.dim())[:dims]) | ||
return result | ||
|
||
# FIXME: missing a term | ||
# FIXME: working only for 2d tensors in path, need to fix wherever tensordot_pytorch or torch.permute is used | ||
# FIXME: divide by the batch size | ||
def flat_hessian(f, path, w, diagonal=False): | ||
# path is a list of intermediate tensors from f to w | ||
# it should not include f or w | ||
if len(path) == 0: | ||
raise NotImplementedError | ||
|
||
# the middle factor | ||
z = path[0] | ||
fz = total_derivative(f, z) | ||
fzz = total_derivative(fz, z) | ||
|
||
# the right factor | ||
pathw = path + [w] # length >= 2 | ||
zw = total_derivative(pathw[0], pathw[1]) | ||
for i in range(len(path))[1:]: | ||
next_ = total_derivative(pathw[i], pathw[i+1]) | ||
zw = tensordot_pytorch(zw, next_) | ||
|
||
|
||
# FIXME: divide by batch_size | ||
fzw = tensordot_pytorch(fzz, zw) | ||
|
||
if diagonal: | ||
# FIXME: divide by batch_size | ||
flattened = diagonal_contraction(fzw, zw) | ||
|
||
return flattened.reshape(w.shape) | ||
else: | ||
# full hessian | ||
a = range(z.dim()) | ||
fww = tensordot_pytorch(fzw, zw, axes=[a, a]) | ||
|
||
return fww | ||
|
||
|
||
def diagonal_hessian(f, path, w): | ||
return flat_hessian(f, path, w, diagonal=True) | ||
|
||
|
||
# FIXME: missing a term | ||
def _diagonal_hessian_multi_inner(fzz, z, w): | ||
with torch.no_grad(): | ||
zw = total_derivative(z, w, create_graph=False) | ||
fzw = tensordot_pytorch(fzz, zw) | ||
fww_diagonal = diagonal_contraction(fzw, zw) | ||
|
||
# TODO: add fz * diagonal(zww) | ||
|
||
assert fww_diagonal.shape == w.shape | ||
|
||
return fww_diagonal | ||
|
||
|
||
def diagonal_hessian_multi(f, z, ws): | ||
fz = total_derivative(f, z) | ||
with torch.no_grad(): | ||
fzz = total_derivative(fz, z, create_graph=False).detach() | ||
|
||
fww_diagonals = [] | ||
|
||
for w in ws: | ||
fww_diagonal = _diagonal_hessian_multi_inner(fzz, z, w) | ||
|
||
assert fww_diagonal.shape == w.shape | ||
|
||
fww_diagonals.append(fww_diagonal.detach()) | ||
|
||
return fww_diagonals | ||
|
||
|
||
def _diagonal_hessian_multi(f, z, ws, trained=True): | ||
fz = total_derivative(f, z) | ||
with torch.no_grad(): | ||
fzz = total_derivative(fz, z, create_graph=False).detach() | ||
|
||
fww_diagonals = [] | ||
|
||
for w in ws: | ||
# fww_diagonal = _diagonal_hessian_multi_inner(fz, z, w) | ||
zw = total_derivative(z, w) | ||
with torch.no_grad(): | ||
zw_detached = zw.detach() | ||
fzw = tensordot_pytorch(fzz, zw_detached) | ||
fww_diagonal = diagonal_contraction(fzw, zw_detached) | ||
|
||
# if trained then fz is assumed to be negligible | ||
# this is VERY slow | ||
# NOTE: if z(w) factors through a ReLU then this additional term is not needed. | ||
if not trained: | ||
# calculate fz * diagonal(zww) | ||
zww_diagonal = [] | ||
w_dim = w.dim() | ||
for zw_ind in all_indices(zw): | ||
w_ind = zw_ind[-w_dim:] | ||
|
||
zz = torch.zeros_like(zw) | ||
zz[zw_ind] = 1.0 | ||
|
||
# shape = w.shape | ||
zww_slice, = grad(zw, w, grad_outputs=zz, create_graph=True) | ||
|
||
# import ipdb; ipdb.set_trace() | ||
|
||
zww_entry = zww_slice[w_ind] | ||
zww_diagonal.append(zww_entry.item()) | ||
|
||
zww_diagonal = torch.tensor(zww_diagonal).reshape(zw.shape) | ||
|
||
a = range(z.dim()) | ||
fww_diagonal += tensordot_pytorch(zw, zww_diagonal, axes=[a, a]) | ||
|
||
assert fww_diagonal.shape == w.shape | ||
|
||
fww_diagonals.append(fww_diagonal.detach()) | ||
|
||
return fww_diagonals | ||
|
||
|
||
# from: https://gist.github.com/deanmark/9aec75b7dc9fa71c93c4bc85c5438777 | ||
def tensordot_pytorch(a, b, axes=2): | ||
# code adapted from numpy | ||
try: | ||
iter(axes) | ||
except Exception: | ||
axes_a = list(range(-axes, 0)) | ||
axes_b = list(range(0, axes)) | ||
else: | ||
axes_a, axes_b = axes | ||
try: | ||
na = len(axes_a) | ||
axes_a = list(axes_a) | ||
except TypeError: | ||
axes_a = [axes_a] | ||
na = 1 | ||
try: | ||
nb = len(axes_b) | ||
axes_b = list(axes_b) | ||
except TypeError: | ||
axes_b = [axes_b] | ||
nb = 1 | ||
|
||
# uncomment in pytorch >= 0.5 | ||
# a, b = torch.as_tensor(a), torch.as_tensor(b) | ||
as_ = a.shape | ||
nda = a.dim() | ||
bs = b.shape | ||
ndb = b.dim() | ||
equal = True | ||
if na != nb: | ||
equal = False | ||
else: | ||
for k in range(na): | ||
if as_[axes_a[k]] != bs[axes_b[k]]: | ||
equal = False | ||
break | ||
if axes_a[k] < 0: | ||
axes_a[k] += nda | ||
if axes_b[k] < 0: | ||
axes_b[k] += ndb | ||
if not equal: | ||
raise ValueError("shape-mismatch for sum") | ||
|
||
# Move the axes to sum over to the end of "a" | ||
# and to the front of "b" | ||
notin = [k for k in range(nda) if k not in axes_a] | ||
newaxes_a = notin + axes_a | ||
N2 = 1 | ||
for axis in axes_a: | ||
N2 *= as_[axis] | ||
newshape_a = (int(np.multiply.reduce([as_[ax] for ax in notin])), N2) | ||
olda = [as_[axis] for axis in notin] | ||
|
||
notin = [k for k in range(ndb) if k not in axes_b] | ||
newaxes_b = axes_b + notin | ||
N2 = 1 | ||
for axis in axes_b: | ||
N2 *= bs[axis] | ||
newshape_b = (N2, int(np.multiply.reduce([bs[ax] for ax in notin]))) | ||
oldb = [bs[axis] for axis in notin] | ||
|
||
at = a.permute(newaxes_a).reshape(newshape_a) | ||
bt = b.permute(newaxes_b).reshape(newshape_b) | ||
|
||
res = at.matmul(bt) | ||
return res.reshape(olda + oldb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
from utee import selector | ||
|
||
import numpy as np | ||
|
||
# from pytorch_hessian_test import hessian_diagonal, hessian, total_derivative | ||
|
||
# from tensordot import tensordot_pytorch, contraction_pytorch | ||
|
||
from hessian_utils import * | ||
|
||
from itertools import product | ||
|
||
import datetime | ||
|
||
# NOTE: Even with cuda=False the selector still tried to map the variables to GPUs. I had to change code in mnist/model.py to force mapping to CPU. | ||
cuda = torch.device('cuda') | ||
model_raw, ds_fetcher, is_imagenet = selector.select('cifar10', cuda=True) | ||
|
||
# ps = list(model_raw.parameters()) | ||
|
||
# batch_size = 13 caused CUDA OOM on a K80 | ||
|
||
batch_size = 15 | ||
|
||
ds_val = ds_fetcher(batch_size=batch_size, train=False, val=True) | ||
|
||
for idx, (data, target) in enumerate(ds_val): | ||
print(idx) | ||
|
||
z = data.to(device=cuda) | ||
target = target.to(device=cuda) | ||
|
||
for layer in model_raw.features: | ||
z = layer(z) | ||
|
||
z = z.reshape([batch_size, 1024]) | ||
|
||
output = model_raw.classifier(z) | ||
loss = torch.nn.CrossEntropyLoss() | ||
f = loss(output, target) ** 2 | ||
|
||
# break | ||
|
||
print('start: {}'.format(datetime.datetime.now())) | ||
|
||
dhs = diagonal_hessian_multi(f, output, model_raw.parameters()) | ||
|
||
print('end: {}'.format(datetime.datetime.now())) | ||
print('memory: {}, {}'.format(torch.cuda.memory_allocated(), torch.cuda.memory_cached())) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
if idx > 10: | ||
break | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import torch | ||
from torch.autograd import Variable | ||
from utee import selector | ||
|
||
import numpy as np | ||
|
||
# from pytorch_hessian_test import hessian_diagonal, hessian, total_derivative | ||
|
||
# from tensordot import tensordot_pytorch, contraction_pytorch | ||
|
||
from hessian_utils import * | ||
|
||
from itertools import product | ||
|
||
import datetime | ||
|
||
batch_size = 200 | ||
|
||
# NOTE: Even with cuda=False the selector still tried to map the variables to GPUs. I had to change code in mnist/model.py to force mapping to CPU. | ||
cuda = torch.device('cuda') | ||
|
||
model_raw, ds_fetcher, is_imagenet = selector.select('mnist', cuda=True) | ||
|
||
# ps = list(model_raw.parameters()) | ||
|
||
layer1 = model_raw.model[:3] | ||
layer2 = model_raw.model[3:6] | ||
layer3 = model_raw.model[6] | ||
|
||
ds_val = ds_fetcher(batch_size=batch_size, train=False, val=True) | ||
for idx, (data, target) in enumerate(ds_val): | ||
# data = Variable(torch.FloatTensor(data)) | ||
# output = model_raw(data) | ||
|
||
data = data.to(device=cuda) | ||
target = target.to(device=cuda) | ||
|
||
flattened = data.reshape((-1, 28*28)) | ||
z1 = layer1(flattened) | ||
z2 = layer2(z1) | ||
z3 = layer3(z2) | ||
|
||
output = z3 | ||
|
||
loss = torch.nn.CrossEntropyLoss() | ||
|
||
f = loss(output, target) ** 2 | ||
|
||
print('start: {}'.format(datetime.datetime.now())) | ||
|
||
dhs = diagonal_hessian_multi(f, output, model_raw.parameters()) | ||
|
||
print('end: {}'.format(datetime.datetime.now())) | ||
print('memory: {}, {}'.format(torch.cuda.memory_allocated(), torch.cuda.memory_cached())) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
if idx > 5: | ||
break | ||
|
||
# # the kernels | ||
# w0 = ps[0] | ||
# w1 = ps[2] | ||
# w2 = ps[4] | ||
|
||
# print(w0.shape, w1.shape, w2.shape) | ||
# print(w0.device, w1.device, w2.device) | ||
|
||
|
||
# on one V100: (also, not really faster on 4 K80s) | ||
# In [9]: %time fw2w2_diagonal = diagonal_hessian(f, [z3], w2) | ||
# CPU times: user 340 ms, sys: 52 ms, total: 392 ms | ||
# Wall time: 385 ms | ||
# In [10]: %time fw1w1_diagonal = diagonal_hessian(f, [z3, z2], w1) | ||
# CPU times: user 5.96 s, sys: 936 ms, total: 6.89 s | ||
# Wall time: 6.81 s | ||
# In [11]: %time fw0w0_diagonal = diagonal_hessian(f, [z3, z2, z1], w0) | ||
# CPU times: user 16 s, sys: 2.68 s, total: 18.7 s | ||
# Wall time: 18.5 s | ||
|
||
# not really faster! | ||
# In [12]: %time fw0w0_diagonal = diagonal_hessian(f, [z3], w0) | ||
# CPU times: user 15.1 s, sys: 2.29 s, total: 17.4 s | ||
# Wall time: 17.4 s | ||
|
||
|
||
# fw2w2_diagonal = diagonal_hessian(f, [z3], w2) | ||
|
||
# print(fw2w2_diagonal.shape, np.max(fw2w2_diagonal.cpu().data.numpy()), np.min(fw2w2_diagonal.cpu().data.numpy())) | ||
|
||
|
||
# # on one V100 with the slow implementation | ||
# # RuntimeError: CUDA error: out of memory | ||
# fw1w1_diagonal = diagonal_hessian(f, [z3, z2], w1) | ||
|
||
# print(fw1w1_diagonal.shape, np.max(fw1w1_diagonal.cpu().data.numpy()), np.min(fw1w1_diagonal.cpu().data.numpy())) | ||
|
||
|
||
|
||
# fw0w0_diagonal = diagonal_hessian(f, [z3, z2, z1], w0) | ||
|
||
# print(fw0w0_diagonal.shape, np.max(fw0w0_diagonal.cpu().data.numpy()), np.min(fw0w0_diagonal.cpu().data.numpy())) | ||
|