In [1]:
import torch
import timm
import torch.nn as nn 


class TimmVision(nn.Module):
    
    """
    TimmVision module to allow loading any timm model with optional features_only output.

    Attributes:
        m (nn.Module): The loaded timm model or feature extractor.
    """

    def __init__(
        self,
        model: str,
        pretrained: bool = True,
        unwrap: bool = True,
        truncate: int = 2,
        split: bool = False,
    ):
        """
        Args:
            model (str): Name of the timm model to load.
            pretrained (bool): Whether to load pretrained weights.
            unwrap (bool): Whether to unwrap into Sequential layers.
            truncate (int): Number of layers to remove from the end if unwrap=True.
            split (bool): If True, returns intermediate feature maps using timm's features_only.
        """
        import timm 
        super().__init__()

        self.split = split

        if split:
            # Use timm's features_only to get intermediate features
            self.m = timm.create_model(model, pretrained=pretrained, features_only=True)
        else:
            # Standard model
            self.m = timm.create_model(model, pretrained=pretrained)
            if unwrap:
                # Break model into children layers
                layers = list(self.m.children())
                if isinstance(layers[0], nn.Sequential):  # nested Sequential
                    layers = [*list(layers[0].children()), *layers[1:]]
                # Truncate last `truncate` layers
                self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
            else:
                # Remove classifier / head
                for attr in ["classifier", "fc", "head"]:
                    if hasattr(self.m, attr):
                        setattr(self.m, attr, nn.Identity())

    def forward(self, x: torch.Tensor):
        if self.split:
            # timm features_only returns list of feature maps
            return self.m(x)
        else:
            return self.m(x)


In [6]:
from pprint import pprint

model_names = timm.list_models('*resnet*', pretrained=True)
print(len(model_names))
pprint(model_names)

