Will learn how to use `torch.nn.utils.prune` to sparsify your networks, and how to extend it to implement your own custom pruning technique.

* Requirement: torch>=1.4.0a0+8e8a5e0

1. Create a model

In [6]:
import torch
import torch.nn.functional as F

from torch import nn


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)

print(f"Number of parameters: {sum((p.numel() for p in model.parameters() if p.requires_grad))}")

Number of parameters: 60074


2. Inspect a Module

Let's inspect the (unpruned) `conv1` layer in LetNet model. It will contain 2 parameters `weight` and `bias`, and no buffer, for now.

In [8]:
module = model.conv1
print(list(module.named_parameters()))  # List[Tuple[name of parameters, tensor of parameters]]

[('weight', Parameter containing:
tensor([[[[ 0.2210,  0.2963, -0.1385],
          [-0.2852,  0.1365, -0.1321],
          [-0.2140, -0.2728,  0.3242]]],


        [[[-0.2939,  0.2436, -0.2873],
          [-0.2664, -0.1056, -0.1238],
          [-0.3064,  0.2746, -0.1016]]],


        [[[ 0.1405,  0.1296, -0.2949],
          [-0.2292,  0.0637, -0.3108],
          [ 0.0451, -0.2222, -0.0419]]],


        [[[-0.1592, -0.0800, -0.2572],
          [-0.0362,  0.0274, -0.0213],
          [ 0.0568,  0.0709,  0.3153]]],


        [[[ 0.2424,  0.1485, -0.2695],
          [ 0.0359,  0.0299,  0.2123],
          [-0.0808,  0.0516, -0.1402]]],


        [[[ 0.1840, -0.0456,  0.2920],
          [-0.1896, -0.3016, -0.0747],
          [-0.0215, -0.1728, -0.0849]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1951,  0.1546, -0.1809,  0.2764, -0.2611, -0.2624],
       requires_grad=True))]


In [9]:
print(list(module.named_buffers()))  # have no buffer

[]


3. Pruning a Module

In [10]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

In [11]:
print(list(module.named_parameters()))

[('bias', Parameter containing:
tensor([ 0.1951,  0.1546, -0.1809,  0.2764, -0.2611, -0.2624],
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.2210,  0.2963, -0.1385],
          [-0.2852,  0.1365, -0.1321],
          [-0.2140, -0.2728,  0.3242]]],


        [[[-0.2939,  0.2436, -0.2873],
          [-0.2664, -0.1056, -0.1238],
          [-0.3064,  0.2746, -0.1016]]],


        [[[ 0.1405,  0.1296, -0.2949],
          [-0.2292,  0.0637, -0.3108],
          [ 0.0451, -0.2222, -0.0419]]],


        [[[-0.1592, -0.0800, -0.2572],
          [-0.0362,  0.0274, -0.0213],
          [ 0.0568,  0.0709,  0.3153]]],


        [[[ 0.2424,  0.1485, -0.2695],
          [ 0.0359,  0.0299,  0.2123],
          [-0.0808,  0.0516, -0.1402]]],


        [[[ 0.1840, -0.0456,  0.2920],
          [-0.1896, -0.3016, -0.0747],
          [-0.0215, -0.1728, -0.0849]]]], requires_grad=True))]


In [14]:
batch_size = 8
index = torch.randperm(batch_size)
index

tensor([3, 1, 6, 0, 2, 4, 5, 7])

In [65]:
import numpy as np
import torch
import torch.nn as nn


