# lib

In [1]:
!pip3 install torch==1.4 torchvision==0.5 numpy==1.18



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
#import mon
import numpy as np
import time
#
def cuda(tensor):
    if torch.cuda.is_available():
        return tensor.cuda()
    else:
        return tensor

# Pruner

In [38]:
class Pruner:
    def __init__(self, masked_parameters):
        self.masked_parameters = list(masked_parameters)
        self.scores = {}

    def score(self, model, loss, dataloader, device):
        raise NotImplementedError

    def _global_mask(self, sparsity):
        r"""Updates masks of model with scores by sparsity level globally.
        """
        # # Set score for masked parameters to -inf 
        # for mask, param in self.masked_parameters:
        #     score = self.scores[id(param)]
        #     score[mask == 0.0] = -np.inf

        # Threshold scores
        global_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        k = int((1.0 - sparsity) * global_scores.numel())
        if not k < 1:
            threshold, _ = torch.kthvalue(global_scores, k)
            for mask, param in self.masked_parameters:
                score = self.scores[id(param)] 
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))
    
    def _local_mask(self, sparsity):
        r"""Updates masks of model with scores by sparsity level parameter-wise.
        """
        for mask, param in self.masked_parameters:
            score = self.scores[id(param)]
            k = int((1.0 - sparsity) * score.numel())
            if not k < 1:
                threshold, _ = torch.kthvalue(torch.flatten(score), k)
                zero = torch.tensor([0.]).to(mask.device)
                one = torch.tensor([1.]).to(mask.device)
                mask.copy_(torch.where(score <= threshold, zero, one))

    def mask(self, sparsity, scope):
        r"""Updates masks of model with scores by sparsity according to scope.
        """
        if scope == 'global':
            self._global_mask(sparsity)
        if scope == 'local':
            self._local_mask(sparsity)

    @torch.no_grad()
    def apply_mask(self):
        r"""Applies mask to prunable parameters.
        """
        for mask, param in self.masked_parameters:
            param.mul_(mask)

    def alpha_mask(self, alpha):
        r"""Set all masks to alpha in model.
        """
        for mask, _ in self.masked_parameters:
            mask.fill_(alpha)

    # Based on https://github.com/facebookresearch/open_lth/blob/master/utils/tensor_utils.py#L43
    def shuffle(self):
        for mask, param in self.masked_parameters:
            shape = mask.shape
            perm = torch.randperm(mask.nelement())
            mask = mask.reshape(-1)[perm].reshape(shape)

    def invert(self):
        for v in self.scores.values():
            v.div_(v**2)

    def stats(self):
        r"""Returns remaining and total number of prunable parameters.
        """
        remaining_params, total_params = 0, 0 
        for mask, _ in self.masked_parameters:
             remaining_params += mask.detach().cpu().numpy().sum()
             total_params += mask.numel()
        return remaining_params, total_params


class Rand(Pruner):
    def __init__(self, masked_parameters):
        super(Rand, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device):
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.randn_like(p)


class Mag(Pruner):
    def __init__(self, masked_parameters):
        super(Mag, self).__init__(masked_parameters)
    
    def score(self, model, loss, dataloader, device):
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.data).detach().abs_()


# Based on https://github.com/mi-lad/snip/blob/master/snip.py#L18
class SNIP(Pruner):
    def __init__(self, masked_parameters):
        super(SNIP, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device):

        # allow masks to have gradient
        for m, _ in self.masked_parameters:
            m.requires_grad = True

        # compute gradient
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss(output, target).backward()

        # calculate score |g * theta|
        for m, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(m.grad).detach().abs_()
            p.grad.data.zero_()
            m.grad.data.zero_()
            m.requires_grad = False
        
        for i, v in enumerate(self.scores.values()):
            print(f'norm of layer {i+1} is {torch.norm(v)}')

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.sum(all_scores)
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


