In [3]:
import torch.nn as nn
import torch

class TorchVision(nn.Module):
    """
    TorchVision module to allow loading any torchvision model.

    This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.

    Attributes:
        m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.

    Args:
        model (str): Name of the torchvision model to load.
        weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
        unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
        truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
        split (bool, optional): Returns output from intermediate child modules as list. Default is False.
    """

    def __init__(
        self, model: str, weights: str = "DEFAULT", unwrap: bool = True, truncate: int = 2, split: bool = False
    ):
        """
        Load the model and weights from torchvision.

        Args:
            model (str): Name of the torchvision model to load.
            weights (str): Pre-trained weights to load.
            unwrap (bool): Whether to unwrap the model.
            truncate (int): Number of layers to truncate.
            split (bool): Whether to split the output.
        """
        import torchvision  # scope for faster 'import ultralytics'

        super().__init__()
        if hasattr(torchvision.models, "get_model"):
            self.m = torchvision.models.get_model(model, weights=weights)
        else:
            self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
        if unwrap:
            layers = list(self.m.children())
            if isinstance(layers[0], nn.Sequential):  # Second-level for some models like EfficientNet, Swin
                layers = [*list(layers[0].children()), *layers[1:]]
            self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
            self.split = split
        else:
            self.split = False
            self.m.head = self.m.heads = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor | list[torch.Tensor]): Output tensor or list of tensors.
        """
        if self.split:
            y = [x]
            y.extend(m(y[-1]) for m in self.m)
        else:
            y = self.m(x)
        return y

In [5]:
model_convnext_tiny = TorchVision('convnext_tiny', 'DEFAULT', True, 2, False)

In [7]:
x = torch.randn(1,3,224,224)

y = model_convnext_tiny(x)

y.shape

torch.Size([1, 768, 7, 7])

In [8]:
model_resnext50_32x4d = TorchVision('resnext50_32x4d', 'DEFAULT', True, 2, False)

Downloading: "https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth" to /home/srs/.cache/torch/hub/checkpoints/resnext50_32x4d-1a0047aa.pth
100%|██████████| 95.8M/95.8M [02:36<00:00, 641kB/s] 


In [9]:
y1 = model_resnext50_32x4d(x)

y1.shape

torch.Size([1, 2048, 7, 7])

In [11]:
import torch
import torch.nn as nn
import timm


class TimmVision(nn.Module):
    """
    TimmVision module to allow loading any timm model with optional features_only output.

    Attributes:
        m (nn.Module): The loaded timm model or feature extractor.
    """

    def __init__(
        self,
        model: str,
        pretrained: bool = True,
        unwrap: bool = True,
        truncate: int = 2,
        split: bool = False,
    ):
        """
        Args:
            model (str): Name of the timm model to load.
            pretrained (bool): Whether to load pretrained weights.
            unwrap (bool): Whether to unwrap into Sequential layers.
            truncate (int): Number of layers to remove from the end if unwrap=True.
            split (bool): If True, returns intermediate feature maps using timm's features_only.
        """
        super().__init__()

        self.split = split

        if split:
            # Use timm's features_only to get intermediate features
            self.m = timm.create_model(model, pretrained=pretrained, features_only=True)
        else:
            # Standard model
            self.m = timm.create_model(model, pretrained=pretrained)
            if unwrap:
                # Break model into children layers
                layers = list(self.m.children())
                if isinstance(layers[0], nn.Sequential):  # nested Sequential
                    layers = [*list(layers[0].children()), *layers[1:]]
                # Truncate last `truncate` layers
                self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
            else:
                # Remove classifier / head
                for attr in ["classifier", "fc", "head"]:
                    if hasattr(self.m, attr):
                        setattr(self.m, attr, nn.Identity())

    def forward(self, x: torch.Tensor):
        if self.split:
            # timm features_only returns list of feature maps
            return self.m(x)
        else:
            return self.m(x)


  from .autonotebook import tqdm as notebook_tqdm


In [12]:
import timm
from pprint import pprint
model_names = timm.list_models('*convnext*')
pprint(model_names)

['convnext_atto',
 'convnext_atto_ols',
 'convnext_atto_rms',
 'convnext_base',
 'convnext_femto',
 'convnext_femto_ols',
 'convnext_large',
 'convnext_large_mlp',
 'convnext_nano',
 'convnext_nano_ols',
 'convnext_pico',
 'convnext_pico_ols',
 'convnext_small',
 'convnext_tiny',
 'convnext_tiny_hnf',
 'convnext_xlarge',
 'convnext_xxlarge',
 'convnext_zepto_rms',
 'convnext_zepto_rms_ols',
 'convnextv2_atto',
 'convnextv2_base',
 'convnextv2_femto',
 'convnextv2_huge',
 'convnextv2_large',
 'convnextv2_nano',
 'convnextv2_pico',
 'convnextv2_small',
 'convnextv2_tiny',
 'test_convnext',
 'test_convnext2',
 'test_convnext3']


In [13]:
timm_model_convnext_tiny = TimmVision('convnext_tiny')

In [15]:
y2 = timm_model_convnext_tiny(x)

y2.shape

torch.Size([1, 768, 7, 7])