## test model


In [20]:
import cv2
import numpy as np
import pandas as pd
from PIL import Image

import torch
from transformers import Swinv2ForImageClassification
from transformers.image_transforms import to_channel_dimension_format

In [2]:
model = Swinv2ForImageClassification.from_pretrained("./tagger-hf")

In [3]:
image = Image.open("sample.jpg")

In [4]:
def load_labels() -> list[str]:
    df = pd.read_csv("./tagger/selected_tags.csv")

    tag_names = df["name"].tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, rating_indexes, general_indexes, character_indexes

tag_names, rating_indexes, general_indexes, character_indexes = load_labels()

In [5]:
# ref: https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py#L124
def set_white_bg(img: Image.Image):
    img = img.convert("RGBA")
    white_bg = Image.new("RGBA", img.size, "WHITE")
    white_bg.paste(img, mask=img)
    return white_bg.convert("RGB")

In [21]:
image = set_white_bg(image)

In [22]:
# ref: https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py#L130
def preprocess(img: Image.Image, size: int):
    img = np.asarray(img)

    # PIL RGB to OpenCV BGR
    img = img[:, :, ::-1]

    old_size = img.shape[:2]
    desired_size = max(old_size)
    desired_size = max(desired_size, size)

    delta_w = desired_size - old_size[1]
    delta_h = desired_size - old_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)

    color = [255, 255, 255]
    img = cv2.copyMakeBorder(
        img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
    )

    if img.shape[0] > size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_AREA)
    elif img.shape[0] < size:
        img = cv2.resize(img, (size, size), interpolation=cv2.INTER_CUBIC)

    # cast
    img = (img.astype(np.float32) - 127.5) * (1 / 127.5)
    img = to_channel_dimension_format(img, "channels_first")
    img = np.expand_dims(img, 0)

    return img

preprocess(image, 448)

array([[[[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],

        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]],

        [[1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         ...,
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.],
         [1., 1., 1., ..., 1., 1., 1.]]]], dtype=float32)

In [23]:
inputs = {
    "pixel_values": torch.tensor(preprocess(image, 448))
}
inputs["pixel_values"].shape

torch.Size([1, 3, 448, 448])

In [24]:
with torch.no_grad():
    outputs = model(**inputs)

In [25]:
outputs

Swinv2ImageClassifierOutput(loss=None, logits=tensor([[  1.0366,  -1.4891,  -5.0560,  ..., -13.1720, -12.4366, -12.7293]]), hidden_states=None, attentions=None, reshaped_hidden_states=None)

In [26]:
logits = torch.sigmoid(outputs.logits[0])

In [30]:
results = {model.config.id2label[i]: logit.float() for i, logit in enumerate(logits)}
results = {k:v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)}
results  # rating tags and character tags are also included

{'1girl': tensor(0.9524),
 'solo': tensor(0.8733),
 'general': tensor(0.7382),
 'school_uniform': tensor(0.6762),
 'outdoors': tensor(0.6675),
 'black_hair': tensor(0.6641),
 'short_hair': tensor(0.6368),
 'sky': tensor(0.5657),
 'skirt': tensor(0.5608),
 'serafuku': tensor(0.5393),
 'cloud': tensor(0.4181),
 'scenery': tensor(0.4173),
 'sailor_collar': tensor(0.4109),
 'blue_sky': tensor(0.3885),
 'shirt': tensor(0.3657),
 'long_sleeves': tensor(0.3433),
 'day': tensor(0.3323),
 'bangs': tensor(0.3179),
 'black_skirt': tensor(0.2557),
 'pleated_skirt': tensor(0.2379),
 'standing': tensor(0.2098),
 'brown_hair': tensor(0.1961),
 'sensitive': tensor(0.1841),
 'black_serafuku': tensor(0.1645),
 'black_shirt': tensor(0.1618),
 'from_behind': tensor(0.1547),
 'closed_mouth': tensor(0.1500),
 'blue_skirt': tensor(0.1413),
 'long_hair': tensor(0.1348),
 'from_side': tensor(0.1311),
 'neckerchief': tensor(0.1274),
 'wind': tensor(0.1244),
 'closed_eyes': tensor(0.1228),
 'black_sailor_collar'

In [31]:
# only general tags
general_tags_threshold = 0.35
general_tag_results = {
    model.config.id2label[i]: logit.float() for i, logit in enumerate(logits) if i in general_indexes
}
general_tag_results = {k:v for k, v in sorted(general_tag_results.items(), key=lambda item: item[1], reverse=True) if v > general_tags_threshold}

general_tag_results

{'1girl': tensor(0.9524),
 'solo': tensor(0.8733),
 'school_uniform': tensor(0.6762),
 'outdoors': tensor(0.6675),
 'black_hair': tensor(0.6641),
 'short_hair': tensor(0.6368),
 'sky': tensor(0.5657),
 'skirt': tensor(0.5608),
 'serafuku': tensor(0.5393),
 'cloud': tensor(0.4181),
 'scenery': tensor(0.4173),
 'sailor_collar': tensor(0.4109),
 'blue_sky': tensor(0.3885),
 'shirt': tensor(0.3657)}