In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import models

### Resnet

In [2]:
resnet18 = models.resnet18()

In [1]:
class PlanetResNet(nn.Module):
    def __init__(self, M, c):
        super().__init__()
        self.features = nn.Sequential(*(list(M.children())[:-1]))
        self.classifier = nn.Linear(in_features=512, out_features=c)
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return torch.sigmoid(x)
    
    
class ModelWrapper:
    def __init__(self, M):
        self.model = M
        
    @staticmethod
    def freeze(layers):
        for param in layers.parameters(): param.requires_grad_(False)
            
    @staticmethod
    def unfreeze(layers):
        for param in layers.parameters(): param.requires_grad_(True)
        
    def freeze_features(self, arg=True):
        if arg: self.freeze(self.model.features)
        else:   self.unfreeze(self.model.features)
            
    def partial_freeze_features(self, pct=0.2):
        size = len(list(self.model.features.children()))
        freeze_point = int(size * (1 - pct))
        
        for idx, child in enumerate(model.features.children()):
            if idx < freeze_point: self.freeze(child)
            else: self.unfreeze(child)
        
    def freeze_classifier(self, arg=True):
        if arg: self.freeze(self.model.classifier)
        else:   self.unfreeze(self.model.classifier)
    
    def summary(self):
        print('\n\n')
        for idx, (name, child) in enumerate(self.model.features.named_children()):
            print(f'{idx}: {name}-{child}')
            for param in child.parameters():
                print(f'{param.requires_grad}')

        for idx, (name, child) in enumerate(self.model.classifier.named_children()):
            print(f'{idx}: {name}-{child}')
            for param in child.parameters():
                print(f'{param.requires_grad}')
        print('\n\n')

NameError: name 'nn' is not defined

In [73]:
model = PlanetResNet(resnet18, 17)
wrapper = ModelWrapper(model)

In [82]:
wrapper.freeze_features(False)

In [83]:
for param in wrapper.model.parameters():
    print('---', param.name, param.shape, param.requires_grad)

--- None torch.Size([64, 3, 7, 7]) True
--- None torch.Size([64]) True
--- None torch.Size([64]) True
--- None torch.Size([64, 64, 3, 3]) True
--- None torch.Size([64]) True
--- None torch.Size([64]) True
--- None torch.Size([64, 64, 3, 3]) True
--- None torch.Size([64]) True
--- None torch.Size([64]) True
--- None torch.Size([64, 64, 3, 3]) True
--- None torch.Size([64]) True
--- None torch.Size([64]) True
--- None torch.Size([64, 64, 3, 3]) True
--- None torch.Size([64]) True
--- None torch.Size([64]) True
--- None torch.Size([128, 64, 3, 3]) True
--- None torch.Size([128]) True
--- None torch.Size([128]) True
--- None torch.Size([128, 128, 3, 3]) True
--- None torch.Size([128]) True
--- None torch.Size([128]) True
--- None torch.Size([128, 64, 1, 1]) True
--- None torch.Size([128]) True
--- None torch.Size([128]) True
--- None torch.Size([128, 128, 3, 3]) True
--- None torch.Size([128]) True
--- None torch.Size([128]) True
--- None torch.Size([128, 128, 3, 3]) True
--- None torch.Si

### Explore different image size

In [34]:
import torch.nn.functional as F

In [81]:
x = torch.randn(16, 3, 224, 224)

In [82]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class PlanetResNet(nn.Module):
    def __init__(self, M, c):
        super().__init__()
        self.features = nn.Sequential(*(list(M.children())[:-2]))
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Flatten(),
            nn.Linear(in_features=512, out_features=c)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return torch.sigmoid(x)

model = PlanetResNet(resnet18, 17)

In [83]:
o = model.features(x)

In [84]:
o.shape

torch.Size([16, 512, 7, 7])

In [87]:
y = model.classifier(o)

In [88]:
y.shape

torch.Size([16, 17])

In [72]:
o = o.view(o.size(0), -1)

In [73]:
o.shape

torch.Size([16, 512])

In [74]:
o = F.linear(x, (17, 512))

AttributeError: 'tuple' object has no attribute 't'

In [33]:
o.shape

torch.Size([16, 3, 1, 1])