class Anchors(nn.Module):
    def __init__(self, pyramid_levels=None, strides=None, sizes=None, ratios=None, scales=None):
        super(Anchors, self).__init__()

        if pyramid_levels is None:
            self.pyramid_levels = [3, 4, 5, 6, 7]
        if strides is None:
            self.strides = [2 ** x for x in self.pyramid_levels]
        if sizes is None:
            self.sizes = [2 ** (x + 2) for x in self.pyramid_levels]
        if ratios is None:
            self.ratios = np.array([0.5, 1, 2])
        if scales is None:
            self.scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

    def forward(self, image):
        
        image_shape = image.shape[2:]
        image_shape = np.array(image_shape)
        image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels]
        print(image_shapes)

        # compute anchors over all pyramid levels
        all_anchors = np.zeros((0, 4)).astype(np.float32)

        for idx, p in enumerate(self.pyramid_levels):
            anchors         = generate_anchors(base_size=self.sizes[idx], ratios=self.ratios, scales=self.scales)
            shifted_anchors = shift(image_shapes[idx], self.strides[idx], anchors)
            all_anchors     = np.append(all_anchors, shifted_anchors, axis=0)

        all_anchors = np.expand_dims(all_anchors, axis=0)

        if torch.cuda.is_available():
            return torch.from_numpy(all_anchors.astype(np.float32)).cuda()
        else:
            return torch.from_numpy(all_anchors.astype(np.float32))

def generate_anchors(base_size=16, ratios=None, scales=None):
    """
    Generate anchor (reference) windows by enumerating aspect ratios X
    scales w.r.t. a reference window.
    """

    if ratios is None:
        ratios = np.array([0.5, 1, 2])

    if scales is None:
        scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])

    num_anchors = len(ratios) * len(scales)

    # initialize output anchors
    anchors = np.zeros((num_anchors, 4))

    # scale base_size
    anchors[:, 2:] = base_size * np.tile(scales, (2, len(ratios))).T

    # compute areas of anchors
    areas = anchors[:, 2] * anchors[:, 3]

    # correct for ratios
    anchors[:, 2] = np.sqrt(areas / np.repeat(ratios, len(scales)))
    anchors[:, 3] = anchors[:, 2] * np.repeat(ratios, len(scales))

    # transform from (x_ctr, y_ctr, w, h) -> (x1, y1, x2, y2)
    anchors[:, 0::2] -= np.tile(anchors[:, 2] * 0.5, (2, 1)).T
    anchors[:, 1::2] -= np.tile(anchors[:, 3] * 0.5, (2, 1)).T

    return anchors

def compute_shape(image_shape, pyramid_levels):
    """Compute shapes based on pyramid levels.
    :param image_shape:
    :param pyramid_levels:
    :return:
    """
    image_shape = np.array(image_shape[:2])
    image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in pyramid_levels]
    return image_shapes


def anchors_for_shape(
    image_shape,
    pyramid_levels=None,
    ratios=None,
    scales=None,
    strides=None,
    sizes=None,
    shapes_callback=None,
):

    image_shapes = compute_shape(image_shape, pyramid_levels)

    # compute anchors over all pyramid levels
    all_anchors = np.zeros((0, 4))
    for idx, p in enumerate(pyramid_levels):
        anchors         = generate_anchors(base_size=sizes[idx], ratios=ratios, scales=scales)
        shifted_anchors = shift(image_shapes[idx], strides[idx], anchors)
        all_anchors     = np.append(all_anchors, shifted_anchors, axis=0)

    return all_anchors


def shift(shape, stride, anchors):
    shift_x = (np.arange(0, shape[1]) + 0.5) * stride
    shift_y = (np.arange(0, shape[0]) + 0.5) * stride

    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    shifts = np.vstack((
        shift_x.ravel(), shift_y.ravel(),
        shift_x.ravel(), shift_y.ravel()
    )).transpose()

    # add A anchors (1, A, 4) to
    # cell K shifts (K, 1, 4) to get
    # shift anchors (K, A, 4)
    # reshape to (K*A, 4) shifted anchors
    A = anchors.shape[0]
    K = shifts.shape[0]
    all_anchors = (anchors.reshape((1, A, 4)) + shifts.reshape((1, K, 4)).transpose((1, 0, 2)))
    all_anchors = all_anchors.reshape((K * A, 4))

    return all_anchors

