In [1]:
import torch
import timm


In [2]:
genres = ['Action', 'Adventure', 'Animation', 'Biography', 'Comedy',
          'Crime', 'Documentary', 'Drama', 'Family', 'Fantasy', 'History',
          'Horror', 'Music', 'Musical', 'Mystery', 'N/A', 'News', 'Romance',
          'Sci-Fi', 'Short', 'Sport', 'Thriller', 'War', 'Western'
         ]
DEVICE = 'cuda:0'

In [3]:
state_dict = torch.load('../weights/model.best.pth')

model = timm.create_model(model_name="resnet18", pretrained=False, num_classes=len(genres))
model.load_state_dict(state_dict)

<All keys matched successfully>

In [4]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, model, classes, size, thresholds):
        super().__init__()
        self.model = model
        self.classes = classes
        self.size = size
        self.thresholds = thresholds
    
    def forward(self, image):
        return torch.sigmoid(self.model.forward(image))

In [5]:
wrapper = ModelWrapper(model, classes=genres, size=(224, 224), thresholds=(0.5,) * len(genres))

In [6]:
scripted_model = torch.jit.script(wrapper)

In [7]:
scripted_model.classes

['Action',
 'Adventure',
 'Animation',
 'Biography',
 'Comedy',
 'Crime',
 'Documentary',
 'Drama',
 'Family',
 'Fantasy',
 'History',
 'Horror',
 'Music',
 'Musical',
 'Mystery',
 'N/A',
 'News',
 'Romance',
 'Sci-Fi',
 'Short',
 'Sport',
 'Thriller',
 'War',
 'Western']

In [9]:
traced_model = torch.jit.trace(wrapper, torch.rand(1, 3, 224, 224))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [10]:
traced_model.classes

AttributeError: 'RecursiveScriptModule' object has no attribute 'classes'

In [11]:
dummy_input = torch.rand(1, 3, 224, 224)

In [12]:
with torch.no_grad():
    print(torch.sigmoid(model(dummy_input)))

tensor([[0.0939, 0.0566, 0.0058, 0.0276, 0.3554, 0.0946, 0.0369, 0.4863, 0.0247,
         0.0390, 0.0110, 0.0248, 0.0281, 0.0065, 0.0303, 0.0006, 0.0007, 0.1140,
         0.0237, 0.0019, 0.0183, 0.0577, 0.0069, 0.0029]])


In [13]:
with torch.no_grad():
    print(scripted_model(dummy_input))

tensor([[0.0939, 0.0566, 0.0058, 0.0276, 0.3554, 0.0946, 0.0369, 0.4863, 0.0247,
         0.0390, 0.0110, 0.0248, 0.0281, 0.0065, 0.0303, 0.0006, 0.0007, 0.1140,
         0.0237, 0.0019, 0.0183, 0.0577, 0.0069, 0.0029]])


In [27]:
torch.jit.save(scripted_model, '../weights/genre_classification_v2.pt')

In [28]:
model = torch.jit.load('../weights/genre_classification_v2.pt', map_location='cpu')

In [29]:
model.classes

['Action',
 'Adventure',
 'Animation',
 'Biography',
 'Comedy',
 'Crime',
 'Documentary',
 'Drama',
 'Family',
 'Fantasy',
 'History',
 'Horror',
 'Music',
 'Musical',
 'Mystery',
 'N/A',
 'News',
 'Romance',
 'Sci-Fi',
 'Short',
 'Sport',
 'Thriller',
 'War',
 'Western']