In [1]:
import torch
import torchvision
import fastai


import torch.nn as nn
from torch.nn.modules.module import Module

from collections import OrderedDict
from pathlib import Path

assert(torch.__version__ == '1.1.0')
assert(torchvision.__version__== '0.3.0')
assert(fastai.__version__ == '1.0.55')


AssertionError: 

In [None]:
from torchvision import models
model = models.resnet34(pretrained=True)

# Freeze all the layers
for param in model.parameters():
    param.requires_grad = False

In [86]:
def hook_func0(self, input, output):
    print(f"Inside {self.__class__.__name__}")
    return input,output

    
class hook_context_manager():
    def __init__(self, module, hook_func = hook_func0):
        self.handle = module.register_forward_hook(self.hook_func_wrapper)
        self.hook_func = hook_func
        self.remove_status = False
        
    
    def hook_func_wrapper(self, module, input, output):
        self.input, self.output  = self.hook_func(module,input,output)
         

    def remove(self):
        if not self.remove_status:
            self.handle.remove()
            self.remove_status = True
    
    @property
    def output_shape(self): return self.output.data.shape

    
    @property
    def input_shape(self):return self.input[0].shape
   

    def __enter__(self): return self
        
    def __exit__(self, type, value, traceback): 
        self.remove()


In [87]:

class Pooling_Layers(Module):
    def __init__(self):
        super().__init__()
        self.AdpAvgPool = nn.AdaptiveAvgPool2d((1,1))
        self.AdpMaxPool = nn.AdaptiveMaxPool2d((1,1))

    def forward(self, x):
        return torch.cat((self.AdpAvgPool(x), self.AdpMaxPool(x)), dim = 1)
    
class Flatten_4D_2D(Module):
    def __init__(self):
        super().__init__()
        
    def forward(self,x):
        return x.view(x.size(0),-1)
    
def Linear(in_features, inter_features = 512, p = 0.5, number_classes = 2):
   
    layers = OrderedDict()
    layers['BatchNorm'] =  nn.BatchNorm1d(num_features=in_features)
    layers['Dropout'] = nn.Dropout(p = p/2)
  
    layers['FC1'] = nn.Linear(in_features, inter_features)
    layers['FC1 RELU']  = nn.ReLU(inplace=True)
    layers['FC1 BatchNorm']  = nn.BatchNorm1d(inter_features)
    layers['FC1 Dropout'] = nn.Dropout(p)
    
    layers['FC2'] = nn.Linear(inter_features, number_classes)
    return nn.Sequential(layers)
    
    

    
    

In [88]:
resnet_base = nn.Sequential(*list(model.children())[:-2])

In [89]:
resnet_base_pooling_flatten = nn.Sequential(resnet_base,Pooling_Layers(), Flatten_4D_2D())

In [90]:
model = resnet_base_pooling_flatten
with hook_context_manager(resnet_base_pooling_flatten) as hook:
    dummy_img = torch.randn(2,3,64,64)
    _ = model(dummy_img)

output_features_size = hook.output_shape[1]
output_features_size

Inside Sequential


1024

In [91]:
resnet_transfer = nn.Sequential(resnet_base,Pooling_Layers(), Flatten_4D_2D(), 
                                Linear(in_features=output_features_size, number_classes=2))

In [92]:
resnet_transfer

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (rel

In [105]:
model_file_name = 'resnet34_transfer_model.pt'
torch.save(resnet_transfer.state_dict(), Path.joinpath(Path("../Models"), model_file_name))

In [106]:
model = torch.load(Path.joinpath(Path("../Models"), model_file_name))

In [107]:
model

OrderedDict([('0.0.weight',
              tensor([[[[ 5.4109e-03, -6.9092e-03,  7.8839e-03,  ...,  4.9072e-02,
                          3.0660e-02,  2.5398e-02],
                        [ 4.1081e-02,  3.1296e-02,  3.2265e-02,  ...,  3.3145e-02,
                          2.9754e-02,  4.1735e-02],
                        [ 4.9519e-03, -3.1705e-02, -6.1310e-02,  ..., -9.7493e-02,
                         -1.1601e-01, -1.2191e-01],
                        ...,
                        [-1.2287e-02, -2.4841e-02, -9.3052e-03,  ...,  1.7113e-02,
                          2.4631e-03,  1.6726e-02],
                        [ 3.9117e-03,  4.4537e-03,  3.6315e-02,  ...,  1.0371e-01,
                          7.3973e-02,  5.9085e-02],
                        [ 1.6784e-02,  8.8902e-03,  3.1312e-02,  ...,  9.6964e-02,
                          8.3749e-02,  9.6970e-02]],
              
                       [[-7.7192e-03, -8.7711e-03,  1.4143e-02,  ...,  3.3901e-02,
                          2.5483e-