In [1]:
import os
from tqdm import tqdm
from PIL import Image

import torch
from torch.nn import Conv2d, MaxPool2d, ReLU, Linear
from torch.utils.data import DataLoader, random_split

from torchvision.models import regnet_y_32gf, RegNet_Y_32GF_Weights
from torchvision.datasets import ImageFolder, ImageNet

In [2]:
device = torch.device('cuda')

In [3]:
classes = ImageNet(root='~/PycharmProjects/pbnn/data/torch/ImageNet').classes
dog_classes = [c for c in classes if any('dog' in n for n in c)][:-1]
cat_classes = [
    ('tabby', 'tabby cat'),
    ('tiger cat',),
    ('Persian cat',),
    ('Siamese cat', 'Siamese'),
    ('Egyptian cat',),
    ('cougar',
     'puma',
     'catamount',
     'mountain lion',
     'painter',
     'panther',
     'Felis concolor'),
    ('lynx', 'catamount'),
    ('Madagascar cat', 'ring-tailed lemur', 'Lemur catta'),
]
dog_indices = [classes.index(c) for c in dog_classes]
cat_indices = [classes.index(c) for c in cat_classes]

In [4]:
model = regnet_y_32gf(weights=RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1).eval().to(device)

In [5]:
dataset = ImageFolder(root='train', transform=RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1.transforms())
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, num_workers=8, pin_memory=True)
accuracy = 0
i = 0
td = tqdm(dataloader)
for x, y in td:
    i += 1
    x, y = x.to(device), y.to(device)
    logits = model(x)
    y_pred = logits[:, dog_indices].max(dim=1).values > logits[:, cat_indices].max(dim=1).values
    accuracy += torch.sum(y_pred == y).item()
    td.set_description(f'Accuracy: {accuracy / (i * x.shape[0])}')
accuracy /= len(dataloader.dataset)

In [8]:
path = 'test'
transform = RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()
os.makedirs('test/dog', exist_ok=True)
os.makedirs('test/cat', exist_ok=True)
dogs = 1276
cats = 1259
files = tqdm(os.listdir(path))
for img_file in files:
    if img_file.endswith('jpg'):
        img = Image.open(os.path.join(path, img_file))
        img = transform(img)
        img = img.unsqueeze(0)
        img = img.to(device)
        logits = model(img)
        y_pred = logits[:, dog_indices].max(dim=1).values > logits[:, cat_indices].max(dim=1).values
        if y_pred:
            os.rename(os.path.join(path, img_file), os.path.join(path, 'dog', img_file))
            dogs += 1
        else:
            os.rename(os.path.join(path, img_file), os.path.join(path, 'cat', img_file))
            cats += 1
        files.set_description(f'Dogs: {dogs}, Cats: {cats}')

Dogs: 6192, Cats: 6308: 100%|██████████| 9967/9967 [06:41<00:00, 24.83it/s]
