In [3]:
from PIL import Image
import os

image_folder = "images"
image_files = [os.path.join(image_folder, f) for f in os.listdir(image_folder)]

images = [Image.open(f).convert("RGB") for f in image_files]

In [4]:
import torch
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

input_tensors = torch.stack([preprocess(img) for img in images])

In [5]:
from torchvision import models

model = models.resnet50(pretrained=True)
model.eval()



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/sibouzitoun/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:28<00:00, 3.57MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:
with torch.no_grad():
    outputs = model(input_tensors)
    probabilities = torch.nn.functional.softmax(outputs, dim=1)

In [7]:
import json
import urllib.request

# Download ImageNet class labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
imagenet_classes = [line.strip() for line in urllib.request.urlopen(url)]

In [8]:
topk = 5

for i, probs in enumerate(probabilities):
    print(f"\nImage {i+1}: {image_files[i]}")
    top_probs, top_idxs = torch.topk(probs, topk)
    for prob, idx in zip(top_probs, top_idxs):
        print(f"  {imagenet_classes[idx]}: {prob.item():.4f}")


Image 1: images/cat.jpg
  b'weasel': 0.1682
  b'Egyptian cat': 0.1134
  b'mink': 0.0740
  b'tabby': 0.0427
  b'lynx': 0.0312

Image 2: images/plane.jpeg
  b'projectile': 0.1927
  b'space shuttle': 0.1431
  b'missile': 0.1133
  b'warplane': 0.0956
  b'balloon': 0.0490

Image 3: images/dog.jpg
  b'whiptail': 0.1158
  b'Saluki': 0.1007
  b'Labrador retriever': 0.0844
  b'banded gecko': 0.0623
  b'golden retriever': 0.0522

Image 4: images/flower.jpeg
  b'daisy': 0.3939
  b'hair slide': 0.0232
  b'vase': 0.0220
  b'ant': 0.0216
  b'sea anemone': 0.0164