In [66]:
anchor = Anchors()

In [67]:
x = torch.FloatTensor(1, 3, 512, 512)

In [69]:
anchor(x)[:,:10,:]

[array([64, 64]), array([32, 32]), array([16, 16]), array([8, 8]), array([4, 4])]


tensor([[[-18.6274,  -7.3137,  26.6274,  15.3137],
         [-24.5088, -10.2544,  32.5088,  18.2544],
         [-31.9188, -13.9594,  39.9188,  21.9594],
         [-12.0000, -12.0000,  20.0000,  20.0000],
         [-16.1587, -16.1587,  24.1587,  24.1587],
         [-21.3984, -21.3984,  29.3984,  29.3984],
         [ -7.3137, -18.6274,  15.3137,  26.6274],
         [-10.2544, -24.5088,  18.2544,  32.5088],
         [-13.9594, -31.9188,  21.9594,  39.9188],
         [-10.6274,  -7.3137,  34.6274,  15.3137]]])

In [None]:
tensor([[[-18.6274,  -7.3137,  26.6274,  15.3137],
         [-12.0000, -12.0000,  20.0000,  20.0000],
         [ -7.3137, -18.6274,  15.3137,  26.6274],
         [-24.5088, -10.2544,  32.5088,  18.2544],
         [-16.1587, -16.1587,  24.1587,  24.1587],
         [-10.2544, -24.5088,  18.2544,  32.5088],
         [-31.9188, -13.9594,  39.9188,  21.9594],
         [-21.3984, -21.3984,  29.3984,  29.3984],
         [-13.9594, -31.9188,  21.9594,  39.9188],
         [-10.6274,  -7.3137,  34.6274,  15.3137]]])

In [71]:
import math
import torch
import itertools
import numpy as np
import torch.nn as nn

from typing import List, Tuple


