Skip to content

Commit

Permalink
Merge pull request #11 from tfjgeorge/fim_logsoftmax
Browse files Browse the repository at this point in the history
fixes FIM MC for log_softmax
  • Loading branch information
tfjgeorge committed Mar 3, 2021
2 parents 3f5361e + 327c0c7 commit 665c0b1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 31 deletions.
2 changes: 1 addition & 1 deletion nngeometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def fim_function(*d):
sampled_targets)
elif variant == 'classif_logsoftmax':

def fim_function(input, target):
def fim_function(*d):
log_softmax = function(*d)
probabilities = torch.exp(log_softmax)
sampled_targets = torch.multinomial(probabilities, trials,
Expand Down
68 changes: 38 additions & 30 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,44 @@ def test_FIM_MC_vs_linearization():
step = 1e-2

for get_task in nonlinear_tasks:
quots = []
for i in range(10): # repeat to kill statistical fluctuations
loader, lc, parameters, model, function, n_output = get_task()
model.train()
F = FIM_MonteCarlo(layer_collection=lc,
model=model,
loader=loader,
variant='classif_logits',
representation=PMatDense,
trials=10,
function=lambda *d: model(to_device(d[0])))

dw = random_pvector(lc, device=device)
dw = step / dw.norm() * dw

output_before = get_output_vector(loader, function)
update_model(parameters, dw.get_flat_representation())
output_after = get_output_vector(loader, function)
update_model(parameters, -dw.get_flat_representation())

KL = tF.kl_div(tF.log_softmax(output_before, dim=1),
tF.log_softmax(output_after, dim=1),
log_target=True, reduction='batchmean')

quot = (KL / F.vTMv(dw) * 2) ** .5

quots.append(quot.item())

mean_quotient = sum(quots) / len(quots)
assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2
for variant in ['classif_logits', 'classif_logsoftmax']:
quots = []
for i in range(10): # repeat to kill statistical fluctuations
loader, lc, parameters, model, function, n_output = get_task()
model.train()

if variant == 'classif_logits':
f = lambda *d: model(to_device(d[0]))
elif variant == 'classif_logsoftmax':
f = lambda *d: torch.log_softmax(model(to_device(d[0])),
dim=1)

F = FIM_MonteCarlo(layer_collection=lc,
model=model,
loader=loader,
variant=variant,
representation=PMatDense,
trials=10,
function=f)

dw = random_pvector(lc, device=device)
dw = step / dw.norm() * dw

output_before = get_output_vector(loader, function)
update_model(parameters, dw.get_flat_representation())
output_after = get_output_vector(loader, function)
update_model(parameters, -dw.get_flat_representation())

KL = tF.kl_div(tF.log_softmax(output_before, dim=1),
tF.log_softmax(output_after, dim=1),
log_target=True, reduction='batchmean')

quot = (KL / F.vTMv(dw) * 2) ** .5

quots.append(quot.item())

mean_quotient = sum(quots) / len(quots)
assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2


def test_FIM_vs_linearization_classif_logits():
Expand Down

0 comments on commit 665c0b1

Please sign in to comment.