In [1]:
import torchvision.models as models

In [4]:
models.EfficientNet

torchvision.models.efficientnet.EfficientNet

In [6]:
class EfficientNetCustom(EfficientNet):
    

<class 'torchvision.models.efficientnet.EfficientNet'>


In [7]:
import re
import torch
import torch.nn as nn
from torchvision.models import ResNet
from torchvision.models.resnet import BasicBlock

class Resnet(nn.Module):

    def __init__(self, 
                 base_model, 
                 n_classes, 
                 n_bands = 4, 
                 p_dropout = 0.25, 
                 unfreeze = False,
                 headless = False):
        """
        Parameters
        ----------
        base_model: nn.Module
            A variant of ResNet, preferably from torchvision.models, or similar.
        n_classes: int
            Number of classes.
        n_bands: int
            Number of input bands.
        p_dropout: float
            Drop out probability
        unfreeze: None, or List
            If not None, must be a list of parameter names which should not be frozen.
            If you want to train all model parameters then just leave this argument as None.
        headless: bool
            If True, then the network will have no classification head, and thus will only 
            return the final extracted features. If False, the model will use the classification
            head and return class logit scores.
        """
        
        super().__init__()
        
        self.headless = headless
        resnet = base_model
        self.dropout = nn.Dropout(p = p_dropout)

        if n_bands != 3:
            resnet.conv1 = nn.Conv2d(n_bands, 
                                     64, 
                                     kernel_size=7, 
                                     stride=2, 
                                     padding=3, 
                                     bias=False)
        
        if self.headless:
            # if headless we delete the classification head
            del resnet.fc
        else:
            # need to change head to have the correct # of classes
            resnet.fc = nn.Linear(
                      in_features = resnet.fc.in_features, 
                      out_features = n_classes
            )
            
        self.model = resnet
        
        # unfreeze pretrained parameters
        if unfreeze:
            self.set_parameter_requires_grad(unfreeze)
        
    def set_parameter_requires_grad(self, unfreeze):
        for name, param in self.model.named_parameters():
            if name not in unfreeze:
                param.requires_grad = False
            
    def forward(self, x):
        res = []
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.dropout(x)
        
        x = self.model.layer2(x)
        x = self.dropout(x)
        
        x = self.model.layer3(x)
        x = self.dropout(x)
        
        x = self.model.layer4(x)
        x = self.dropout(x)

        if not self.headless:
            x = self.model.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.fc(x)
                
        return x
    
class HeadlessResnet(Resnet):
    def __init__(self,
                 base_model, 
                 n_classes, 
                 n_bands = 4, 
                 p_dropout = 0.3):
        
        super().__init__(base_model=base_model, 
                         n_classes=n_classes, 
                         n_bands = n_bands, 
                         p_dropout = p_dropout, 
                         headless = True)

In [8]:
model = Resnet()

TypeError: __init__() missing 2 required positional arguments: 'base_model' and 'n_classes'