Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
wgao9 committed Nov 21, 2018
1 parent d0ca7c6 commit c7162c2
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 0 deletions.
231 changes: 231 additions & 0 deletions hessian_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from itertools import product
import numpy as np
import torch
from torch.autograd import grad

import torch.multiprocessing as mp
import functools


def all_indices(a):
return product(*[range(s) for s in a.shape])


def flat_total_derivative(output, input_, create_graph=True):
derivatives = []

for i, ind in enumerate(all_indices(output)):
z = torch.zeros_like(output)
z[ind] = 1.0

d, = grad(output, input_, grad_outputs=z, create_graph=create_graph)

# d, = grad(output[ind], input_, create_graph=create_graph)

derivatives.append(d)

return torch.stack(derivatives, dim=0).reshape(output.shape + input_.shape)


def total_derivative(output, input_, create_graph=True):
return flat_total_derivative(output, input_)


def diagonal_contraction(a, b, dims=2):
# contracts the first dims dimensions
assert a.shape[:dims] == b.shape[:dims]
assert a.shape[dims:] == b.shape[dims:]

m = torch.mul(a, b)

result = torch.sum(m, dim=range(a.dim())[:dims])
return result

# FIXME: missing a term
# FIXME: working only for 2d tensors in path, need to fix wherever tensordot_pytorch or torch.permute is used
# FIXME: divide by the batch size
def flat_hessian(f, path, w, diagonal=False):
# path is a list of intermediate tensors from f to w
# it should not include f or w
if len(path) == 0:
raise NotImplementedError

# the middle factor
z = path[0]
fz = total_derivative(f, z)
fzz = total_derivative(fz, z)

# the right factor
pathw = path + [w] # length >= 2
zw = total_derivative(pathw[0], pathw[1])
for i in range(len(path))[1:]:
next_ = total_derivative(pathw[i], pathw[i+1])
zw = tensordot_pytorch(zw, next_)


# FIXME: divide by batch_size
fzw = tensordot_pytorch(fzz, zw)

if diagonal:
# FIXME: divide by batch_size
flattened = diagonal_contraction(fzw, zw)

return flattened.reshape(w.shape)
else:
# full hessian
a = range(z.dim())
fww = tensordot_pytorch(fzw, zw, axes=[a, a])

return fww


def diagonal_hessian(f, path, w):
return flat_hessian(f, path, w, diagonal=True)


# FIXME: missing a term
def _diagonal_hessian_multi_inner(fzz, z, w):
with torch.no_grad():
zw = total_derivative(z, w, create_graph=False)
fzw = tensordot_pytorch(fzz, zw)
fww_diagonal = diagonal_contraction(fzw, zw)

# TODO: add fz * diagonal(zww)

assert fww_diagonal.shape == w.shape

return fww_diagonal


def diagonal_hessian_multi(f, z, ws):
fz = total_derivative(f, z)
with torch.no_grad():
fzz = total_derivative(fz, z, create_graph=False).detach()

fww_diagonals = []

for w in ws:
fww_diagonal = _diagonal_hessian_multi_inner(fzz, z, w)

assert fww_diagonal.shape == w.shape

fww_diagonals.append(fww_diagonal.detach())

return fww_diagonals


def _diagonal_hessian_multi(f, z, ws, trained=True):
fz = total_derivative(f, z)
with torch.no_grad():
fzz = total_derivative(fz, z, create_graph=False).detach()

fww_diagonals = []

for w in ws:
# fww_diagonal = _diagonal_hessian_multi_inner(fz, z, w)
zw = total_derivative(z, w)
with torch.no_grad():
zw_detached = zw.detach()
fzw = tensordot_pytorch(fzz, zw_detached)
fww_diagonal = diagonal_contraction(fzw, zw_detached)

# if trained then fz is assumed to be negligible
# this is VERY slow
# NOTE: if z(w) factors through a ReLU then this additional term is not needed.
if not trained:
# calculate fz * diagonal(zww)
zww_diagonal = []
w_dim = w.dim()
for zw_ind in all_indices(zw):
w_ind = zw_ind[-w_dim:]

