Skip to content

Commit

Permalink
adds function argument to FIM and FIM_MC
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Sep 18, 2020
1 parent 4497db9 commit e43229e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
40 changes: 28 additions & 12 deletions nngeometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def FIM_MonteCarlo(model,
variant='classif_logits',
trials=1,
device='cpu',
function=None,
layer_collection=None):
"""
Helper that creates a matrix computing the Fisher Information
Expand All @@ -35,28 +36,36 @@ def FIM_MonteCarlo(model,
Number of trials for Monte Carlo sampling
device : string, optional (default='cpu')
Target device for the returned matrix
function : function, optional (default=None)
An optional function if different from `model(input)`. If
it is different from None, it will override the device
parameter.
layer_collection : layercollection.LayerCollection, optional
(default=None)
An optional layer collection
"""

if function is None:
def function(*d):
return model(d[0].to(device))

if layer_collection is None:
layer_collection = LayerCollection.from_model(model)

if variant == 'classif_logits':

def function(input, target):
log_softmax = torch.log_softmax(model(input.to(device)), dim=1)
def fim_function(*d):
log_softmax = torch.log_softmax(function(*d), dim=1)
probabilities = torch.exp(log_softmax)
sampled_targets = torch.multinomial(probabilities, trials,
replacement=True)
return trials ** -.5 * torch.gather(log_softmax, 1,
sampled_targets)
elif variant == 'classif_logsoftmax':

def function(input, target):
log_softmax = model(input.to(device))
def fim_function(input, target):
log_softmax = function(*d)
probabilities = torch.exp(log_softmax)
sampled_targets = torch.multinomial(probabilities, trials,
replacement=True)
Expand All @@ -68,7 +77,7 @@ def function(input, target):
generator = Jacobian(layer_collection=layer_collection,
model=model,
loader=loader,
function=function,
function=fim_function,
n_output=trials)
return representation(generator)

Expand All @@ -79,6 +88,7 @@ def FIM(model,
n_output,
variant='classif_logits',
device='cpu',
function=None,
layer_collection=None):
"""
Helper that creates a matrix computing the Fisher Information
Expand All @@ -104,34 +114,40 @@ def FIM(model,
- 'regression' when using a gaussian regression model
device : string, optional (default='cpu')
Target device for the returned matrix
function : function, optional (default=None)
An optional function if different from `model(input)`. If
it is different from None, it will override the device
parameter.
layer_collection : layercollection.LayerCollection, optional
(default=None)
An optional layer collection
"""

if function is None:
def function(*d):
return model(d[0].to(device))

if layer_collection is None:
layer_collection = LayerCollection.from_model(model)

if variant == 'classif_logits':

def function(*d):
inputs = d[0].to(device)
log_probs = torch.log_softmax(model(inputs), dim=1)
def function_fim(*d):
log_probs = torch.log_softmax(function(*d), dim=1)
probs = torch.exp(log_probs).detach()
return (log_probs * probs**.5)

elif variant == 'regression':

def function(*d):
inputs = d[0].to(device)
estimates = model(inputs)
def function_fim(*d):
estimates = model(function(*d))
return estimates
else:
raise NotImplementedError

generator = Jacobian(layer_collection=layer_collection,
model=model,
loader=loader,
function=function,
function=function_fim,
n_output=n_output)
return representation(generator)
5 changes: 3 additions & 2 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
get_fullyconnect_task, get_fullyconnect_bn_task,
get_batchnorm_nonlinear_task,
get_conv_task, get_conv_bn_task, get_conv_gn_task)
from tasks import to_device
from nngeometry.object.map import (PushForwardDense, PushForwardImplicit,
PullBackDense)
from nngeometry.object.fspace import FMatDense
Expand Down Expand Up @@ -38,7 +39,7 @@ def test_FIM_MC_vs_linearization():
variant='classif_logits',
representation=PMatDense,
trials=10,
device=device)
function=lambda *d: model(to_device(d[0])))

dw = random_pvector(lc, device=device)
dw = step / dw.norm() * dw
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_FIM_vs_linearization():
variant='classif_logits',
representation=PMatDense,
n_output=n_output,
device=device)
function=lambda *d: model(to_device(d[0])))

dw = random_pvector(lc, device=device)
dw = step / dw.norm() * dw
Expand Down

0 comments on commit e43229e

Please sign in to comment.