In [1]:
import os
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.autograd as autograd

In [2]:
class GetSubnet(autograd.Function):
    @staticmethod
    def forward(ctx, scores, k):
        # Get the subnetwork by sorting the scores and using the top k%
        out = scores.clone()
        _, idx = scores.flatten().sort()
        j = int((1 - k) * scores.numel())

        # flat_out and out access the same memory.
        flat_out = out.flatten()
        flat_out[idx[:j]] = 0
        flat_out[idx[j:]] = 1

        return out

    @staticmethod
    def backward(ctx, g):
        # send the gradient g straight-through on the backward pass.
        return g, None

class NonAffineBatchNorm(nn.BatchNorm2d):
    def __init__(self, dim):
        super(NonAffineBatchNorm, self).__init__(dim, affine=False)

class SubnetConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.scores = nn.Parameter(torch.Tensor(self.weight.size()))
        nn.init.kaiming_uniform_(self.scores, a=math.sqrt(5))
        self.prune_rate = prune_rate

    def set_prune_rate(self, prune_rate):
        self.prune_rate = prune_rate

    @property
    def clamped_scores(self):
        return self.scores.abs()

    def forward(self, x):
        subnet = GetSubnet.apply(self.clamped_scores, self.prune_rate)
        w = self.weight * subnet
        x = F.conv2d(
            x, w, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

In [5]:
batch_size = 256
test_batch_size = 1000
epochs = 20
lr = 0.1
momentum = 0.9
weight_decay = 0.0005
log_interval = 10
data_path = "data"
sparsity = 0.5
save_model = True
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
last_layer_dense = True
num_class = 10
nonlinearity = "relu"
init = "kaiming_normal"
mode = "fan_in"
scale_fan = False
prune_rate = 0
conv_type = SubnetConv
bn_type = NonAffineBatchNorm
first_layer_type = None
first_layer_dense = True
gpu = None

In [6]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST(os.path.join(data_path, 'mnist'), train=True, download=True,
                                                          transform=transforms.Compose([transforms.ToTensor(),
                                                                                        transforms.Normalize((0.1307,), (0.3081,))])),
                                           batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.MNIST(os.path.join(data_path, 'mnist'), train=False,
                                                         transform=transforms.Compose([transforms.ToTensor(),
                                                                                       transforms.Normalize((0.1307,), (0.3081,))])),
                                          batch_size=test_batch_size, shuffle=True)

