As a baseline for experiments with dense-to-sparse pruning algorithms, an implementation of global magnitude pruning is provided below ("prune_gmp"). 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
#from torch.nn.utils import prune
from models import VGG, resnet
from data_utils import get_dataloader
from prune_utils import _count_unmasked_weights, _count_all_weights, prune_weights_reparam, prune_weights_erk

def train(net, loader, optimizer, criterion, epoch, verbose=False):
    net.train()
    train_loss, correct, total = 0, 0, 0
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    if verbose:
        print("epoch, learning rate", epoch, optimizer.param_groups[0]['lr'], "training accuracy", 100*correct/total)

def test(net, loader, verbose=False):
    criterion = nn.CrossEntropyLoss()
    net.eval()
    test_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = 100. * correct / total
    if verbose:
      print("testing accuracy is", acc)
    return acc


def prune_gmp(dataset="cifar100", network="resnet", num_epochs=240, sparsity=0.05, k=80, l=100, verbose=False):
    if network=="resnet":
        net = resnet(depth=32, dataset=dataset).cuda()
    if network=="vgg":
        net = VGG(depth=19, dataset=dataset, batchnorm=True).cuda()
    trainloader, testloader = get_dataloader(dataset, 128, 128, 2, root='/content')
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80,180], gamma=0.1)
    prune_weights_reparam(net)
    alpha = 1-sparsity**(1/float(l))
    if verbose:
      print(net)
      print("pruning amount is", alpha)
    for epoch in range(num_epochs):
        train(net, trainloader, optimizer, criterion, epoch, verbose)
        test(net, testloader, verbose)
        unmasked_weights = _count_unmasked_weights(net)
        sparsity_ratio = torch.sum(unmasked_weights)/_count_all_weights(net)
        scheduler.step()
        if verbose:
          print("unmasked weights", unmasked_weights, "sparsity ratio", sparsity_ratio)
        if epoch>k and epoch<=k+l:
          prune_weights_global(net,alpha)

prune_gmp(verbose=True)