In [1]:
import os
from dataclasses import dataclass

import torch
from torch import nn
from torchvision.models import (
    resnet50,
    ResNet50_Weights,
    resnet18,
    ResNet18_Weights,
)

In [9]:
class ModelFactory():

    @staticmethod
    def create(model_name: str, n_classes: int):
        for subclass in ModelFactory.__subclasses__():
            if subclass.__name__ == model_name:
                return subclass(model_name, n_classes)

        return False

    def __init__(self, model_name: str, n_classes: int):
        self.model_name = model_name
        self.n_classes = n_classes
        self.model = nn.Module()  # dummy?

    def train(self):
        self.model.train()

    def eval(self):
        self.model.eval()

    def __call__(self, *args, **kwargs):
        return self.model(*args, **kwargs)

    def get_model(self):
        return self.model

In [10]:
class ResNet18(ModelFactory):

    def __init__(self, model_name: str, n_classes: int):
        super().__init__(model_name, n_classes)
        weights = ResNet18_Weights.IMAGENET1K_V1
        self.model = resnet18(weights=weights)
        self.model.fc = nn.Linear(
            self.model.fc.in_features,
            self.n_classes
        )

In [11]:
model = ModelFactory.create("ResNet18", 778)

In [12]:
input = torch.rand(16, 3, 224, 224)

In [13]:
# print(model.model)

In [14]:
model(input)

tensor([[ 1.5250, -0.5373,  0.7756,  ...,  0.2224,  0.1816,  1.6705],
        [ 1.2608, -0.6349,  0.9953,  ...,  0.0521, -0.4448,  1.2121],
        [ 1.1375, -0.1234,  0.8351,  ...,  0.2446,  0.1314,  0.4164],
        ...,
        [ 1.1033, -0.5420,  0.1594,  ...,  0.4961,  0.1726,  0.7346],
        [ 1.2857, -0.5147,  0.3937,  ...,  0.3329,  0.0552,  1.0015],
        [ 0.9839, -0.2012,  0.6409,  ...,  0.6985, -0.3498,  1.0434]],
       grad_fn=<AddmmBackward0>)