# Based on https://github.com/alecwangcq/GraSP/blob/master/pruner/GraSP.py#L49
class GraSP(Pruner):
    def __init__(self, masked_parameters):
        super(GraSP, self).__init__(masked_parameters)
        self.temp = 200
        self.eps = 1e-10

    def score(self, model, loss, dataloader, device):

        # first gradient vector without computational graph
        stopped_grads = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=False)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])
            stopped_grads += flatten_grads

        # second gradient vector with computational graph
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            output = model(data) / self.temp
            L = loss(output, target)

            grads = torch.autograd.grad(L, [p for (_, p) in self.masked_parameters], create_graph=True)
            flatten_grads = torch.cat([g.reshape(-1) for g in grads if g is not None])
            
            gnorm = (stopped_grads * flatten_grads).sum()
            gnorm.backward()
        
        # calculate score Hg * theta (negate to remove top percent)
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p.data).detach()
            p.grad.data.zero_()

        # normalize score
        all_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
        norm = torch.abs(torch.sum(all_scores)) + self.eps
        for _, p in self.masked_parameters:
            self.scores[id(p)].div_(norm)


class SynFlow(Pruner):
    def __init__(self, masked_parameters):
        super(SynFlow, self).__init__(masked_parameters)

    def score(self, model, loss, dataloader, device):
      
        @torch.no_grad()
        def linearize(model):
            # model.double()
            signs = {}
            for name, param in model.state_dict().items():
                signs[name] = torch.sign(param)
                param.abs_()
            return signs

        @torch.no_grad()
        def nonlinearize(model, signs):
            # model.float()
            for name, param in model.state_dict().items():
                param.mul_(signs[name])
        
        signs = linearize(model)

        (data, _) = next(iter(dataloader))
        input_dim = list(data[0,:].shape)
        input = torch.ones([1] + input_dim).to(device)#, dtype=torch.float64).to(device)
        output = model(input)
        torch.sum(output).backward()
        
        for _, p in self.masked_parameters:
            self.scores[id(p)] = torch.clone(p.grad * p).detach().abs_()
            p.grad.data.zero_()

        nonlinearize(model, signs)



In [25]:
init_mask_weight = torch.rand(10,10)-0.5
init_mask_weight

