# Imports

In [3]:
!pip install torchsummary


Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms
from torchsummary import summary
import os
from zipfile import ZipFile
import urllib.request
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from collections import OrderedDict


# Model

In [32]:
class ZFNet(nn.Module):

  def __init__(self):
    super(ZFNet,self).__init__()

    # CONV PART
    self.features=nn.Sequential(OrderedDict([
        ('conv1',nn.Conv2d(3,96,kernel_size=7,stride=2,padding=1)),
        ('act1',nn.ReLU()),
        ('pool1',nn.MaxPool2d(kernel_size=3,stride=2,padding=1,return_indices=True)),
        ('conv2', nn.Conv2d(96, 256, kernel_size=5, stride=2, padding=0)),
        ('act2', nn.ReLU()),
        ('pool2', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, return_indices=True)),
        ('conv3', nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1)),
        ('act3', nn.ReLU()),
        ('conv4', nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1)),
        ('act4', nn.ReLU()),
        ('conv5', nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1)),
        ('act5', nn.ReLU()),
        ('pool5', nn.MaxPool2d(kernel_size=3, stride=2, padding=0, return_indices=True))
        ]))

    self.feature_outputs=[0]*len(self.features)
    self.switch_indices=dict()
    self.sizes=dict()

    self.classifier=nn.Sequential(OrderedDict([
        ('fc6',nn.Linear(9216,4096)),
        ('act6',nn.ReLU()),
        ('fc7',nn.Linear(4096,4096)),
        ('act7',nn.ReLU()),
        ('fc8',nn.Linear(4096,1000))
    ]))

    # DECONV PART

    self.deconv_pool5=nn.MaxUnpool2d(kernel_size=3,stride=2,padding=0)
    self.deconv_act5=nn.ReLU()
    self.deconv_conv5=nn.ConvTranspose2d(256,384,kernel_size=3,stride=1,padding=1,bias=False)

    self.deconv_act4 = nn.ReLU()
    self.deconv_conv4 = nn.ConvTranspose2d(384,
                                            384,
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               bias=False)

    self.deconv_act3 = nn.ReLU()
    self.deconv_conv3 = nn.ConvTranspose2d(384,
                                               256,
                                               kernel_size=3,
                                               stride=1,
                                               padding=1,
                                               bias=False)

    self.deconv_pool2 = nn.MaxUnpool2d(kernel_size=3,
                                           stride=2,
                                           padding=1)
    self.deconv_act2 = nn.ReLU()
    self.deconv_conv2 = nn.ConvTranspose2d(256,
                                               96,
                                               kernel_size=5,
                                               stride=2,
                                               padding=0,
                                               bias=False)

    self.deconv_pool1 = nn.MaxUnpool2d(kernel_size=3,
                                           stride=2,
                                           padding=1)
    self.deconv_act1 = nn.ReLU()
    self.deconv_conv1 = nn.ConvTranspose2d(96,
                                               3,
                                               kernel_size=7,
                                               stride=2,
                                               padding=1,
                                               bias=False)

  def forward(self,x):
    for i,layer in enumerate(self.features):
      if isinstance(layer,nn.MaxPool2d):
        self.sizes[i] = x.size()
        x,indices =layer(x)
        self.feature_outputs[i]=x
        self.switch_indices[i]=indices
      else:
        x=layer(x)
        self.feature_outputs[i]=x
    return x

  def forward_deconv(self,x,layer):
    if layer < 1 or layer > 5:
      raise Exception("ZFnet -> forward_deconv(): layer value should be between [1,5]")

    x=self.deconv_pool5(x,self.switch_indices[12],output_size=self.sizes[12])
    x=self.deconv_act5(x)
    x=self.deconv_conv5(x)

    if layer ==1:
      return x

    x=self.deconv_act4(x)
    x=self.deconv_conv4(x)

    if layer ==2:
      return x

    x=self.deconv_act3(x)
    x=self.deconv_conv3(x)

    if layer ==3:
      return x

    x=self.deconv_pool2(x,self.switch_indices[5],output_size=self.sizes[5])
    x=self.deconv_act2(x)
    x=self.deconv_conv2(x)

    if layer ==4:
      return x

    x=self.deconv_pool1(x,self.switch_indices[2],output_size=self.sizes[2])
    x=self.deconv_act1(x)
    x=self.deconv_conv1(x)

    if layer ==5:
      return x











In [3]:
model=ZFNet()
summary(model,(3,224,224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 96, 110, 110]          14,208
              ReLU-2         [-1, 96, 110, 110]               0
         MaxPool2d-3  [[-1, 96, 55, 55], [-1, 96, 55, 55]]               0
            Conv2d-4          [-1, 256, 26, 26]         614,656
              ReLU-5          [-1, 256, 26, 26]               0
         MaxPool2d-6  [[-1, 256, 13, 13], [-1, 256, 13, 13]]               0
            Conv2d-7          [-1, 384, 13, 13]         885,120
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]       1,327,488
             ReLU-10          [-1, 384, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         884,992
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13  [[-1, 256, 6, 6], [-1, 256, 6, 6]]               0
Total 