In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pandas as pd
from torchinfo import summary

In [16]:
class VGG16MultiLabel(nn.Module):
    def __init__(self, num_classes=28):
        super(VGG16MultiLabel, self).__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
        for param in vgg.features.parameters():
            param.requires_grad = False

        self.features = vgg.features
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 6 * 4, 2056),
            nn.ReLU(),
            nn.Linear(2056, 1024),
            nn.ReLU(),
            nn.Dropout(0.6),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

In [17]:
model = VGG16MultiLabel(9)
print(summary(model, input_size=(1, 3, 200, 150)))
model.load_state_dict(torch.load('vgg16_genre_model.pth', map_location=torch.device('cpu'), weights_only=True))

Layer (type:depth-idx)                   Output Shape              Param #
VGG16MultiLabel                          [1, 9]                    --
├─Sequential: 1-1                        [1, 512, 6, 4]            --
│    └─Conv2d: 2-1                       [1, 64, 200, 150]         (1,792)
│    └─ReLU: 2-2                         [1, 64, 200, 150]         --
│    └─Conv2d: 2-3                       [1, 64, 200, 150]         (36,928)
│    └─ReLU: 2-4                         [1, 64, 200, 150]         --
│    └─MaxPool2d: 2-5                    [1, 64, 100, 75]          --
│    └─Conv2d: 2-6                       [1, 128, 100, 75]         (73,856)
│    └─ReLU: 2-7                         [1, 128, 100, 75]         --
│    └─Conv2d: 2-8                       [1, 128, 100, 75]         (147,584)
│    └─ReLU: 2-9                         [1, 128, 100, 75]         --
│    └─MaxPool2d: 2-10                   [1, 128, 50, 37]          --
│    └─Conv2d: 2-11                      [1, 256, 50, 37]    

<All keys matched successfully>

In [18]:
model.eval()
transform = transforms.Compose([
    transforms.Resize((200,150)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [19]:
# select a random index from the posters folder or input the image you want to classify.
model.to('cpu')
imdb_id = "tt0002797"
img_path = f"posters/{imdb_id}.jpg"
print(img_path)
image = Image.open(img_path).convert('RGB')

input_tensor = transform(image).unsqueeze(0)

posters/tt0002797.jpg


In [20]:
with torch.no_grad():
    outputs = model(input_tensor)
    predicted_class = outputs.argmax(dim=1).item()

print(f'Predicted class index: {predicted_class}')

Predicted class index: 2


{'Action': 0,
 'Adventure': 1,
 'Comedy': 2,
 'Crime': 3,
 'Drama': 4,
 'Horror': 5,
 'Romance': 6,
 'Sci-Fi': 7,
 'Thriller': 8}  \\

 Above are the indices of the respective genres.

In [21]:
df = pd.read_csv("movies_with_posters.csv")
genre_row = df.loc[df['movie_id'] == imdb_id, 'genres']
print(genre_row.values[0])

['Comedy']