tensor([[ 0.2783,  0.1714, -0.4783,  0.0742,  0.0288,  0.4450,  0.1079,  0.4279,
         -0.1227, -0.3381],
        [ 0.2989, -0.0272, -0.0189,  0.2187,  0.0531,  0.3550, -0.2004,  0.2271,
         -0.4213,  0.2937],
        [-0.2428,  0.1546,  0.3518,  0.4355,  0.2819, -0.1301,  0.1143,  0.0624,
         -0.3870,  0.3932],
        [-0.3901, -0.4208,  0.2806, -0.0792, -0.1617, -0.2824, -0.2249,  0.2954,
         -0.0652, -0.4695],
        [-0.3162,  0.4881,  0.1006, -0.4630,  0.1849, -0.0221, -0.3098,  0.4377,
          0.3013,  0.1811],
        [ 0.2289, -0.4438,  0.4827,  0.4618, -0.2580, -0.4855, -0.1030,  0.3854,
          0.3921, -0.0597],
        [ 0.4564, -0.4202, -0.3125, -0.4881,  0.0912,  0.0602, -0.3711,  0.1108,
         -0.4060, -0.0427],
        [-0.3145, -0.1864,  0.4832, -0.0598, -0.3233, -0.2183, -0.2146,  0.0414,
         -0.4638, -0.4166],
        [-0.1902,  0.3340, -0.2034, -0.1202,  0.2628,  0.1837, -0.2106,  0.2514,
          0.1065,  0.0827],
        [ 0.0393, -

In [28]:

init_mask_weight = torch.rand(10,10)-0.5
init_mask_weight = torch.where(init_mask_weight > 0, 1, 0) 
init_mask_weight.sum()

tensor(55)

# mon

In [29]:
class Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__(in_features, out_features, bias)      
        init_mask_weight = torch.rand(self.weight.shape) - 0.5
        init_mask_weight = torch.where(init_mask_weight > 0, 1, 0)  
        self.register_buffer('weight_mask', init_mask_weight)
        if self.bias is not None:    
            init_mask_bias = torch.rand(self.bias.shape) - 0.5
            init_mask_bias = torch.where(init_mask_bias > 0, 1, 0)  
            self.register_buffer('bias_mask', init_mask_bias)

    def get_masked_weight(self):
        W = self.weight_mask * self.weight
        return W

    def forward(self, input):
        W = self.weight_mask * self.weight
        if self.bias is not None:
            b = self.bias_mask * self.bias
        else:
            b = self.bias
        return F.linear(input, W, b)

In [30]:
class MONSingleFc(nn.Module):
    """ Simple MON linear class, just a single full multiply. """

    def __init__(self, in_dim, out_dim, m=1.0):
        super().__init__()
        self.U = nn.Linear(in_dim, out_dim)
        self.A = nn.Linear(out_dim, out_dim, bias=False)
        self.B = nn.Linear(out_dim, out_dim, bias=False)
        self.m = m
        self.dropout = nn.Dropout(p=0.8)

    def x_shape(self, n_batch):
        return (n_batch, self.U.in_features)

    def z_shape(self, n_batch):
        return ((n_batch, self.A.in_features),)

    def forward(self, x, *z):
        out = (self.dropout(self.U(x) + self.multiply(*z)[0]),)
        return out

    def bias(self, x):
        return (self.U(x),)

    def multiply(self, *z):
        ATAz = self.A(z[0]) @ self.A.weight
        z_out = (1 - self.m) * z[0] - ATAz + self.B(z[0]) - z[0] @ self.B.weight
        return (z_out,)

    def multiply_transpose(self, *g):
        ATAg = self.A(g[0]) @ self.A.weight
        g_out = (1 - self.m) * g[0] - ATAg - self.B(g[0]) + g[0] @ self.B.weight
        return (g_out,)

    def init_inverse(self, alpha, beta):
        I = torch.eye(self.A.weight.shape[0], dtype=self.A.weight.dtype,
                      device=self.A.weight.device)
        W = (1 - self.m) * I - self.A.weight.T @ self.A.weight + self.B.weight - self.B.weight.T
        self.Winv = torch.inverse(alpha * I + beta * W)

    def inverse(self, *z):
        return (z[0] @ self.Winv.transpose(0, 1),)

    def inverse_transpose(self, *g):
        return (g[0] @ self.Winv,)

In [31]:
class MaskedMONSingleFc(nn.Module):
    """ Simple MON linear class, just a single full multiply. """

    def __init__(self, in_dim, out_dim, m=1.0):
        super().__init__()
        self.U = Linear(in_dim, out_dim)
        self.A = Linear(out_dim, out_dim, bias=False)
        self.B = Linear(out_dim, out_dim, bias=False)
        self.m = m
        # self.dropout = nn.Dropout(p=0.8)

    def x_shape(self, n_batch):
        return (n_batch, self.U.in_features)

    def z_shape(self, n_batch):
        return ((n_batch, self.A.in_features),)

    def forward(self, x, *z):
        out = (self.U(x) + self.multiply(*z)[0],)
        return out

    def bias(self, x):
        return (self.U(x),)

    def multiply(self, *z):
        ATAz = self.A(z[0]) @ self.A.get_masked_weight()
        z_out = (1 - self.m) * z[0] - ATAz + self.B(z[0]) - z[0] @ self.B.get_masked_weight()
        return (z_out,)

    def multiply_transpose(self, *g):
        ATAg = self.A(g[0]) @ self.A.get_masked_weight()
        g_out = (1 - self.m) * g[0] - ATAg - self.B(g[0]) + g[0] @ self.B.get_masked_weight()
        return (g_out,)

    def init_inverse(self, alpha, beta):
        I = torch.eye(self.A.get_masked_weight().shape[0], dtype=self.A.get_masked_weight().dtype,
                      device=self.A.get_masked_weight().device)
        W = (1 - self.m) * I - self.A.get_masked_weight().T @ self.A.get_masked_weight() + self.B.get_masked_weight() - self.B.get_masked_weight().T
        self.Winv = torch.inverse(alpha * I + beta * W)

    def inverse(self, *z):
        return (z[0] @ self.Winv.transpose(0, 1),)

    def inverse_transpose(self, *g):
        return (g[0] @ self.Winv,)

In [4]:
class MONReLU(nn.Module):
    def forward(self, *z):
        return tuple(F.relu(z_) for z_ in z)

    def derivative(self, *z):
        return tuple((z_ > 0).type_as(z[0]) for z_ in z)

# utils

In [5]:
class Meter(object):
    """Computes and stores the min, max, avg, and current values"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.max = -float("inf")
        self.min = float("inf")

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.max = max(self.max, val)
        self.min = min(self.min, val)

In [6]:
class SplittingMethodStats(object):
    def __init__(self):
        self.fwd_iters = Meter()
        self.bkwd_iters = Meter()
        self.fwd_time = Meter()
        self.bkwd_time = Meter()

    def reset(self):
        self.fwd_iters.reset()
        self.fwd_time.reset()
        self.bkwd_iters.reset()
        self.bkwd_time.reset()

# Splitting

In [7]:
from torch.autograd import Function

In [8]:
class MONForwardBackwardSplitting(nn.Module):

    def __init__(self, linear_module, nonlin_module, alpha=1.0, tol=1e-5, max_iter=50, verbose=False):
        super().__init__()
        self.linear_module = linear_module
        self.nonlin_module = nonlin_module
        self.alpha = alpha
        self.tol = tol
        self.max_iter = max_iter
        self.verbose = verbose
        self.stats = SplittingMethodStats()
        self.save_abs_err = False

    def forward(self, x):
        """ Forward pass of the MON, find an equilibirum with forward-backward splitting"""

        start = time.time()
        # Run the forward pass _without_ tracking gradients
        with torch.no_grad():
            z = tuple(torch.zeros(s, dtype=x.dtype, device=x.device)
                      for s in self.linear_module.z_shape(x.shape[0]))
            n = len(z)
            bias = self.linear_module.bias(x)

            err = 1.0
            it = 0
            errs = []
            while (err > self.tol and it < self.max_iter):
                zn = self.linear_module.multiply(*z)
                zn = tuple((1 - self.alpha) * z[i] + self.alpha * (zn[i] + bias[i]) for i in range(n))
                zn = self.nonlin_module(*zn)
                if self.save_abs_err:
                    fn = self.nonlin_module(*self.linear_module(x, *zn))
                    err = sum((zn[i] - fn[i]).norm().item() / (zn[i].norm().item()) for i in range(n))
                    errs.append(err)
                else:
                    err = sum((zn[i] - z[i]).norm().item() / (1e-6 + zn[i].norm().item()) for i in range(n))
                z = zn
                it = it + 1

        # Run the forward pass one more time, tracking gradients, then backward placeholder
        zn = self.linear_module(x, *z)
        zn = self.nonlin_module(*zn)
        zn = self.Backward.apply(self, *zn)
        self.stats.fwd_iters.update(it)
        self.stats.fwd_time.update(time.time() - start)
        self.errs = errs
        return zn

    class Backward(Function):
        @staticmethod
        def forward(ctx, splitter, *z):
            ctx.splitter = splitter
            ctx.save_for_backward(*z)
            return z

        @staticmethod
        def backward(ctx, *g):
            start = time.time()
            sp = ctx.splitter
            n = len(g)
            z = ctx.saved_tensors
            j = sp.nonlin_module.derivative(*z)
            I = [j[i] == 0 for i in range(n)]
            d = [(1 - j[i]) / j[i] for i in range(n)]
            v = tuple(j[i] * g[i] for i in range(n))
            u = tuple(torch.zeros(s, dtype=g[0].dtype, device=g[0].device)
                      for s in sp.linear_module.z_shape(g[0].shape[0]))

            err = 1.0
            it = 0
            errs = []
            while (err > sp.tol and it < sp.max_iter):
                un = sp.linear_module.multiply_transpose(*u)
                un = tuple((1 - sp.alpha) * u[i] + sp.alpha * un[i] for i in range(n))
                un = tuple((un[i] + sp.alpha * (1 + d[i]) * v[i]) / (1 + sp.alpha * d[i]) for i in range(n))
                for i in range(n):
                    un[i][I[i]] = v[i][I[i]]

                err = sum((un[i] - u[i]).norm().item() / (1e-6 + un[i].norm().item()) for i in range(n))
                errs.append(err)
                u = un
                it = it + 1

            dg = sp.linear_module.multiply_transpose(*u)
            dg = tuple(g[i] + dg[i] for i in range(n))

            sp.stats.bkwd_iters.update(it)
            sp.stats.bkwd_time.update(time.time() - start)
            sp.errs = errs
            return (None,) + dg

In [9]:
class MONPeacemanRachford(nn.Module):

    def __init__(self, linear_module, nonlin_module, alpha=1.0, tol=1e-5, max_iter=50, verbose=False):
        super().__init__()
        self.linear_module = linear_module
        self.nonlin_module = nonlin_module
        self.alpha = alpha
        self.tol = tol
        self.max_iter = max_iter
        self.verbose = verbose
        self.stats = SplittingMethodStats()
        self.save_abs_err = False

    def forward(self, x):
        """ Forward pass of the MON, find an equilibirum with forward-backward splitting"""

        start = time.time()
        # Run the forward pass _without_ tracking gradients
        self.linear_module.init_inverse(1 + self.alpha, -self.alpha)
        with torch.no_grad():
            z = tuple(torch.zeros(s, dtype=x.dtype, device=x.device)
                      for s in self.linear_module.z_shape(x.shape[0]))
            u = tuple(torch.zeros(s, dtype=x.dtype, device=x.device)
                      for s in self.linear_module.z_shape(x.shape[0]))

            n = len(z)
            bias = self.linear_module.bias(x)

            err = 1.0
            it = 0
            errs = []
            while (err > self.tol and it < self.max_iter):
                u_12 = tuple(2 * z[i] - u[i] for i in range(n))
                z_12 = self.linear_module.inverse(*tuple(u_12[i] + self.alpha * bias[i] for i in range(n)))
                u = tuple(2 * z_12[i] - u_12[i] for i in range(n))
                zn = self.nonlin_module(*u)

                if self.save_abs_err:
                    fn = self.nonlin_module(*self.linear_module(x, *zn))
                    err = sum((zn[i] - fn[i]).norm().item() / (zn[i].norm().item()) for i in range(n))
                    errs.append(err)
                else:
                    err = sum((zn[i] - z[i]).norm().item() / (1e-6 + zn[i].norm().item()) for i in range(n))
                z = zn
                it = it + 1

        zn = self.linear_module(x, *z)
        zn = self.nonlin_module(*zn)

        zn = self.Backward.apply(self, *zn)
        self.stats.fwd_iters.update(it)
        self.stats.fwd_time.update(time.time() - start)
        self.errs = errs
        return zn

    class Backward(Function):
        @staticmethod
        def forward(ctx, splitter, *z):
            ctx.splitter = splitter
            ctx.save_for_backward(*z)
            return z

        @staticmethod
        def backward(ctx, *g):
            start = time.time()
            sp = ctx.splitter
            n = len(g)
            z = ctx.saved_tensors
            j = sp.nonlin_module.derivative(*z)
            I = [j[i] == 0 for i in range(n)]
            d = [(1 - j[i]) / j[i] for i in range(n)]
            v = tuple(j[i] * g[i] for i in range(n))

            z = tuple(torch.zeros(s, dtype=g[0].dtype, device=g[0].device)
                      for s in sp.linear_module.z_shape(g[0].shape[0]))
            u = tuple(torch.zeros(s, dtype=g[0].dtype, device=g[0].device)
                      for s in sp.linear_module.z_shape(g[0].shape[0]))

            err = 1.0
            errs=[]
            it = 0
            while (err >sp.tol and it < sp.max_iter):
                u_12 = tuple(2 * z[i] - u[i] for i in range(n))
                z_12 = sp.linear_module.inverse_transpose(*u_12)
                u = tuple(2 * z_12[i] - u_12[i] for i in range(n))
                zn = tuple((u[i] + sp.alpha * (1 + d[i]) * v[i]) / (1 + sp.alpha * d[i]) for i in range(n))
                for i in range(n):
                    zn[i][I[i]] = v[i][I[i]]

                err = sum((zn[i] - z[i]).norm().item() / (1e-6 + zn[i].norm().item()) for i in range(n))
                errs.append(err)
                z = zn
                it = it + 1

            dg = sp.linear_module.multiply_transpose(*zn)
            dg = tuple(g[i] + dg[i] for i in range(n))

            sp.stats.bkwd_iters.update(it)
            sp.stats.bkwd_time.update(time.time() - start)
            sp.errs = errs
            return (None,) + dg

# Train

In [10]:
def train(trainLoader, testLoader, model, epochs=15, max_lr=1e-3,
          print_freq=10, change_mo=True, model_path=None, lr_mode='step',
          step=10,tune_alpha=False,max_alpha=1.):

    optimizer = optim.Adam(model.parameters(), lr=max_lr)

    if lr_mode == '1cycle':
        lr_schedule = lambda t: np.interp([t],
                                          [0, (epochs-5)//2, epochs-5, epochs],
                                          [1e-3, max_lr, 1e-3, 1e-3])[0]
    elif lr_mode == 'step':
        lr_scheduler =optim.lr_scheduler.StepLR(optimizer, step, gamma=0.1, last_epoch=-1)
    elif lr_mode != 'constant':
        raise Exception('lr mode one of constant, step, 1cycle')

    if change_mo:
        max_mo = 0.85
        momentum_schedule = lambda t: np.interp([t],
                                                [0, (epochs - 5) // 2, epochs - 5, epochs],
                                                [0.95, max_mo, 0.95, 0.95])[0]

    model = cuda(model)

    for epoch in range(1, 1 + epochs):
        nProcessed = 0
        nTrain = len(trainLoader.dataset)
        model.train()
        start = time.time()
        for batch_idx, batch in enumerate(trainLoader):
            if (batch_idx  == 30 or batch_idx == int(len(trainLoader)/2)) and tune_alpha:
                run_tune_alpha(model, cuda(batch[0]), max_alpha)
            if lr_mode == '1cycle':
                lr = lr_schedule(epoch -  1 + batch_idx/ len(trainLoader))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
            if change_mo:
                beta1 = momentum_schedule(epoch - 1 + batch_idx / len(trainLoader))
                for param_group in optimizer.param_groups:
                    param_group['betas'] = (beta1, optimizer.param_groups[0]['betas'][1])

            data, target = cuda(batch[0]), cuda(batch[1])
            optimizer.zero_grad()
            preds = model(data)
            ce_loss = nn.CrossEntropyLoss()(preds, target)
            ce_loss.backward()
            nProcessed += len(data)
            if batch_idx % print_freq == 0 and batch_idx > 0:
                incorrect = preds.float().argmax(1).ne(target.data).sum()
                err = 100. * incorrect.float() / float(len(data))
                partialEpoch = epoch + batch_idx / len(trainLoader) - 1

            optimizer.step()

        if lr_mode == 'step':
            lr_scheduler.step()

        if model_path is not None:
            torch.save(model.state_dict(), model_path)

        start = time.time()
        test_loss = 0
        incorrect = 0
        model.eval()
        with torch.no_grad():
            for batch in testLoader:
                data, target = cuda(batch[0]), cuda(batch[1])
                preds = model(data)
                ce_loss = nn.CrossEntropyLoss(reduction='sum')(preds, target)
                test_loss += ce_loss
                incorrect += preds.float().argmax(1).ne(target.data).sum()
            test_loss /= len(testLoader.dataset)
            nTotal = len(testLoader.dataset)
            err = 100. * incorrect.float() / float(nTotal)
            print('\n\n\n Epoch: {:d},Test set: Average loss: {:.4f}, Error: {}/{} ({:.2f}%)'.format(
                epoch, test_loss, incorrect, nTotal, err))

In [11]:
def run_tune_alpha(model, x, max_alpha):
    orig_alpha  =  model.mon.alpha
    model.mon.stats.reset()
    model.mon.alpha = max_alpha
    with torch.no_grad():
        model(x)
    iters = model.mon.stats.fwd_iters.val
    model.mon.stats.reset()
    iters_n = iters
    while model.mon.alpha > 1e-4 and iters_n <= iters:
        model.mon.alpha = model.mon.alpha/2
        with torch.no_grad():
            model(x)
        iters = iters_n
        iters_n = model.mon.stats.fwd_iters.val
        model.mon.stats.reset()

    if iters==model.mon.max_iter:
        print("none converged, resetting to current")
        model.mon.alpha=orig_alpha
    else:
        model.mon.alpha = model.mon.alpha * 2
        

In [12]:
def mnist_loaders(train_batch_size, test_batch_size=None):
    if test_batch_size is None:
        test_batch_size = train_batch_size

    trainLoader = torch.utils.data.DataLoader(
        dset.MNIST('data',
                   train=True,
                   download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
        batch_size=train_batch_size,
        shuffle=True)

    testLoader = torch.utils.data.DataLoader(
        dset.MNIST('data',
                   train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
        batch_size=test_batch_size,
        shuffle=False)
    return trainLoader, testLoader

In [13]:
def expand_args(defaults, kwargs):
    d = defaults.copy()
    for k, v in kwargs.items():
        d[k] = v
    return d

In [14]:
MON_DEFAULTS = {
    'alpha': 1.0,
    'tol': 1e-5,
    'max_iter': 50
}

In [32]:
class SingleFcNet(nn.Module):

    def __init__(self, splittingMethod, in_dim=784, out_dim=100, m=0.1, **kwargs):
        super().__init__()
        linear_module = MaskedMONSingleFc(in_dim, out_dim, m=m)
        nonlin_module = MONReLU()
        self.mon = splittingMethod(linear_module, nonlin_module, **expand_args(MON_DEFAULTS, kwargs))
        self.Wout = nn.Linear(out_dim, 10)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        z = self.mon(x)
        return self.Wout(z[-1])


# Prune

In [None]:
import os
# from copyreg import pickle
from tqdm import tqdm
import torch
import numpy as np
import pickle

def prune_loop(model, loss, pruner, dataloader, device, sparsity, schedule, scope, epochs,
               reinitialize=False, train_mode=False, shuffle=False, invert=False, 
               store_mask=False, pruner_name='', compression='', dataset='mnist', args=None):
    r"""Applies score mask loop iteratively to a final sparsity level.
    """
    # Set model to train or eval mode
    model.train()
    if not train_mode:
        model.eval()

    # Prune model
    for epoch in tqdm(range(epochs)):
        pruner.score(model, loss, dataloader, device)
        if schedule == 'exponential':
            sparse = sparsity**((epoch + 1) / epochs)
        elif schedule == 'linear':
            sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
        # Invert scores
        if invert:
            pruner.invert()
        pruner.mask(sparse, scope)
    
    # Reainitialize weights
    if reinitialize:
        model._initialize_weights()

    # Shuffle masks
    if shuffle:
        pruner.shuffle()

    # Confirm sparsity level
    remaining_params, total_params = pruner.stats()
    if np.abs(remaining_params - total_params*sparsity) >= 5:
        print("ERROR: {} prunable parameters remaining, expected {}".format(remaining_params, total_params*sparsity))
        quit()
    
    if store_mask:
        # pruner_name = args.pruner
        if args.shuffle:
            pruner_name = 'shuffled_' + args.pruner
        if args.pruner in ['grasp', 'snip'] and args.prune_epochs == 100:
            pruner_name = 'iterative_' + pruner_name
        if args.pruner == 'synflow' and args.prune_epochs == 1:
            pruner_name = 'oneshot_' + pruner_name
        
        if args.model is not 'fc':
            file_path = f'./Reproduced_Results/Masks/{args.dataset}_{args.model}/{args.init_type}/pre_epoch_{args.pre_epochs}/compression_{int(args.compression*100)}/{pruner_name}'
        else:
            file_path = f'./Reproduced_Results/Masks/{args.dataset}_{args.model}/{args.init_type}/pre_epoch_{args.pre_epochs}/MLP_{args.n_layers}_layers_{args.n_neurons}/compression_{int(args.compression*100)}/{pruner_name}'
        
        if not os.path.exists(file_path):
            os.makedirs(file_path)
        file_name = f'{file_path}/{pruner_name}_{int(compression*100)}.pkl'
        stored_data = {}
        stored_masks = []
        stored_params = []
        for m, p in pruner.masked_parameters:
            stored_masks.append(m.detach().cpu().numpy())
            stored_params.append(p.detach().cpu().numpy())
        stored_data['mask'] = stored_masks
        stored_data['param'] = stored_params
        
        with open(file_name, 'wb') as f:
            pickle.dump(stored_data, f)


# Running

In [33]:
trainLoader, testLoader = mnist_loaders(train_batch_size=128, test_batch_size=400)

model = SingleFcNet(MONPeacemanRachford,
                        in_dim=28**2,
                        out_dim=20,
                        alpha=1.0,
                        max_iter=300,
                        tol=1e-2,
                        m=1.0) #parameter which controls the strong monotonicity of W

train(trainLoader, testLoader,
        model,
        max_lr=1e-3,
        lr_mode='step',  #use step decay learning rate
        step=10,      
        change_mo=False, #do not adjust momentum during training
        epochs=20,
        print_freq=200,
        tune_alpha=False)




 Epoch: 1,Test set: Average loss: 0.3020, Error: 866/10000 (8.66%)



 Epoch: 2,Test set: Average loss: 0.2534, Error: 743/10000 (7.43%)



 Epoch: 3,Test set: Average loss: 0.2309, Error: 671/10000 (6.71%)



 Epoch: 4,Test set: Average loss: 0.2085, Error: 615/10000 (6.15%)



 Epoch: 5,Test set: Average loss: 0.1931, Error: 542/10000 (5.42%)



 Epoch: 6,Test set: Average loss: 0.1844, Error: 534/10000 (5.34%)



 Epoch: 7,Test set: Average loss: 0.1820, Error: 542/10000 (5.42%)



 Epoch: 8,Test set: Average loss: 0.1724, Error: 502/10000 (5.02%)



 Epoch: 9,Test set: Average loss: 0.1706, Error: 489/10000 (4.89%)



 Epoch: 10,Test set: Average loss: 0.1589, Error: 479/10000 (4.79%)



 Epoch: 11,Test set: Average loss: 0.1541, Error: 452/10000 (4.52%)



 Epoch: 12,Test set: Average loss: 0.1540, Error: 447/10000 (4.47%)



 Epoch: 13,Test set: Average loss: 0.1525, Error: 438/10000 (4.38%)



 Epoch: 14,Test set: Average loss: 0.1528, Error: 440/10000 (4.40%)



 Epoch: 15,T

In [36]:
model.mon.linear_module.A.weight_mask.sum()

tensor(214, device='cuda:0')

In [37]:
model.mon.linear_module.A.weight.shape

torch.Size([20, 20])

In [35]:
model

SingleFcNet(
  (mon): MONPeacemanRachford(
    (linear_module): MaskedMONSingleFc(
      (U): Linear(in_features=784, out_features=20, bias=True)
      (A): Linear(in_features=20, out_features=20, bias=False)
      (B): Linear(in_features=20, out_features=20, bias=False)
    )
    (nonlin_module): MONReLU()
  )
  (Wout): Linear(in_features=20, out_features=10, bias=True)
)