In [1]:
import time
import os
from PIL import Image
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets

import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [32]:
categories = ['not_rated', 'rated']

In [4]:
transform = transforms.Compose([
    transforms.Resize((640,640)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [5]:
output_shape = len(categories)
output_shape

2

In [6]:
model = torchvision.models.efficientnet_b0(pretrained=True).to(device)
model



EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [7]:
model.classifier = nn.Sequential(
    nn.Dropout(p=0.2, inplace=True),
    nn.Linear(in_features=1280, out_features=output_shape, bias=True)
  ).to(device)

In [8]:
model.load_state_dict(torch.load('mobilenet_transfer_learning.pth', map_location=device))
model.eval()

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [33]:
def predict_rating(image_path):
    img = Image.open(image_path).convert('RGB')
    img_tensor = transform(img).unsqueeze(0)
    img_tensor = img_tensor.to(device)
    
    with torch.no_grad():
        output = model(img_tensor)
        predicted_class = torch.argmax(output, dim=1).item()
        return categories[predicted_class]

In [34]:
image_path = '../data/not_rated (11).jpg'

predicted_rating = predict_rating(image_path)

print(f'Predicted rating: {predicted_rating}')

Predicted rating: not_rated


In [35]:
images = os.listdir('../data')

In [36]:
for i, image in enumerate(images):
    if image.endswith('.jpg'):
        image_path = os.path.join('../data', image)
        predicted_rating = predict_rating(image_path)
        print(f'Image {i+1}: image: {image} pred: {predicted_rating}')

Image 1: image: not_rated (10).jpg pred: not_rated
Image 2: image: not_rated (11).jpg pred: not_rated
Image 3: image: not_rated (12).jpg pred: not_rated
Image 4: image: not_rated (13).jpg pred: not_rated
Image 5: image: not_rated (4).jpg pred: not_rated
Image 6: image: not_rated (5).jpg pred: not_rated
Image 7: image: not_rated (6).jpg pred: not_rated
Image 8: image: not_rated (7).jpg pred: not_rated
Image 9: image: not_rated (8).jpg pred: not_rated
Image 10: image: not_rated (9).jpg pred: not_rated
Image 11: image: not_rated_1.jpg pred: not_rated
Image 12: image: not_rated_2.jpg pred: not_rated
Image 13: image: not_rated_3.jpg pred: not_rated
Image 14: image: rated_1.jpg pred: rated
Image 15: image: rated_2.jpg pred: rated
Image 16: image: rated_3.jpg pred: rated
Image 17: image: rated_4.jpg pred: rated
Image 18: image: rated_5.jpg pred: rated
Image 19: image: rated_6.jpg pred: rated
