Skip to content

Commit

Permalink
Merge pull request #25 from tfjgeorge/multiclass_semantic_segm
Browse files Browse the repository at this point in the history
adds MonteCarlo Fisher helper for segmentation
  • Loading branch information
tfjgeorge committed Jun 10, 2021
2 parents bdcf8a7 + eca86b6 commit d6cbe15
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 25 deletions.
18 changes: 18 additions & 0 deletions nngeometry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def FIM_MonteCarlo(model,
Variant to use depending on how you interpret your function.
Possible choices are:
- 'classif_logits' when using logits for classification
- 'classif_logsoftmax' when using log_softmax values for classification
- 'segmentation_logits' when using logits in a segmentation task
- 'regression' when using a gaussian regression model
trials : int, optional (default=1)
Number of trials for Monte Carlo sampling
Expand Down Expand Up @@ -71,6 +73,22 @@ def fim_function(*d):
replacement=True)
return trials ** -.5 * torch.gather(log_softmax, 1,
sampled_targets)
elif variant == 'segmentation_logits':

def fim_function(*d):
log_softmax = torch.log_softmax(function(*d), dim=1)
s_mb, s_c, s_h, s_w = log_softmax.size()
log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous() \
.view(s_mb * s_h * s_w, s_c)
probabilities = torch.exp(log_softmax)
sampled_indices = torch.multinomial(probabilities, trials,
replacement=True)
sampled_targets = torch.gather(log_softmax, 1,
sampled_indices)
sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials) \
.sum(dim=1)
return trials ** -.5 * sampled_targets

else:
raise NotImplementedError

Expand Down
57 changes: 50 additions & 7 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

def to_device(tensor):
return tensor.to(device)

def to_device_model(model):
model.to('cuda')
else:
device = 'cpu'

Expand All @@ -23,6 +26,9 @@ def to_device(tensor):
def to_device(tensor):
return tensor.double()

def to_device_model(model):
model.double()

class FCNet(nn.Module):
def __init__(self, out_size=10, normalization='none'):
super(FCNet, self).__init__()
Expand All @@ -45,6 +51,25 @@ def forward(self, x):
return self.net(x)


class FCNetSegmentation(nn.Module):
def __init__(self, out_size=10):
super(FCNetSegmentation, self).__init__()
layers = []
self.out_size = out_size
sizes = [18*18, 10, 10, 4*4*out_size]
for s_in, s_out in zip(sizes[:-1], sizes[1:]):
layers.append(nn.Linear(s_in, s_out))
layers.append(nn.ReLU())
# remove last nonlinearity:
layers.pop()
self.net = nn.Sequential(*layers)

def forward(self, x):
x = x[:, :, 5:-5, 5:-5].contiguous()
x = x.view(x.size(0), -1)
return self.net(x).view(-1, self.out_size, 4, 4)


class ConvNet(nn.Module):
def __init__(self, normalization='none'):
super(ConvNet, self).__init__()
Expand Down Expand Up @@ -95,7 +120,7 @@ def get_linear_fc_task():
batch_size=300,
shuffle=False)
net = LinearFCNet()
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand Down Expand Up @@ -128,7 +153,7 @@ def get_linear_conv_task():
batch_size=300,
shuffle=False)
net = LinearConvNet()
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand Down Expand Up @@ -163,7 +188,7 @@ def get_batchnorm_fc_linear_task():
batch_size=300,
shuffle=False)
net = BatchNormFCLinearNet()
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand Down Expand Up @@ -205,7 +230,7 @@ def get_batchnorm_conv_linear_task():
batch_size=300,
shuffle=False)
net = BatchNormConvLinearNet()
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand Down Expand Up @@ -258,7 +283,7 @@ def get_batchnorm_nonlinear_task():
batch_size=1000,
shuffle=False)
net = BatchNormNonLinearNet()
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand All @@ -284,7 +309,7 @@ def get_fullyconnect_task(normalization='none'):
batch_size=300,
shuffle=False)
net = FCNet(out_size=3, normalization=normalization)
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand All @@ -307,7 +332,7 @@ def get_conv_task(normalization='none'):
batch_size=300,
shuffle=False)
net = ConvNet(normalization=normalization)
net.to(device)
to_device_model(net)
net.eval()

