In [1]:
#!/usr/bin/env python3
from pathlib import Path

import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [3]:
class OrientationDetection:
    angles = [0, 90, 180, 270]  # classes
    _model = None

    def __init__(self, checkpoint_path: str, device: str | None = None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = device
        self.checkpoint_path = checkpoint_path

    @staticmethod
    def read_image(image_path: str) -> Image.Image:
        return Image.open(image_path).convert("RGB")
    
    @property
    def model(self):
        if self._model is None:
            num_classes = len(self.angles)
            # define the model architecture (ResNet152) 
            # and update the final layer for classes
            model = models.resnet152(weights=None)
            model.fc = nn.Linear(model.fc.in_features, num_classes)
            state_dict = torch.load(self.checkpoint_path, map_location=self.device)
            model.load_state_dict(state_dict)
            # upload the model to device
            model.to(self.device)
            # set the model to  mode
            model.eval()
            self._model = model
        return self._model

    def to_tensor(self, image: Image):
        image = image.convert("RGB")
        tensor_image = transform(image).unsqueeze(0).to(self.device)
        return tensor_image

    def get_angles(self, tensor_image: torch.Tensor) -> dict:
        # Perform inference and compute probabilities
        with torch.no_grad():
            # Define the rotation angles corresponding to each class
            outputs = self.model(tensor_image)
            probabilities = nn.functional.softmax(outputs, dim=1).cpu().numpy()[0]

            angles = {angle: score
                      for angle, score in zip(self.angles, probabilities)}
            return angles

    def get_best_angle(self, tensor_image: torch.Tensor):
        angles = self.get_angles(tensor_image)
        best_angle = max(angles, key=angles.get)
        return best_angle

    def __call__(self, image_path: Path):
        return self.get_best_angle(self.to_tensor(self.read_image(image_path)))

In [4]:
def recursive_iterdir(path: Path):
    path = Path(path)
    for i in path.iterdir():
        if i.is_dir():
            yield from recursive_iterdir(i)
        yield i

In [5]:
if __name__ == "__main__":
    checkpoint_path = "resnet152-ixion-epoch-1.pth"
    od = OrientationDetection(checkpoint_path=checkpoint_path)

    for path in recursive_iterdir("images"):
        img = Image.open(path)
        print(f"Checking {path}\n")

        for a in od.angles:
            if a != 0:
                img = img.transpose(Image.Transpose.ROTATE_90)

            tensor_image = od.to_tensor(img)
            angles = od.get_angles(tensor_image)

            best_angle = max(angles, key=angles.get)
            best_score = angles[best_angle] * 100

            print(f"Orientation probabilities of {path}:")
            for angle, prob in angles.items():
                prob = prob * 100
                print(f"{angle}°: {prob:.2f}%")

            if best_angle == a:
                print(f"Image's orientation is correct ({a}°): {best_score:.2f}%")
            else:
                print(f"Failure. Model detected {best_angle}° for {best_score:.2f}%, but the correct orientation is {a}°")
            print()

Checking images/image_1.jpg
Orientation probabilities of images/image_1.jpg:
0°: 100.00%
90°: 0.00%
180°: 0.00%
270°: 0.00%
Image's orientation is correct (0°): 100.00%

Orientation probabilities of images/image_1.jpg:
0°: 0.03%
90°: 99.93%
180°: 0.03%
270°: 0.01%
Image's orientation is correct (90°): 99.93%

Orientation probabilities of images/image_1.jpg:
0°: 0.00%
90°: 0.00%
180°: 99.99%
270°: 0.00%
Image's orientation is correct (180°): 99.99%

Orientation probabilities of images/image_1.jpg:
0°: 0.00%
90°: 0.00%
180°: 0.01%
270°: 99.99%
Image's orientation is correct (270°): 99.99%

Checking images/image_2.jpg

Orientation probabilities of images/image_2.jpg:
0°: 99.41%
90°: 0.05%
180°: 0.09%
270°: 0.45%
Image's orientation is correct (0°): 99.41%

Orientation probabilities of images/image_2.jpg:
0°: 1.32%
90°: 95.95%
180°: 2.19%
270°: 0.54%
Image's orientation is correct (90°): 95.95%

Orientation probabilities of images/image_2.jpg:
0°: 0.13%
90°: 0.02%
180°: 99.81%
270°: 0.04%