In [7]:
class Builder(object):
    def __init__(self, conv_layer, bn_layer, first_layer=None):
        self.conv_layer = conv_layer
        self.bn_layer = bn_layer
        self.first_layer = first_layer or conv_layer


    def conv(self, kernel_size, in_planes, out_planes, stride=1, first_layer=False):
        conv_layer = self.first_layer if first_layer else self.conv_layer

        if first_layer:
            print(f"==> Building first layer with {str(self.first_layer)}")

        if kernel_size == 3:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            )
        elif kernel_size == 1:
            conv = conv_layer(
                in_planes, out_planes, kernel_size=1, stride=stride, bias=False
            )
        elif kernel_size == 5:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=5,
                stride=stride,
                padding=2,
                bias=False,
            )
        elif kernel_size == 7:
            conv = conv_layer(
                in_planes,
                out_planes,
                kernel_size=7,
                stride=stride,
                padding=3,
                bias=False,
            )
        else:
            return None

        self._init_conv(conv)

        return conv


    def conv3x3(self, in_planes, out_planes, stride=1, first_layer=False):
        """3x3 convolution with padding"""
        c = self.conv(3, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv1x1(self, in_planes, out_planes, stride=1, first_layer=False):
        """1x1 convolution with padding"""
        c = self.conv(1, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv7x7(self, in_planes, out_planes, stride=1, first_layer=False):
        """7x7 convolution with padding"""
        c = self.conv(7, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def conv5x5(self, in_planes, out_planes, stride=1, first_layer=False):
        """5x5 convolution with padding"""
        c = self.conv(5, in_planes, out_planes, stride=stride, first_layer=first_layer)
        return c

    def batchnorm(self, planes, last_bn=False, first_layer=False):
        return self.bn_layer(planes)

    def activation(self):
        if nonlinearity == "relu":
            return (lambda: nn.ReLU(inplace=True))()
        else:
            raise ValueError(f"{nonlinearity} is not an initialization option!")

    def _init_conv(self, conv):
        if init == "signed_constant":

            fan = nn.init._calculate_correct_fan(conv.weight, mode)
            if scale_fan:
                fan = fan * (1 - prune_rate)
            gain = nn.init.calculate_gain(nonlinearity)
            std = gain / math.sqrt(fan)
            conv.weight.data = conv.weight.data.sign() * std

        elif init == "unsigned_constant":

            fan = nn.init._calculate_correct_fan(conv.weight, mode)
            if scale_fan:
                fan = fan * (1 - prune_rate)

            gain = nn.init.calculate_gain(nonlinearity)
            std = gain / math.sqrt(fan)
            conv.weight.data = torch.ones_like(conv.weight.data) * std

        elif init == "kaiming_normal":

            if scale_fan:
                fan = nn.init._calculate_correct_fan(conv.weight, mode)
                fan = fan * (1 - prune_rate)
                gain = nn.init.calculate_gain(nonlinearity)
                std = gain / math.sqrt(fan)
                with torch.no_grad():
                    conv.weight.data.normal_(0, std)
            else:
                nn.init.kaiming_normal_(
                    conv.weight, mode=mode, nonlinearity=nonlinearity
                )

        elif init == "kaiming_uniform":
            nn.init.kaiming_uniform_(
                conv.weight, mode=mode, nonlinearity=nonlinearity
            )
        elif init == "xavier_normal":
            nn.init.xavier_normal_(conv.weight)
        elif init == "xavier_constant":

            fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(conv.weight)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            conv.weight.data = conv.weight.data.sign() * std

        elif init == "standard":

            nn.init.kaiming_uniform_(conv.weight, a=math.sqrt(5))

        else:
            raise ValueError(f"{init} is not an initialization option!")


def get_builder():

    print("==> Conv Type: {}".format(conv_type))
    print("==> BN Type: {}".format(bn_type))

    #conv_layer = getattr(utils.conv_type, conv_type)
    #bn_layer = getattr(utils.bn_type, bn_type)
    conv_layer = conv_type
    bn_layer = bn_type

    if first_layer_type is not None:
        first_layer = getattr(conv_type, first_layer_type)
        print(f"==> First Layer Type: {first_layer_type}")
    else:
        first_layer = None

    builder = Builder(conv_layer=conv_layer, bn_layer=bn_layer, first_layer=first_layer)
    #builder = Builder(conv_layer=None, bn_layer=None, first_layer=None)

    return builder

In [8]:
class BasicBlock(nn.Module):
    #M = 2
    expansion = 1

    def __init__(self, builder, inplanes, planes, stride=1, downsample=None, base_width=64):
        super(BasicBlock, self).__init__()
        if base_width / 64 > 1:
            raise ValueError("Base width >64 does not work for BasicBlock")

        self.conv1 = builder.conv3x3(inplanes, planes, stride)
        self.bn1 = builder.batchnorm(planes)
        self.relu = builder.activation()
        self.conv2 = builder.conv3x3(planes, planes)
        self.bn2 = builder.batchnorm(planes, last_bn=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        if self.bn1 is not None:
            out = self.bn1(out)

        out = self.relu(out)

        out = self.conv2(out)

        if self.bn2 is not None:
            out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class Bottleneck(nn.Module):
    #M = 3
    expansion = 4

    def __init__(self, builder, inplanes, planes, stride=1, downsample=None, base_width=64):
        super(Bottleneck, self).__init__()
        width = int(planes * base_width / 64)
        self.conv1 = builder.conv1x1(inplanes, width)
        self.bn1 = builder.batchnorm(width)
        self.conv2 = builder.conv3x3(width, width, stride=stride)
        self.bn2 = builder.batchnorm(width)
        self.conv3 = builder.conv1x1(width, planes * self.expansion)
        self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True)
        self.relu = builder.activation()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, builder, block, layers, num_classes=1000, base_width=64):
        self.inplanes = 64
        super(ResNet, self).__init__()

        self.base_width = base_width
        if self.base_width // 64 > 1:
            print(f"==> Using {self.base_width // 64}x wide model")

        if first_layer_dense:
            self.conv1 = nn.Conv2d(
                1, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
            )
        else:
            self.conv1 = builder.conv7x7(3, 64, stride=2, first_layer=True)

        self.bn1 = builder.batchnorm(64)
        self.relu = builder.activation()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(builder, block, 64, layers[0])
        self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        if last_layer_dense:
            self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1)
        else:
            self.fc = builder.conv1x1(512 * block.expansion, num_classes)

    def _make_layer(self, builder, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            dconv = builder.conv1x1(
                self.inplanes, planes * block.expansion, stride=stride
            )
            dbn = builder.batchnorm(planes * block.expansion)
            if dbn is not None:
                downsample = nn.Sequential(dconv, dbn)
            else:
                downsample = dconv

        layers = []
        layers.append(block(builder, self.inplanes, planes, stride, downsample, base_width=self.base_width))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(builder, self.inplanes, planes, base_width=self.base_width))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        if self.bn1 is not None:
            x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.fc(x)
        x = x.view(x.size(0), -1)

        return x

In [47]:
import time
from tqdm import tqdm

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def train(train_loader, model, criterion, optimizer, epoch, writer = None):
    # batch_time = AverageMeter("Time", ":6.3f")
    # data_time = AverageMeter("Data", ":6.3f")
    # losses = AverageMeter("Loss", ":.3f")
    # top1 = AverageMeter("Acc@1", ":6.2f")
    # top5 = AverageMeter("Acc@5", ":6.2f")
    # progress = ProgressMeter(len(train_loader),[batch_time, data_time, losses, top1, top5],prefix=f"Epoch: [{epoch}]",)

    # switch to train mode
    model.train()

    batch_size = train_loader.batch_size
    num_batches = len(train_loader)
    end = time.time()
    for i, (images, target) in tqdm(enumerate(train_loader), ascii=True, total=len(train_loader)):
        # measure data loading time
        # data_time.update(time.time() - end)

        if gpu is not None:
            images = images.to(device)

        target = target.to(device)

        # compute output
        output = model(images)

        loss = criterion(output, target)

        # measure accuracy and record loss
        # print(output.shape)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        # losses.update(loss.item(), images.size(0))
        # top1.update(acc1.item(), images.size(0))
        # top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        # batch_time.update(time.time() - end)
        end = time.time()

        # if i % args.print_freq == 0:
            # t = (num_batches * epoch + i) * batch_size
            # progress.display(i)
            # progress.write_to_tensorboard(writer, prefix="train", global_step=t)

    return loss


def validate(val_loader, model, criterion, writer = None, epoch = None):
    # batch_time = AverageMeter("Time", ":6.3f", write_val=False)
    # losses = AverageMeter("Loss", ":.3f", write_val=False)
    # top1 = AverageMeter("Acc@1", ":6.2f", write_val=False)
    # top5 = AverageMeter("Acc@5", ":6.2f", write_val=False)
    # progress = ProgressMeter(
    #     len(val_loader), [batch_time, losses, top1, top5], prefix="Test: ")

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in tqdm(enumerate(val_loader), ascii=True, total=len(val_loader)):
            if gpu is not None:
                images = images.to(device)

            target = target.to(device)

            # compute output
            output = model(images)

            loss = criterion(output, target)
            # print(output.shape)
            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            # losses.update(loss.item(), images.size(0))
            # top1.update(acc1.item(), images.size(0))
            # top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            # batch_time.update(time.time() - end)
            end = time.time()

            # if i % args.print_freq == 0:
                # progress.display(i)

        # progress.display(len(val_loader))

        # if writer is not None:
            # progress.write_to_tensorboard(writer, prefix="test", global_step=epoch)

    return acc1, acc5


In [None]:
model = ResNet(get_builder(), Bottleneck, [3, 4, 23, 3], 10).to(device)

optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad],
                      lr=lr,
                      momentum=momentum,
                      weight_decay=weight_decay)

criterion = nn.CrossEntropyLoss().to(device)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
acc1, acc5 = validate(test_loader, model, criterion)
print(acc1, acc5)

for epoch in range(1, epochs+1):
    loss = train(train_loader, model, criterion, optimizer, epoch)
    acc1, acc5 = validate(test_loader, model, criterion)
    print(loss, acc1, acc5)
    scheduler.step()

if save_model:
    torch.save(model.state_dict(), "mnist_res50.pt")