In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import models

In [None]:
optim.Adam()

In [2]:
densenet121 = models.densenet121()

  nn.init.kaiming_normal(m.weight.data)


In [3]:
x = torch.randn(16, 3, 256, 256)

In [4]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)
    
class PlanetDenseNet(nn.Module):
    def __init__(self, M, c):
        super().__init__()
        self.features = M.features
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Flatten(),
            nn.Linear(in_features=1024, out_features=c)
        )
        
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        
        return torch.sigmoid(x)
    
class PlanetWrapper(object):
    def __init__(self, M):
        self.model = M

    @staticmethod
    def freeze(layers):
        for param in layers.parameters(): param.requires_grad_(False)
            
    @staticmethod
    def unfreeze(layers):
        for param in layers.parameters(): param.requires_grad_(True)
        
    def freeze_features(self, arg=True):
        if arg: self.freeze(self.model.features)
        else:   self.unfreeze(self.model.features)
            
    def partial_freeze_features(self, pct=0.2):
        size = len(list(self.model.features.children()))
        freeze_point = int(size * (1 - pct))
        
        for idx, child in enumerate(self.model.features.children()):
            if idx < freeze_point: self.freeze(child)
            else: self.unfreeze(child)
        
    def freeze_classifier(self, arg=True):
        if arg: self.freeze(self.model.classifier)
        else:   self.unfreeze(self.model.classifier)

    def summary(self):
        print('\n\n')
        for idx, (name, child) in enumerate(self.model.features.named_children()):
            print(f'{idx}: {name}-{child}')
            for param in child.parameters():
                print(f'{param.requires_grad}')

        for idx, (name, child) in enumerate(self.model.classifier.named_children()):
            print(f'{idx}: {name}-{child}')
            for param in child.parameters():
                print(f'{param.requires_grad}')
        print('\n\n')

In [40]:
densenet121 = PlanetDenseNet(densenet121, 17)

In [41]:
wrapper = PlanetWrapper(densenet121)

In [42]:
wrapper.model

PlanetDenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(

In [47]:
wrapper.partial_freeze_features(0.6)
wrapper.summary()




0: conv0-Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
False
1: norm0-BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
False
False
2: relu0-ReLU(inplace)
3: pool0-MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
4: denseblock1-_DenseBlock(
  (denselayer1): _DenseLayer(
    (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace)
    (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace)
    (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (denselayer2): _DenseLayer(
    (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace)
    (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [53]:
densenet121 = models.densenet121(pretrained=True, drop_rate=0.3)

In [52]:
densenet121

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplac