Written by: Povilas Stašys (1812991)

1. Installing the required packages

In [16]:
pip install openimages torchmetrics

Note: you may need to restart the kernel to use updated packages.


2. Downloading the dataset

In [17]:
import os
from openimages.download import download_dataset
from math import ceil

data_dir = "data"
total_samples = 1000
classes = ["Mushroom", "Strawberry", "Orange"]
samples_per_class = ceil(total_samples / len(classes))

if not os.path.exists(data_dir):
  os.makedirs(data_dir)
  
print("Download is starting...")
download_dataset(data_dir, classes, limit=samples_per_class)


Download is starting...


2023-02-27  23:03:08 INFO Downloading 325 train images for class 'mushroom'
100%|██████████| 325/325 [00:22<00:00, 14.68it/s]
2023-02-27  23:03:31 INFO Downloading 334 train images for class 'strawberry'
100%|██████████| 334/334 [00:21<00:00, 15.87it/s]
2023-02-27  23:03:52 INFO Downloading 334 train images for class 'orange'
100%|██████████| 334/334 [00:22<00:00, 14.99it/s]
2023-02-27  23:04:16 INFO Downloading 9 validation images for class 'mushroom'
100%|██████████| 9/9 [00:02<00:00,  4.27it/s]


{'mushroom': {'images_dir': 'data\\mushroom\\images'},
 'strawberry': {'images_dir': 'data\\strawberry\\images'},
 'orange': {'images_dir': 'data\\orange\\images'}}

3. Assign proper device

In [18]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


4. Custom dataset class definition

In [19]:
from torch.utils.data import Dataset
from glob import glob
import numpy as np
from PIL import Image
from torchvision.transforms import transforms

class CustomDataset(Dataset):
  def __init__(self, images_dir, transform=None):
    self.images_dir = images_dir
    self.transform = transform

    self.class1_files = glob(self.images_dir + "/{}/images/*.jpg".format(classes[0].lower()))
    self.class2_files = glob(self.images_dir + "/{}/images/*.jpg".format(classes[1].lower()))
    self.class3_files = glob(self.images_dir + "/{}/images/*.jpg".format(classes[2].lower()))

    self.class1 = len(self.class1_files)
    self.class2 = len(self.class2_files)

    self.files = self.class1_files + self.class2_files + self.class3_files

    self.labels = np.zeros(len(self.files))
    self.labels[self.class1:] = 1
    self.labels[self.class1 + self.class2:] = 2 

    self.order =  [x for x in np.random.permutation(len(self.labels))]
    self.files = [self.files[x] for x in self.order]
    self.labels = [self.labels[x] for x in self.order]

  def __len__(self):
    return (len(self.labels))

  def __getitem__(self, i):
    img_path = self.files[i]

    img = Image.open(img_path).convert("RGB")
    img_tensor = transforms.PILToTensor()(img)

    if self.transform:
      img_tensor = self.transform(img_tensor)
            
    y = self.labels[i]
    return (img_tensor, y)

5. Initialize Model, dataset and dataloader

In [20]:
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader

batch_size = 10
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)
model.to(device)
model.eval()

dataset = CustomDataset("./data", weights.transforms())
dataloader = DataLoader(dataset, batch_size)

6. Run the images through the model and get predictions

In [21]:
model_class_ids = [weights.meta["categories"].index(c.lower()) for c in classes]

targets = []
predictions = []

for X, y in dataloader:
  X, y = X.to(device), y.to(device)
  pred = model(X)

  for i, prediction in enumerate(pred):
    class_predictions = [prediction[id].item() for id in model_class_ids]
    predictions.append(class_predictions)

    temp = np.zeros(3, dtype=int)
    temp[int(y[i].item())] = 1

    targets.append(temp)
  
  torch.cuda.empty_cache()

7. Calculate statistics

In [22]:
import torchmetrics

predictions_tensor = torch.tensor(predictions)
targets_tensor = torch.tensor(targets)

threshold = 0.9

accuracy_metric = torchmetrics.classification.MultilabelAccuracy(num_labels = 3, threshold = threshold, average = "micro")
accuracy = accuracy_metric(predictions_tensor, targets_tensor).item()

precision_metric = torchmetrics.classification.MultilabelPrecision(num_labels = 3, threshold = threshold, average = "micro")
precision = precision_metric(predictions_tensor, targets_tensor).item()

recall_metric = torchmetrics.classification.MultilabelRecall(num_labels = 3, threshold = threshold, average = "micro")
recall = recall_metric(predictions_tensor, targets_tensor).item()

f1_metric = torchmetrics.classification.MultilabelF1Score(num_labels = 3, threshold = threshold, average = "micro")
f1 = f1_metric(predictions_tensor, targets_tensor).item()

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 score:", f1)

Accuracy: 0.9374583959579468
Precision: 0.8713503479957581
Recall: 0.9530938267707825
F1 score: 0.9103908538818359