zz = torch.zeros_like(zw)
zz[zw_ind] = 1.0

# shape = w.shape
zww_slice, = grad(zw, w, grad_outputs=zz, create_graph=True)

# import ipdb; ipdb.set_trace()

zww_entry = zww_slice[w_ind]
zww_diagonal.append(zww_entry.item())

zww_diagonal = torch.tensor(zww_diagonal).reshape(zw.shape)

a = range(z.dim())
fww_diagonal += tensordot_pytorch(zw, zww_diagonal, axes=[a, a])

assert fww_diagonal.shape == w.shape

fww_diagonals.append(fww_diagonal.detach())

return fww_diagonals


# from: https://gist.github.com/deanmark/9aec75b7dc9fa71c93c4bc85c5438777
def tensordot_pytorch(a, b, axes=2):
# code adapted from numpy
try:
iter(axes)
except Exception:
axes_a = list(range(-axes, 0))
axes_b = list(range(0, axes))
else:
axes_a, axes_b = axes
try:
na = len(axes_a)
axes_a = list(axes_a)
except TypeError:
axes_a = [axes_a]
na = 1
try:
nb = len(axes_b)
axes_b = list(axes_b)
except TypeError:
axes_b = [axes_b]
nb = 1

# uncomment in pytorch >= 0.5
# a, b = torch.as_tensor(a), torch.as_tensor(b)
as_ = a.shape
nda = a.dim()
bs = b.shape
ndb = b.dim()
equal = True
if na != nb:
equal = False
else:
for k in range(na):
if as_[axes_a[k]] != bs[axes_b[k]]:
equal = False
break
if axes_a[k] < 0:
axes_a[k] += nda
if axes_b[k] < 0:
axes_b[k] += ndb
if not equal:
raise ValueError("shape-mismatch for sum")

# Move the axes to sum over to the end of "a"
# and to the front of "b"
notin = [k for k in range(nda) if k not in axes_a]
newaxes_a = notin + axes_a
N2 = 1
for axis in axes_a:
N2 *= as_[axis]
newshape_a = (int(np.multiply.reduce([as_[ax] for ax in notin])), N2)
olda = [as_[axis] for axis in notin]

notin = [k for k in range(ndb) if k not in axes_b]
newaxes_b = axes_b + notin
N2 = 1
for axis in axes_b:
N2 *= bs[axis]
newshape_b = (N2, int(np.multiply.reduce([bs[ax] for ax in notin])))
oldb = [bs[axis] for axis in notin]

at = a.permute(newaxes_a).reshape(newshape_a)
bt = b.permute(newaxes_b).reshape(newshape_b)

res = at.matmul(bt)
return res.reshape(olda + oldb)
59 changes: 59 additions & 0 deletions pytorch_cifar10_test_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from torch.autograd import Variable
from utee import selector

import numpy as np

# from pytorch_hessian_test import hessian_diagonal, hessian, total_derivative

# from tensordot import tensordot_pytorch, contraction_pytorch

from hessian_utils import *

from itertools import product

import datetime

# NOTE: Even with cuda=False the selector still tried to map the variables to GPUs. I had to change code in mnist/model.py to force mapping to CPU.
cuda = torch.device('cuda')
model_raw, ds_fetcher, is_imagenet = selector.select('cifar10', cuda=True)

# ps = list(model_raw.parameters())

# batch_size = 13 caused CUDA OOM on a K80

batch_size = 15

ds_val = ds_fetcher(batch_size=batch_size, train=False, val=True)

for idx, (data, target) in enumerate(ds_val):
print(idx)

z = data.to(device=cuda)
target = target.to(device=cuda)

for layer in model_raw.features:
z = layer(z)

z = z.reshape([batch_size, 1024])

output = model_raw.classifier(z)
loss = torch.nn.CrossEntropyLoss()
f = loss(output, target) ** 2

# break

