In [12]:
import os
from PIL import Image, ImageOps

import torch
from torch import nn
from torchvision import models
from torchvision.transforms import v2

import pandas as pd

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [14]:
MODEL_PATH = r"D:\ASL_Project\workflow\RegNet\RegNet_model.pth"
TEST_IMAGE_FOLDER = r"D:\ASL_Project\workflow\YOLOv8\Dataset\test_images"
OUTPUT_FOLDER = r"D:\ASL_Project\workflow\RegNet\Classified Images"

In [15]:
weights = models.RegNet_Y_800MF_Weights.DEFAULT
transform = weights.transforms()

class LetterboxPad:
    def __call__(self, img):
        w, h = img.size
        max_dim = max(w, h)
        padding = ImageOps.pad(img, (max_dim, max_dim), color="black", centering=(0.5, 0.5))
        return padding

test_transform = v2.Compose([
    LetterboxPad(),
    transform
])

model = models.regnet_y_800mf(weights=weights)

model.fc = nn.Linear(model.fc.in_features, 1)

model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

RegNet(
  (stem): SimpleStemIN(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (trunk_output): Sequential(
    (block1): AnyStage(
      (block1-0): ResBottleneckBlock(
        (proj): Conv2dNormActivation(
          (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (f): BottleneckTransform(
          (a): Conv2dNormActivation(
            (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (b): Conv2dNormActivation(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=4, bias=False)
            

In [16]:
approved_dir = os.path.join(OUTPUT_FOLDER, "approved")
rejected_dir = os.path.join(OUTPUT_FOLDER, "rejected")

os.makedirs(approved_dir, exist_ok=True)
os.makedirs(rejected_dir, exist_ok=True)

In [17]:
def predict_image(image_path, model, transform):
    image = Image.open(image_path).convert("RGB")
    image = transform(image)
    image = image.unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(image)
        prob = torch.sigmoid(output)
        prediction = (prob > 0.5).int().item()

    return prob.item(), prediction

In [18]:
import shutil

image_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".webp")

for filename in os.listdir(TEST_IMAGE_FOLDER):
    if filename.lower().endswith(image_extensions):

        image_path = os.path.join(TEST_IMAGE_FOLDER, filename)

        prob, pred = predict_image(image_path, model, test_transform)

        label = "rejected" if pred == 1 else "approved"

        if pred == 1:
            destination_path = os.path.join(rejected_dir, filename)
        else:
            destination_path = os.path.join(approved_dir, filename)

        shutil.copy2(image_path, destination_path)

        print(f"{filename} --> {label} (prob={prob:.4f})")

print("\nAll images processed and sorted!")

cat_dog.png --> rejected (prob=0.9221)
test1.jpg --> approved (prob=0.0001)
test2.jpg --> rejected (prob=1.0000)
test3.jpg --> rejected (prob=0.9960)
test4.jpg --> rejected (prob=1.0000)
test5.jpg --> rejected (prob=0.9990)

All images processed and sorted!
