Skip to content

Commit

Permalink
Merge pull request #5 from tfjgeorge/mb
Browse files Browse the repository at this point in the history
moves loader out of generator instance to be passed when calling meth…
  • Loading branch information
tfjgeorge committed Feb 24, 2021
2 parents 27aa8c5 + c1d66ed commit 3cf8abb
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 140 deletions.
101 changes: 59 additions & 42 deletions nngeometry/generator/jacobian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from ..utils import per_example_grad_conv
from ..object.vector import PVector, FVector
from ..layercollection import LayerCollection
Expand All @@ -18,8 +19,6 @@ class Jacobian:
:type layer_collection: :class:`.layercollection.LayerCollection`
:param model:
:type model: Pytorch `nn.Module`
:param loader:
:type loader: Pytorch `utils.data.DataLoader`
:param function: A function :math:`f\left(X,Y,Z\\right)` where :math:`X,Y,Z` are minibatchs
returned by the dataloader (Note that in some cases :math:`Y,Z` are not required). If None,
it defaults to `function = lambda *x: model(x[0])`
Expand All @@ -29,10 +28,9 @@ class Jacobian:
:type n_output: integer
"""
def __init__(self, model, loader, function=None, n_output=1,
def __init__(self, model, function=None, n_output=1,
centering=False, layer_collection=None):
self.model = model
self.loader = loader
self.handles = []
self.xs = dict()
self.n_output = n_output
Expand All @@ -53,16 +51,17 @@ def __init__(self, model, loader, function=None, n_output=1,
def get_device(self):
return next(self.model.parameters()).device

def get_covariance_matrix(self):
def get_covariance_matrix(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
bs = self.loader.batch_size
bs = loader.batch_size
G = torch.zeros((n_parameters, n_parameters), device=device)
self.grads = torch.zeros((1, bs, n_parameters), device=device)
if self.centering:
Expand All @@ -71,7 +70,7 @@ def get_covariance_matrix(self):

self.start = 0
self.i_output = 0
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -98,7 +97,7 @@ def get_covariance_matrix(self):

return G

def get_covariance_diag(self):
def get_covariance_diag(self, examples):
if self.centering:
raise NotImplementedError
# add hooks
Expand All @@ -107,11 +106,12 @@ def get_covariance_diag(self):
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.diag_m = torch.zeros((n_parameters,), device=device)
self.start = 0
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -131,7 +131,7 @@ def get_covariance_diag(self):

return diag_m

def get_covariance_quasidiag(self):
def get_covariance_quasidiag(self, examples):
if self.centering:
raise NotImplementedError
# add hooks
Expand All @@ -140,7 +140,8 @@ def get_covariance_quasidiag(self):
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
s = layer.numel()
Expand All @@ -152,7 +153,7 @@ def get_covariance_quasidiag(self):
self._blocks[layer_id] = (torch.zeros((s, ), device=device),
torch.zeros(cross_s, device=device))

for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -177,7 +178,7 @@ def get_covariance_quasidiag(self):

return blocks

def get_covariance_layer_blocks(self):
def get_covariance_layer_blocks(self, examples):
if self.centering:
raise NotImplementedError
# add hooks
Expand All @@ -186,13 +187,14 @@ def get_covariance_layer_blocks(self):
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
s = layer.numel()
self._blocks[layer_id] = torch.zeros((s, s), device=device)

for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -212,14 +214,15 @@ def get_covariance_layer_blocks(self):

return blocks

def get_kfac_blocks(self):
def get_kfac_blocks(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_kfac_blocks,
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
layer_class = layer.__class__.__name__
Expand All @@ -235,7 +238,7 @@ def get_kfac_blocks(self):
self._blocks[layer_id] = (torch.zeros((sA, sA), device=device),
torch.zeros((sG, sG), device=device))

for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand Down Expand Up @@ -265,19 +268,20 @@ def get_kfac_blocks(self):

return blocks

def get_jacobian(self):
def get_jacobian(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.grads = torch.zeros((self.n_output, n_examples, n_parameters),
device=device)
self.start = 0
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -303,20 +307,21 @@ def get_jacobian(self):

return grads

def get_gram_matrix(self):
def get_gram_matrix(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex_io, self._hook_kxy,
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self.G = torch.zeros((self.n_output, n_examples,
self.n_output, n_examples), device=device)
self.x_outer = dict()
self.x_inner = dict()
self.gy_outer = dict()
self.e_outer = 0
for i_outer, d in enumerate(self.loader):
for i_outer, d in enumerate(loader):
# used in hooks to switch between store/compute
inputs_outer = d[0]
inputs_outer.requires_grad = True
Expand All @@ -332,7 +337,7 @@ def get_gram_matrix(self):
self.outerloop_switch = False

self.e_inner = 0
for i_inner, d in enumerate(self.loader):
for i_inner, d in enumerate(loader):
if i_inner > i_outer:
break
inputs_inner = d[0]
Expand Down Expand Up @@ -381,14 +386,15 @@ def get_gram_matrix(self):

return G

def get_kfe_diag(self, kfe):
def get_kfe_diag(self, kfe, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_kfe_diag,
self.l_to_m.values())

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._diags = dict()
self._kfe = kfe
for layer_id, layer in self.layer_collection.layers.items():
Expand All @@ -404,7 +410,7 @@ def get_kfe_diag(self, kfe):
sA += 1
self._diags[layer_id] = torch.zeros((sG * sA), device=device)

for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -427,7 +433,7 @@ def get_kfe_diag(self, kfe):

return diags

def implicit_mv(self, v):
def implicit_mv(self, v, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_Jv,
Expand All @@ -448,11 +454,12 @@ def implicit_mv(self, v):
output[mod.bias] = torch.zeros_like(mod.bias)

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

self.i_output = 0
self.start = 0
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand Down Expand Up @@ -495,7 +502,7 @@ def implicit_mv(self, v):
return PVector(layer_collection=self.layer_collection,
dict_repr=output_dict)

def implicit_vTMv(self, v):
def implicit_vTMv(self, v, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_Jv,
Expand All @@ -504,7 +511,8 @@ def implicit_vTMv(self, v):
self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

for layer_id, layer in self.layer_collection.layers.items():
mod = self.l_to_m[layer_id]
Expand All @@ -516,7 +524,7 @@ def implicit_vTMv(self, v):
self.start = 0
norm2 = 0
self.compute_switch = True
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -542,16 +550,17 @@ def implicit_vTMv(self, v):

return norm

def implicit_trace(self):
def implicit_trace(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_trace,
self.l_to_m.values())

n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

self._trace = 0
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand All @@ -571,7 +580,7 @@ def implicit_trace(self):

return trace

def implicit_Jv(self, v):
def implicit_Jv(self, v, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_Jv,
Expand All @@ -580,11 +589,12 @@ def implicit_Jv(self, v):
self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
n_examples = len(self.loader.sampler)
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._Jv = torch.zeros((self.n_output, n_examples), device=device)
self.start = 0
self.compute_switch = True
for d in self.loader:
for d in loader:
inputs = d[0]
inputs.requires_grad = True
bs = inputs.size(0)
Expand Down Expand Up @@ -1101,3 +1111,10 @@ def _hook_compute_trace(self, mod, grad_input, grad_output):
self._trace += (gy.sum(dim=(2, 3))**2).sum()
else:
raise NotImplementedError

def _get_dataloader(self, examples):
if isinstance(examples, DataLoader):
return examples
else:
return DataLoader(TensorDataset(*examples),
batch_size=len(examples[0]))
6 changes: 2 additions & 4 deletions nngeometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@ def fim_function(input, target):

generator = Jacobian(layer_collection=layer_collection,
model=model,
loader=loader,
function=fim_function,
n_output=trials)
return representation(generator)
return representation(generator=generator, examples=loader)


def FIM(model,
Expand Down Expand Up @@ -147,7 +146,6 @@ def function_fim(*d):

generator = Jacobian(layer_collection=layer_collection,
model=model,
loader=loader,
function=function_fim,
n_output=n_output)
return representation(generator)
return representation(generator=generator, examples=loader)
4 changes: 2 additions & 2 deletions nngeometry/object/fspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def __init__(self, generator):


class FMatDense(FMatAbstract):
def __init__(self, generator, data=None):
def __init__(self, generator, data=None, examples=None):
self.generator = generator
if data is not None:
self.data = data
else:
self.data = generator.get_gram_matrix()
self.data = generator.get_gram_matrix(examples)

def compute_eigendecomposition(self, impl='symeig'):
# TODO: test
Expand Down

0 comments on commit 3cf8abb

Please sign in to comment.