In [41]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights

class EfficientNetFeatureExtractor(nn.Module):
    def __init__(self, model_name='efficientnet_b0'):
        super(EfficientNetFeatureExtractor, self).__init__()
        # Load the pre-trained EfficientNet
        self.model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)

        # final layer feature
        self.feature_blocks = self.model.features

    def forward(self, x):
        features = []

        # Stages 1 to 8
        for block in self.feature_blocks:
            x = block(x)
            features.append(x)

        return features

# Example usage
model = EfficientNetFeatureExtractor()
input_tensor = torch.rand(1, 3, 224, 224)  # Example input tensor
features = model(input_tensor)

# features is a list where each element is the output of a stage
print(len(features))


The shape is: torch.Size([1, 1280, 7, 7])
9


In [39]:
for ll in features:
    print(ll.shape)

torch.Size([1, 32, 112, 112])
torch.Size([1, 16, 112, 112])
torch.Size([1, 24, 56, 56])
torch.Size([1, 40, 28, 28])
torch.Size([1, 80, 14, 14])
torch.Size([1, 112, 14, 14])
torch.Size([1, 192, 7, 7])
torch.Size([1, 320, 7, 7])
torch.Size([1, 1280, 7, 7])


In [9]:
print('Total params for model: %2.fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 

Total params for model:  5M
