In [1]:
import shutil
from pathlib import Path

import torch
from torchvision.transforms import Compose, InterpolationMode, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from joblib import Parallel, delayed
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.jit import ScriptModule
from tqdm.notebook import tqdm

# !wget -O "../../labels/ImageNetLabels.txt" https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt

CLASS_MAP = {i - 1: v for i, v in enumerate(Path('labels/ImageNetLabels.txt').read_text().splitlines())}
IMAGE_FILES_PATH = 'images/'

transform = Compose([
    Resize(256, InterpolationMode.LANCZOS),
    CenterCrop(224),
    ToTensor(),
    Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

In [2]:
model = torch.jit.load("deit_small_distilled_patch16_224_sq.pt")

In [3]:
def run(model: ScriptModule, filename: Path, classes: dict) -> None:
    img = transform(Image.open(filename))[None,]
    probs = model(img)
    top_pred = torch.argmax(probs)
    pred_label = Path(classes[top_pred.item()].replace(' ', '_'))
    try:
        (filename.parent / pred_label).mkdir()
    except FileExistsError:
        ...
    filename.rename(filename.parent / pred_label / filename.name)


def label(model: ScriptModule, root_dir: Path, class_map: dict) -> None:
    # freezes if threads not specified
    images = list(Path(root_dir).iterdir())
    Parallel(n_jobs=-1, prefer='threads')(
        delayed(run)(model, img, class_map) for img in tqdm(images,total=len(images)))

In [4]:
def unlabel(root_dir: str) -> None:
    dirs = [p.iterdir() for p in Path(root_dir).iterdir() if p.is_dir()]
    [f.rename(f.parent.parent / f.name) for d in dirs for f in d]
    [shutil.rmtree(p) for p in Path(root_dir).iterdir() if p.is_dir()]

In [5]:
label(model, IMAGE_FILES_PATH, CLASS_MAP)

In [7]:
# unlabel(IMAGE_FILES_PATH)