In [2]:
import torch
import torch.nn as nn
import timm.models.swin_transformer_v2 as swinv2
from timm.models.helpers import load_state_dict
import timm

class testmodel():
    def __init__(self, num_classes_list, model_name, projector_features = None, use_mlp=False):
        super().__init__()
        assert num_classes_list is not None
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes = 0)
        self.projector = None 
        if projector_features:
            encoder_features = self.backbone.num_features
            self.num_features = projector_features
            if use_mlp:
                self.projector = nn.Sequential(nn.Linear(encoder_features, self.num_features), nn.ReLU(inplace=True), nn.Linear(self.num_features, self.num_features))
            else:
                self.projector = nn.Linear(encoder_features, self.num_features)
        
        #multi-task heads
        self.omni_heads = []  
        for num_classes in num_classes_list:
            self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
        self.omni_heads = nn.ModuleList(self.omni_heads)

    def forward(self, x, head_n=None):
        x = self.backbone(x)
        if self.projector:
            x = self.projector(x)
        if head_n is not None:
            return x, self.omni_heads[head_n](x)
        else:
            return [head(x) for head in self.omni_heads]

    def generate_embeddings(self, x, after_proj = True):
        x = self.backbone(x)
        if after_proj:
            x = self.projector(x)
    def __str__(self):
        # Customize the representation of your model here
        return "Custom Test Model:\n" + str(self.backbone) + "\nProjector: " + str(self.projector) + "\nOmni Heads: " + str(self.omni_heads)

In [4]:
print(timm.list_models("*swinv2_large*"))
# model = timm.create_model('swinv2_base_window12_192', pretrained=True)
    # print(model)
# model = testmodel([14, 14, 2, 14, 2], model_name="swinv2_base_window12_192", projector_features=1376)
# print(model)

['swinv2_large_window12_192', 'swinv2_large_window12to16_192to256', 'swinv2_large_window12to24_192to384']
