Skip to content

Commit

Permalink
Merge pull request #18 from tfjgeorge/pickle
Browse files Browse the repository at this point in the history
Adds pickle for PMat objects and PVectors
  • Loading branch information
tfjgeorge committed Apr 7, 2021
2 parents ba043fc + 4432fb1 commit 8ae7664
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 0 deletions.
11 changes: 11 additions & 0 deletions nngeometry/generator/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class DummyGenerator:
"""
This dummy generator is used for pickled objects
"""

def __init__(self, layer_collection, device):
self.layer_collection = layer_collection
self.device = device

def get_device(self):
return self.device
26 changes: 26 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def parameters(self, layerid_to_module):
elif layer.bias:
yield layerid_to_module[layer_id].bias

def __eq__(self, other):
for layer_id in set(self.layers.keys()).union(set(other.layers.keys())):
if (layer_id not in other.layers.keys()
or layer_id not in self.layers.keys()
or self.layers[layer_id] != other.layers[layer_id]):
return False
return True


class AbstractLayer(ABC):
pass
Expand All @@ -144,6 +152,11 @@ def numel(self):
else:
return self.weight.numel()

def __eq__(self, other):
return (self.in_channels == other.in_channels and
self.out_channels == other.out_channels and
self.kernel_size == other.kernel_size)


class LinearLayer(AbstractLayer):

Expand All @@ -162,6 +175,10 @@ def numel(self):
else:
return self.weight.numel()

def __eq__(self, other):
return (self.in_features == other.in_features and
self.out_features == other.out_features)


class BatchNorm1dLayer(AbstractLayer):

Expand All @@ -173,6 +190,9 @@ def __init__(self, num_features):
def numel(self):
return self.weight.numel() + self.bias.numel()

def __eq__(self, other):
return self.num_features == other.num_features


class BatchNorm2dLayer(AbstractLayer):

Expand All @@ -184,6 +204,9 @@ def __init__(self, num_features):
def numel(self):
return self.weight.numel() + self.bias.numel()

def __eq__(self, other):
return self.num_features == other.num_features


class GroupNormLayer(AbstractLayer):

Expand All @@ -195,6 +218,9 @@ def __init__(self, num_groups, num_channels):
def numel(self):
return self.weight.numel() + self.bias.numel()

def __eq__(self, other):
return self.num_channels == other.num_channels


class Parameter(object):

Expand Down
11 changes: 11 additions & 0 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from ..maths import kronecker
from .vector import PVector
from nngeometry.generator.dummy import DummyGenerator


class PMatAbstract(ABC):
Expand Down Expand Up @@ -103,6 +104,16 @@ def _check_data_examples(self, data, examples):
"""
assert (data is not None) ^ (examples is not None)

def __getstate__(self):
return {'layer_collection': self.generator.layer_collection,
'data': self.data,
'device': self.generator.get_device()}

def __setstate__(self, state_dict):
self.data = state_dict['data']
self.generator = DummyGenerator(state_dict['layer_collection'],
state_dict['device'])


class PMatDense(PMatAbstract):
def __init__(self, generator, data=None, examples=None):
Expand Down
66 changes: 66 additions & 0 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from tasks import get_conv_gn_task, get_conv_task
from nngeometry.layercollection import LayerCollection
import pickle as pkl
from utils import check_tensors

from nngeometry.object.pspace import (PMatDense, PMatDiag, PMatBlockDiag,
PMatLowRank, PMatQuasiDiag)
from nngeometry.generator import Jacobian
from nngeometry.object.vector import PVector

def test_layercollection_pkl():
_, lc, _, _, _, _ = get_conv_gn_task()

with open('/tmp/lc.pkl', 'wb') as f:
pkl.dump(lc, f)

with open('/tmp/lc.pkl', 'rb') as f:
lc_pkl = pkl.load(f)

assert lc == lc_pkl


def test_layercollection_eq():
_, lc, _, _, _, _ = get_conv_gn_task()
_, lc_same, _, _, _, _ = get_conv_gn_task()
_, lc_different, _, _, _, _ = get_conv_task()

assert lc == lc_same
assert lc != lc_different


def test_PMat_pickle():
loader, lc, parameters, model, function, n_output = get_conv_task()

generator = Jacobian(layer_collection=lc,
model=model,
function=function,
n_output=n_output)

for repr in [PMatDense, PMatDiag, PMatBlockDiag,
PMatLowRank, PMatQuasiDiag]:
PMat = repr(generator=generator,
examples=loader)

with open('/tmp/PMat.pkl', 'wb') as f:
pkl.dump(PMat, f)

with open('/tmp/PMat.pkl', 'rb') as f:
PMat_pkl = pkl.load(f)

check_tensors(PMat.get_dense_tensor(), PMat_pkl.get_dense_tensor())


def test_PVector_pickle():
_, _, _, model, _, _ = get_conv_task()

vec = PVector.from_model(model)

with open('/tmp/PVec.pkl', 'wb') as f:
pkl.dump(vec, f)

with open('/tmp/PVec.pkl', 'rb') as f:
vec_pkl = pkl.load(f)

check_tensors(vec.get_flat_representation(),
vec_pkl.get_flat_representation())

0 comments on commit 8ae7664

Please sign in to comment.