class Anchors(nn.Module):
    def __init__(
        self,
        anchor_scale: float = 4.,  # NOTE!!: anchor_scale = 4. if compound_coef != 7 else 5.
        scales: List[float] = [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)],
        aspect_ratios: List[float] = [0.5, 1., 2.]  # width_box / height_box
    ):
        super(Anchors, self).__init__()
        self.scales = scales
        self.aspect_ratios = aspect_ratios
        self.anchor_scale = anchor_scale

    def forward(self, inputs: torch.Tensor, features: Tuple[torch.Tensor]) -> torch.Tensor:
        """Generates multiscale anchor boxes.
        Args:
            inputs: Tensor (B x N x H x W): H = W = 128*compound_coef + 512
            features: Tuple (Tensor[B x N' x H' x W']): tuple of tensors get from output of biFPN
        Output:
            anchors: Tensor[1 x all_anchors x 4]: all anchors of all pyramid features
        """

        dtype, device = inputs.dtype, inputs.device
        _, _, image_height, image_width = inputs.shape   # inputs: B x N x H x W

        # stride of anchors on input size
        features_sizes = [feature.shape[2:] for feature in features]   # List[[H_feature, W_feature]]
        strides = [
            (image_height // feature_height, image_width // feature_width)
            for feature_height, feature_width in features_sizes
        ]

        anchors_over_all_pyramid_features = []
        for stride_height, stride_width in strides:

            anchors_per_pyramid_feature = []
            for scale, ratio in itertools.product(self.scales, self.aspect_ratios):
                if (image_width % stride_width != 0) or (image_height % stride_height != 0):
                    raise ValueError('input size must be divided by the stride.')

                # anchor base size
                base_anchor_width = self.anchor_scale * stride_width
                base_anchor_height = self.anchor_scale * stride_height

                # anchor size
                anchor_width = base_anchor_width * scale * math.sqrt(1 / ratio)
                anchor_height = base_anchor_height * scale * math.sqrt(ratio)

                # center of anchors
                cx = torch.arange(
                    start=stride_width / 2, end=image_width, step=stride_width, device=device, dtype=dtype
                )
                cy = torch.arange(
                    start=stride_height / 2, end=image_height, step=stride_height, device=device, dtype=dtype
                )

                cx, cy = torch.meshgrid(cx, cy)
                cx, cy = cx.t().reshape(-1), cy.t().reshape(-1)

                # coodinates of each anchors: format anchor boxes # x1,y1,x2,y2
                anchors = torch.stack(
                    (
                        cx - anchor_width / 2.,
                        cy - anchor_height / 2.,
                        cx + anchor_width / 2.,
                        cy + anchor_height / 2.,
                    ), dim=1
                )  # num_anchors x 4

                anchors = anchors.unsqueeze(dim=1)  # num_anchors x 1 x 4
                anchors_per_pyramid_feature.append(anchors)

            # num_anchors x (scale * aspect_ratios) x 4
            anchors_per_pyramid_feature = torch.cat(anchors_per_pyramid_feature, dim=1)
            # (num_anchors * scale * aspect_ratios) x 4
            anchors_per_pyramid_feature = anchors_per_pyramid_feature.reshape(-1, 4)
            anchors_over_all_pyramid_features.append(anchors_per_pyramid_feature)

        # [(num_anchors * scale * aspect_ratios) * pyramid_levels] x 4
        anchors = torch.vstack(anchors_over_all_pyramid_features)

        return anchors.unsqueeze(dim=0)

In [72]:
anchor = Anchors()

In [73]:
inputs = torch.FloatTensor(1, 3, 512, 512)
features = [
    torch.FloatTensor(1, 256, 64, 64),
    torch.FloatTensor(1, 256, 32, 32),
    torch.FloatTensor(1, 256, 16, 16),
    torch.FloatTensor(1, 256, 8, 8),
    torch.FloatTensor(1, 256, 4, 4),
]

In [75]:
anchor(inputs, features).shape

torch.Size([1, 49104, 4])

In [81]:
a = torch.FloatTensor(5)
b = torch.FloatTensor(5)

In [84]:
c = torch.stack((a, b)).t()
c.shape

torch.Size([5, 2])

In [87]:
a = torch.FloatTensor(5, 10, 3)
scores, _ = a.max(dim=2)
print(scores.shape)
print(scores)

torch.Size([5, 10])
tensor([[5.9969e-22, 4.5785e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41,
         4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41],
        [4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41,
         4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41],
        [4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41,
         4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41],
        [4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41,
         4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41],
        [4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41, 4.5783e-41,
         4.5783e-41, 4.5783e-41, 0.0000e+00, 0.0000e+00]])


In [88]:
batch_scores_over_threshold = (scores > 1e-50)
print(batch_scores_over_threshold.shape)
print(batch_scores_over_threshold)

torch.Size([5, 10])
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False]])


In [93]:
sample_scores_over_threshold = batch_scores_over_threshold[4, :]
print(sample_scores_over_threshold.shape)
print(sample_scores_over_threshold)

torch.Size([10])
tensor([ True,  True,  True,  True,  True,  True,  True,  True, False, False])


In [94]:
a[1, sample_scores_over_threshold, :]

tensor([[-1.0862e-27,  4.5783e-41, -1.3418e-27],
        [ 4.5783e-41, -1.3416e-27,  4.5783e-41],
        [-1.3415e-27,  4.5783e-41, -1.3418e-27],
        [ 4.5783e-41, -1.3417e-27,  4.5783e-41],
        [-1.3415e-27,  4.5783e-41, -1.3417e-27],
        [ 4.5783e-41, -1.3415e-27,  4.5783e-41],
        [-1.3417e-27,  4.5783e-41, -1.3416e-27],
        [ 4.5783e-41, -1.3416e-27,  4.5783e-41]])

In [95]:
a[1, sample_scores_over_threshold, :].shape

torch.Size([8, 3])