In [1]:
import logging
import os
import re
import shutil
from enum import Enum
from pathlib import Path
from typing import Optional

import timm
import torch
from PIL import Image
from joblib import Parallel, delayed
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from torch.jit import ScriptModule
from torchvision.transforms import Compose, InterpolationMode, Resize, CenterCrop, ToTensor, Normalize
from tqdm.notebook import tqdm

from labels.labels import IMAGENET_1K, IMAGENET_21K


# !wget -O "../../labels/ImageNetLabels.txt" https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt
# !wget -O "../../labels/imagenet21k_wordnet_lemmas.txt" https://storage.googleapis.com/bit_models/imagenet21k_wordnet_lemmas.txt

class LogConstants(Enum):
    FMT = '%(asctime)s.%(msecs)03d %(levelname)s:\t%(message)s'
    DT_FMT = '%Y-%m-%d %H:%M:%S'


if not os.path.exists('logs'): os.makedirs('logs')
logging.basicConfig(
    filename=f'logs/tested_models.log',
    filemode='a',
    level=logging.INFO,
    format=LogConstants.FMT.value,
    datefmt=LogConstants.DT_FMT.value)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
formatter = logging.Formatter(fmt=LogConstants.FMT.value,
                              datefmt=LogConstants.DT_FMT.value)
console.setFormatter(formatter)
logging.getLogger().addHandler(console)

[print(device := 'cuda:0') if torch.cuda.is_available() else print(device := 'cpu')]
torch.backends.cudnn.benchmark = True

IMAGE_FILES_PATH = Path('../../data/images/')
LABELS_PATH = Path('../../labels/')


# timm.list_models('*21k*')

cuda:0


In [2]:
def run(model: ScriptModule,
        filename: Path,
        classes: dict,
        target_classes: set,
        torchscript: bool,
        input_dim: Optional[int] = 224) -> None:
    transform = Compose([
        Resize(input_dim, InterpolationMode.LANCZOS),
        CenterCrop(input_dim),
        ToTensor(),
        Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    ])
    img = transform(Image.open(filename))[None,] if torchscript else transform(Image.open(filename))[None,].to(device)
    probs = model(img)
    top_pred = torch.argmax(probs)
    pred_label = classes[top_pred.item()]
    if pred_label in target_classes:
        pred_label = Path(re.sub('[\W\s]+', '_', pred_label.split(',')[0]))
        try:
            (filename.parent / pred_label).mkdir()
        except FileExistsError:
            ...
        filename.rename(filename.parent / pred_label / filename.name)


def label(model: ScriptModule,
          root_dir: Path,
          class_map: dict,
          n_jobs: Optional[int] = -1,
          target_classes: Optional[str] = '',
          parallel: Optional[bool] = False,
          torchscript: Optional[bool] = False,
          input_dim: Optional[int] = 224) -> None:
    images = [d for d in Path(root_dir).iterdir() if not d.is_dir()]
    target_classes = set(Path(target_classes).read_text().splitlines()) if target_classes else class_map.values()
    if parallel:
        Parallel(n_jobs=n_jobs, prefer='threads')(
            delayed(run)(model, img, class_map, target_classes, torchscript, input_dim) for img in
            tqdm(images, total=len(images)))
    else:
        [run(model, img, class_map, target_classes, torchscript, input_dim) for img in tqdm(images, total=len(images))]


def unlabel(root_dir: str) -> None:
    dirs = [p.iterdir() for p in Path(root_dir).iterdir() if p.is_dir()]
    [f.rename(f.parent.parent / f.name) for d in dirs for f in d]
    [shutil.rmtree(p) for p in Path(root_dir).iterdir() if p.is_dir()]


def targets_found() -> dict:
    found = {d for d in Path(IMAGE_FILES_PATH).iterdir() if d.is_dir()}
    return {d.name: len(set(d.iterdir())) for d in sorted(found)}

# Load a single model

### Load built-in model

In [3]:
# model = models.efficientnet_b7(pretrained=True).to(device).eval()

### Load `timm` model

In [4]:
# MODEL_NAME = 'regnetz_e8'
# model = timm.create_model(MODEL_NAME, pretrained=True).to(device).eval()

### Load scripted/quantized model

In [5]:
# MODEL_NAME = 'vit_base_patch8_224_scripted_quantized.pt'
# model = torch.jit.load(MODEL_NAME).eval()

# Load multiple models

### Sample 1K Models

In [12]:
for m in [
    'cait_s24_384',
    # 'regnetz_e8',
    # 'vit_base_patch16_384',
    # 'vit_base_patch8_224'
]:
    model = timm.create_model(m, pretrained=True)
    label(model.to(device).eval(), IMAGE_FILES_PATH, IMAGENET_1K, input_dim=384, target_classes='target_classes.txt')
    logging.info({m: targets_found()})

In [7]:
# unlabel(IMAGE_FILES_PATH)

### Sample 21K Models

In [11]:
# 'mixer_l16_224_in21k' ## keeps downloading
for m in [
    # 'resnetv2_101x1_bitm_in21k',
    # 'resnetv2_50x3_bitm_in21k',
    # 'resnetv2_50x1_bitm_in21k',
    # 'vit_base_patch8_224_in21k',
    # 'vit_small_r26_s32_224_in21k',
    # 'vit_base_patch16_224_in21k',
    # 'vit_tiny_patch16_224_in21k',
    # 'vit_tiny_r_s16_p8_224_in21k',
    'vit_small_patch32_224_in21k',
    # 'vit_base_patch32_224_in21k',
    # 'vit_small_patch16_224_in21k'
]:
    model = timm.create_model(m, pretrained=True)
    label(model.to(device).eval(), IMAGE_FILES_PATH, IMAGENET_21K, target_classes='target_classes.txt')
    logging.info({m: targets_found()})

In [10]:
# unlabel(IMAGE_FILES_PATH)