180
['cspresnet50.ra_in1k',
 'eca_resnet33ts.ra2_in1k',
 'ecaresnet26t.ra2_in1k',
 'ecaresnet50d.miil_in1k',
 'ecaresnet50d_pruned.miil_in1k',
 'ecaresnet50t.a1_in1k',
 'ecaresnet50t.a2_in1k',
 'ecaresnet50t.a3_in1k',
 'ecaresnet50t.ra2_in1k',
 'ecaresnet101d.miil_in1k',
 'ecaresnet101d_pruned.miil_in1k',
 'ecaresnet269d.ra2_in1k',
 'ecaresnetlight.miil_in1k',
 'gcresnet33ts.ra2_in1k',
 'gcresnet50t.ra2_in1k',
 'inception_resnet_v2.tf_ens_adv_in1k',
 'inception_resnet_v2.tf_in1k',
 'lambda_resnet26rpt_256.c1_in1k',
 'lambda_resnet26t.c1_in1k',
 'lambda_resnet50ts.a1h_in1k',
 'legacy_seresnet18.in1k',
 'legacy_seresnet34.in1k',
 'legacy_seresnet50.in1k',
 'legacy_seresnet101.in1k',
 'legacy_seresnet152.in1k',
 'nf_resnet50.ra2_in1k',
 'resnet10t.c3_in1k',
 'resnet14t.c3_in1k',
 'resnet18.a1_in1k',
 'resnet18.a2_in1k',
 'resnet18.a3_in1k',
 'resnet18.fb_ssl_yfcc100m_ft_in1k',
 'resnet18.fb_swsl_ig1b_ft_in1k',
 'resnet18.gluon_in1k',
 'resnet18.tv_in1k',
 'resnet18d.ra2_in1k',
 'resnet18d

resnet18.a1_in1k

In [7]:
model_resnet18 = TimmVision('resnet18.a1_in1k',pretrained=True,unwrap=True,truncate=2,split=True)

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

In [10]:
x = torch.randn(1,3,244,244)
for i, j in enumerate(model_resnet18(x)):
    print(i, j.shape)

0 torch.Size([1, 64, 122, 122])
1 torch.Size([1, 64, 61, 61])
2 torch.Size([1, 128, 31, 31])
3 torch.Size([1, 256, 16, 16])
4 torch.Size([1, 512, 8, 8])


In [11]:
class TorchVision(nn.Module):
    """
    TorchVision module to allow loading any torchvision model.

    This class provides a way to load a model from the torchvision library, optionally load pre-trained weights, and customize the model by truncating or unwrapping layers.

    Attributes:
        m (nn.Module): The loaded torchvision model, possibly truncated and unwrapped.

    Args:
        model (str): Name of the torchvision model to load.
        weights (str, optional): Pre-trained weights to load. Default is "DEFAULT".
        unwrap (bool, optional): If True, unwraps the model to a sequential containing all but the last `truncate` layers. Default is True.
        truncate (int, optional): Number of layers to truncate from the end if `unwrap` is True. Default is 2.
        split (bool, optional): Returns output from intermediate child modules as list. Default is False.
    """

    def __init__(
        self, model: str, weights: str = "DEFAULT", unwrap: bool = True, truncate: int = 2, split: bool = False
    ):
        """
        Load the model and weights from torchvision.

        Args:
            model (str): Name of the torchvision model to load.
            weights (str): Pre-trained weights to load.
            unwrap (bool): Whether to unwrap the model.
            truncate (int): Number of layers to truncate.
            split (bool): Whether to split the output.
        """
        import torchvision  # scope for faster 'import ultralytics'

        super().__init__()
        if hasattr(torchvision.models, "get_model"):
            self.m = torchvision.models.get_model(model, weights=weights)
        else:
            self.m = torchvision.models.__dict__[model](pretrained=bool(weights))
        if unwrap:
            layers = list(self.m.children())
            if isinstance(layers[0], nn.Sequential):  # Second-level for some models like EfficientNet, Swin
                layers = [*list(layers[0].children()), *layers[1:]]
            self.m = nn.Sequential(*(layers[:-truncate] if truncate else layers))
            self.split = split
        else:
            self.split = False
            self.m.head = self.m.heads = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the model.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor | list[torch.Tensor]): Output tensor or list of tensors.
        """
        if self.split:
            y = [x]
            y.extend(m(y[-1]) for m in self.m)
        else:
            y = self.m(x)
        return y


In [14]:
x1 = torch.randn(1,3,244,244)
model_convnext_tiny = TorchVision( 'convnext_tiny', 'DEFAULT', True, 2, True)

for i, j in enumerate(model_convnext_tiny(x1)):
    print(i, j.shape)

0 torch.Size([1, 3, 244, 244])
1 torch.Size([1, 96, 61, 61])
2 torch.Size([1, 96, 61, 61])
3 torch.Size([1, 192, 30, 30])
4 torch.Size([1, 192, 30, 30])
5 torch.Size([1, 384, 15, 15])
6 torch.Size([1, 384, 15, 15])
7 torch.Size([1, 768, 7, 7])
8 torch.Size([1, 768, 7, 7])


In [2]:
from pprint import pprint

model_mobilenet_names = timm.list_models("*mobile*", pretrained=True)

pprint(model_mobilenet_names)

['mobilenet_edgetpu_v2_m.ra4_e3600_r224_in1k',
 'mobilenetv1_100.ra4_e3600_r224_in1k',
 'mobilenetv1_100h.ra4_e3600_r224_in1k',
 'mobilenetv1_125.ra4_e3600_r224_in1k',
 'mobilenetv2_050.lamb_in1k',
 'mobilenetv2_100.ra_in1k',
 'mobilenetv2_110d.ra_in1k',
 'mobilenetv2_120d.ra_in1k',
 'mobilenetv2_140.ra_in1k',
 'mobilenetv3_large_100.miil_in21k',
 'mobilenetv3_large_100.miil_in21k_ft_in1k',
 'mobilenetv3_large_100.ra4_e3600_r224_in1k',
 'mobilenetv3_large_100.ra_in1k',
 'mobilenetv3_large_150d.ra4_e3600_r256_in1k',
 'mobilenetv3_rw.rmsp_in1k',
 'mobilenetv3_small_050.lamb_in1k',
 'mobilenetv3_small_075.lamb_in1k',
 'mobilenetv3_small_100.lamb_in1k',
 'mobilenetv4_conv_aa_large.e230_r384_in12k',
 'mobilenetv4_conv_aa_large.e230_r384_in12k_ft_in1k',
 'mobilenetv4_conv_aa_large.e230_r448_in12k_ft_in1k',
 'mobilenetv4_conv_aa_large.e600_r384_in1k',
 'mobilenetv4_conv_blur_medium.e500_r224_in1k',
 'mobilenetv4_conv_large.e500_r256_in1k',
 'mobilenetv4_conv_large.e600_r384_in1k',
 'mobilenet

In [3]:
model_mobilenetv4 = TimmVision("mobilenetv4_conv_small.e1200_r224_in1k",split=True)

x = torch.randn(1,3,244,244)
for i, j in enumerate(model_mobilenetv4(x)):
    print(i, j.shape)

model.safetensors:   0%|          | 0.00/15.2M [00:00<?, ?B/s]

Unexpected keys (norm_head.num_batches_tracked, classifier.bias, classifier.weight, conv_head.weight, norm_head.bias, norm_head.running_mean, norm_head.running_var, norm_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


0 torch.Size([1, 32, 122, 122])
1 torch.Size([1, 32, 61, 61])
2 torch.Size([1, 64, 31, 31])
3 torch.Size([1, 96, 16, 16])
4 torch.Size([1, 960, 8, 8])