def output_fn(input, target):
Expand Down Expand Up @@ -335,3 +360,21 @@ def get_fullyconnect_onlylast_task():
parameters = net.net[-1].parameters()

return train_loader, layer_collection, parameters, net, output_fn, n_output

def get_fullyconnect_segm_task():
train_set = get_mnist()
train_set = Subset(train_set, range(1000))
train_loader = DataLoader(
dataset=train_set,
batch_size=300,
shuffle=False)
net = FCNetSegmentation(out_size=3)
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(),
net, output_fn, 3)
64 changes: 46 additions & 18 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,15 @@
import torch
import torch.nn.functional as tF
from tasks import (get_linear_fc_task, get_linear_conv_task,
get_batchnorm_fc_linear_task,
get_batchnorm_conv_linear_task,
get_fullyconnect_onlylast_task,
get_fullyconnect_task, get_fullyconnect_bn_task,
get_batchnorm_nonlinear_task,
get_conv_task, get_conv_bn_task, get_conv_gn_task)
from tasks import to_device
from nngeometry.object.map import (PushForwardDense, PushForwardImplicit,
PullBackDense)
from nngeometry.object.fspace import FMatDense
from nngeometry.object.pspace import (PMatDense, PMatDiag, PMatBlockDiag,
PMatImplicit, PMatLowRank, PMatQuasiDiag)
from nngeometry.generator import Jacobian
from tasks import (get_fullyconnect_task, get_conv_task, get_conv_gn_task,
get_fullyconnect_segm_task)
from tasks import to_device, device
from nngeometry.object.pspace import PMatDense
from nngeometry.metrics import FIM, FIM_MonteCarlo
from nngeometry.object.vector import random_pvector, random_fvector, PVector
from utils import check_ratio, check_tensors, check_angle
from test_jacobian import update_model, get_output_vector, device
from nngeometry.object.vector import random_pvector
from test_jacobian import update_model, get_output_vector

nonlinear_tasks = [get_conv_gn_task, get_fullyconnect_task, get_conv_task]

import numpy as np
import pytest


Expand Down Expand Up @@ -110,6 +98,7 @@ def test_FIM_vs_linearization_classif_logits():
mean_quotient = sum(quots) / len(quots)
assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2


def test_FIM_vs_linearization_regression():
step = 1e-2

Expand Down Expand Up @@ -143,3 +132,42 @@ def test_FIM_vs_linearization_regression():

mean_quotient = sum(quots) / len(quots)
assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2


def test_FIM_MC_vs_linearization_segmentation():
step = 1e-2
variant = 'segmentation_logits'
for get_task in [get_fullyconnect_segm_task]:
quots = []
for i in range(10): # repeat to kill statistical fluctuations
loader, lc, parameters, model, function, n_output = get_task()
model.train()

f = lambda *d: model(to_device(d[0]))

F = FIM_MonteCarlo(layer_collection=lc,
model=model,
loader=loader,
variant=variant,
representation=PMatDense,
trials=10,
function=f)

dw = random_pvector(lc, device=device)
dw = step / dw.norm() * dw

output_before = get_output_vector(loader, function)
update_model(parameters, dw.get_flat_representation())
output_after = get_output_vector(loader, function)
update_model(parameters, -dw.get_flat_representation())

KL = tF.kl_div(tF.log_softmax(output_before, dim=1),
tF.log_softmax(output_after, dim=1),
log_target=True, reduction='batchmean')

quot = (KL / F.vTMv(dw) * 2) ** .5

quots.append(quot.item())

mean_quotient = sum(quots) / len(quots)
assert mean_quotient > 1 - 5e-2 and mean_quotient < 1 + 5e-2

0 comments on commit d6cbe15

Please sign in to comment.