In [62]:
from pruneshift.prune import simple_prune
from pruneshift.prune import L1Unstructured
from pruneshift.topologies import network_topology

import torch
from torch.nn.utils.prune import Identity
import tempfile
from pathlib import Path

tmp_dir = Path(tempfile.mkdtemp())
path_pruned = tmp_dir/"pruned.chpt"
path_unpruned = tmp_dir/"unpruned.chpt"

In [55]:
def same_network(net):
    """Heuristic to check whether we have the same net."""
    def _same_network(net):
        x = torch.ones([1, 3, 32, 32])
        orig_y = net(x)
        while True:
            net = yield
            y = net(x)
            is_equal = (orig_y == y).all()
            print(f"The output has not changed: {is_equal}")
            if not is_equal:
                norm_diff = torch.norm(orig_y - y)
                print(f"The norm of the difference is: {norm_diff}")
    gen = _same_network(net)
    gen.send(None)
    return gen

In [68]:
# We still would need to recreate the net with the corresponding pruning method.
net = network_topology("cifar10_resnet18", pretrained=False)
torch.save(net.state_dict(), path_unpruned)
print("Doing some pruning.")
simple_prune(net, L1Unstructured, amount=0.1)
simple_prune(net, L1Unstructured, amount=0.2)
check = same_network(net)
print("Saving the state of the network.")
torch.save(net.state_dict(), path_pruned)
del net
pruned_state = torch.load(path_pruned)
unpruned_state = torch.load(path_unpruned)
print("Deleted the network.")
print("Recreating the default network.")
net = network_topology("cifar10_resnet18", pretrained=False)
check.send(net)
print("Reloading the network with custom mask.")
simple_prune(net, Identity)
print("Reloading the network state.")
net.load_state_dict(pruned_state)
check.send(net)
del net
print("Check whether the network pruning is smart")
net = network_topology("cifar10_resnet18", pretrained=False)
net.load_state_dict(unpruned_state)
iter_amount = 0.1 + 0.9 * 0.2
simple_prune(net, L1Unstructured, amount=iter_amount)
check.send(net)

Doing some pruning.
Saving the state of the network.
Deleted the network.
Recreating the default network.
The output has not changed: False
The norm of the difference is: 2.707881450653076
Reloading the network with custom mask.
Reloading the network state.
The output has not changed: True
Check whether the network pruning is smart
The output has not changed: False
The norm of the difference is: 0.00023480765230488032


Question:
 - The pruning method is indeed smart :) : it respects that there might be already some weights pruned and than calculates the update to prune regarding the already pruned network.

Situation:
 - The compute mask is calculated once in the beginning in apply than the forward hook just looks up the mask in the corresponding buffer. Hence, it is probably not possible to allow learning of masks...

Problem:
 - Reloading can be solved by applying a dummy pruning method like Identity.
 - Reloading i