print('start: {}'.format(datetime.datetime.now()))

dhs = diagonal_hessian_multi(f, output, model_raw.parameters())

print('end: {}'.format(datetime.datetime.now()))
print('memory: {}, {}'.format(torch.cuda.memory_allocated(), torch.cuda.memory_cached()))

torch.cuda.empty_cache()

if idx > 10:
break



103 changes: 103 additions & 0 deletions pytorch_mnist_test_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
from torch.autograd import Variable
from utee import selector

import numpy as np

# from pytorch_hessian_test import hessian_diagonal, hessian, total_derivative

# from tensordot import tensordot_pytorch, contraction_pytorch

from hessian_utils import *

from itertools import product

import datetime

batch_size = 200

# NOTE: Even with cuda=False the selector still tried to map the variables to GPUs. I had to change code in mnist/model.py to force mapping to CPU.
cuda = torch.device('cuda')

model_raw, ds_fetcher, is_imagenet = selector.select('mnist', cuda=True)

# ps = list(model_raw.parameters())

layer1 = model_raw.model[:3]
layer2 = model_raw.model[3:6]
layer3 = model_raw.model[6]

ds_val = ds_fetcher(batch_size=batch_size, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
# data = Variable(torch.FloatTensor(data))
# output = model_raw(data)

data = data.to(device=cuda)
target = target.to(device=cuda)

flattened = data.reshape((-1, 28*28))
z1 = layer1(flattened)
z2 = layer2(z1)
z3 = layer3(z2)

output = z3

loss = torch.nn.CrossEntropyLoss()

f = loss(output, target) ** 2

print('start: {}'.format(datetime.datetime.now()))

dhs = diagonal_hessian_multi(f, output, model_raw.parameters())

print('end: {}'.format(datetime.datetime.now()))
print('memory: {}, {}'.format(torch.cuda.memory_allocated(), torch.cuda.memory_cached()))

torch.cuda.empty_cache()

if idx > 5:
break

# # the kernels
# w0 = ps[0]
# w1 = ps[2]
# w2 = ps[4]

# print(w0.shape, w1.shape, w2.shape)
# print(w0.device, w1.device, w2.device)


# on one V100: (also, not really faster on 4 K80s)
# In [9]: %time fw2w2_diagonal = diagonal_hessian(f, [z3], w2)
# CPU times: user 340 ms, sys: 52 ms, total: 392 ms
# Wall time: 385 ms
# In [10]: %time fw1w1_diagonal = diagonal_hessian(f, [z3, z2], w1)
# CPU times: user 5.96 s, sys: 936 ms, total: 6.89 s
# Wall time: 6.81 s
# In [11]: %time fw0w0_diagonal = diagonal_hessian(f, [z3, z2, z1], w0)
# CPU times: user 16 s, sys: 2.68 s, total: 18.7 s
# Wall time: 18.5 s

# not really faster!
# In [12]: %time fw0w0_diagonal = diagonal_hessian(f, [z3], w0)
# CPU times: user 15.1 s, sys: 2.29 s, total: 17.4 s
# Wall time: 17.4 s


# fw2w2_diagonal = diagonal_hessian(f, [z3], w2)

# print(fw2w2_diagonal.shape, np.max(fw2w2_diagonal.cpu().data.numpy()), np.min(fw2w2_diagonal.cpu().data.numpy()))


# # on one V100 with the slow implementation
# # RuntimeError: CUDA error: out of memory
# fw1w1_diagonal = diagonal_hessian(f, [z3, z2], w1)

# print(fw1w1_diagonal.shape, np.max(fw1w1_diagonal.cpu().data.numpy()), np.min(fw1w1_diagonal.cpu().data.numpy()))



# fw0w0_diagonal = diagonal_hessian(f, [z3, z2, z1], w0)

# print(fw0w0_diagonal.shape, np.max(fw0w0_diagonal.cpu().data.numpy()), np.min(fw0w0_diagonal.cpu().data.numpy()))

0 comments on commit c7162c2

Please sign in to comment.