In [1]:
import torch
import torch.nn as nn
import torch.autograd as autograd

In [62]:
def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


class Maskarade(torch.autograd.Function):
    """ Pretends to be a param!"""
    @staticmethod
    def forward(ctx, scores, weight, sparsity):
        k_val = percentile(scores, sparsity*100)
        # return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device))
        mask = scores < k_val
        hydrated_weight = weight * mask
        return hydrated_weight, mask

    @staticmethod
    def backward(ctx, g1, g2):
        return g1, None, None


def register_hooks(module, name, amount):
    param = getattr(module, name)
    # Register buffers the same way as the pruning model.
    del module._parameters[name]
    module.register_parameter(name + "_orig", param)  # TODO: Should this be a buffer?
    module.register_buffer(name + "_mask", torch.zeros_like(param))
    module.register_parameter(name + "_score", nn.Parameter(torch.normal(mean=0, std=1, size=param.shape)))
    module.register_buffer(name, torch.zeros_like(param))

    def hydra_mask(module_, inputs):
        score = getattr(module, name + "_score")
        weight = getattr(module, name + "_orig")
        hydrated_weight, mask = Maskarade.apply(score, weight, amount)
        setattr(module, name, hydrated_weight)
        setattr(module, name + "_mask", mask)

    module.register_forward_pre_hook(hydra_mask)

def hydrate(network: nn.Module, amount: float, init: str):
    pass

def dehydrate(network: nn.Module):
    """ Change the hydrated network to a normal pruned network."""
    # 1. Delete pre_forward hooks.
    
    # 2. Delete the scoring params and make the _orig buffer a parameter.

    # 3. Than delete each mask and call the CustomMask method/class from pruned module.
    raise NotImplementedError

In [63]:
m = nn.Sequential(nn.Linear(1, 4), nn.Linear(4, 1))
register_hooks(m[0], "weight")

In [64]:
list(m.named_parameters())

[('0.bias',
  Parameter containing:
  tensor([ 0.5925,  0.0846, -0.9099, -0.6454], requires_grad=True)),
 ('0.weight_orig',
  Parameter containing:
  tensor([[-0.7632],
          [ 0.2147],
          [-0.1644],
          [ 0.4362]], requires_grad=True)),
 ('0.weight_score',
  Parameter containing:
  tensor([[-0.5151],
          [-0.5813],
          [-0.9485],
          [ 0.3997]], requires_grad=True)),
 ('1.weight',
  Parameter containing:
  tensor([[-0.4225,  0.3531,  0.3221,  0.3617]], requires_grad=True)),
 ('1.bias',
  Parameter containing:
  tensor([-0.1517], requires_grad=True))]

In [65]:
m(torch.tensor([1.])).backward()

False


In [60]:
list(m.named_parameters())

[('0.bias',
  Parameter containing:
  tensor([0.7637, 0.3764, 0.4258, 0.5208], requires_grad=True)),
 ('0.weight_orig',
  Parameter containing:
  tensor([[-0.4998],
          [-0.9608],
          [ 0.1656],
          [ 0.5035]], requires_grad=True)),
 ('0.weight_score',
  Parameter containing:
  tensor([[ 0.4618],
          [ 1.3859],
          [ 0.7635],
          [-1.1754]], requires_grad=True)),
 ('1.weight',
  Parameter containing:
  tensor([[-0.2975,  0.1069, -0.2973, -0.3809]], requires_grad=True)),
 ('1.bias',
  Parameter containing:
  tensor([0.1533], requires_grad=True))]

In [61]:
m[0].weight_score.grad

tensor([[-0.2975],
        [ 0.1069],
        [-0.2973],
        [-0.3809]])