## wd-swinv2-tagger-v3

References:
- Original model: https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3
- Old training repo: https://github.com/SmilingWolf/SW-CV-ModelZoo
- v3 repo: https://github.com/SmilingWolf/JAX-CV
- Converting script: https://github.com/huggingface/transformers/blob/main/src/transformers/models/swinv2/convert_swinv2_timm_to_pytorch.py

In [None]:
# clone the swinv2 model
!git clone https://huggingface.co/SmilingWolf/wd-swinv2-tagger-v3 ./swinv2-v3

## imports

In [1]:
import torch
from PIL import Image

from transformers import AutoImageProcessor, Swinv2Config, Swinv2ForImageClassification
from safetensors.torch import load_file

import pandas as pd

## config

In [2]:
config = Swinv2Config.from_pretrained("swinv2-v3-config")

## labels

In [3]:
df = pd.read_csv("./swinv2-v3/selected_tags.csv")
df.head()

Unnamed: 0,tag_id,name,category,count
0,9999999,general,9,1589178
1,9999998,sensitive,9,3994361
2,9999997,questionable,9,892496
3,9999996,explicit,9,706509
4,470575,1girl,0,5113288


In [4]:
def convert_tag_name(tag: str, category: int):
    if category == 0:
        return tag
    elif category == 4:
        return f"character:{tag}"
    elif category == 9:
        return f"rating:{tag}"

In [5]:
id2label = {
    i: convert_tag_name(tag, df["category"][i]) for i, tag in enumerate(df["name"])
}
id2label

{0: 'rating:general',
 1: 'rating:sensitive',
 2: 'rating:questionable',
 3: 'rating:explicit',
 4: '1girl',
 5: 'solo',
 6: 'long_hair',
 7: 'breasts',
 8: 'looking_at_viewer',
 9: 'blush',
 10: 'smile',
 11: 'open_mouth',
 12: 'short_hair',
 13: 'blue_eyes',
 14: 'simple_background',
 15: 'shirt',
 16: 'large_breasts',
 17: 'skirt',
 18: 'blonde_hair',
 19: 'multiple_girls',
 20: 'brown_hair',
 21: 'black_hair',
 22: 'long_sleeves',
 23: 'hair_ornament',
 24: 'white_background',
 25: '1boy',
 26: 'gloves',
 27: 'red_eyes',
 28: 'dress',
 29: 'thighhighs',
 30: 'hat',
 31: 'holding',
 32: 'bow',
 33: 'navel',
 34: 'animal_ears',
 35: 'ribbon',
 36: 'hair_between_eyes',
 37: 'closed_mouth',
 38: '2girls',
 39: 'cleavage',
 40: 'jewelry',
 41: 'bare_shoulders',
 42: 'very_long_hair',
 43: 'sitting',
 44: 'twintails',
 45: 'medium_breasts',
 46: 'brown_eyes',
 47: 'standing',
 48: 'nipples',
 49: 'green_eyes',
 50: 'underwear',
 51: 'blue_hair',
 52: 'jacket',
 53: 'school_uniform',
 54:

In [6]:
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

In [7]:
config.save_pretrained("swinv2-v3-hf")  # save the config

## load models and convert

In [8]:
def rename_key(name):
    if "patch_embed.proj" in name:
        name = name.replace(
            "patch_embed.proj", "embeddings.patch_embeddings.projection"
        )
    if "patch_embed.norm" in name:
        name = name.replace("patch_embed.norm", "embeddings.norm")
    if "layers" in name:
        name = "encoder." + name
    if "attn.proj" in name:
        name = name.replace("attn.proj", "attention.output.dense")
    if "attn" in name:
        name = name.replace("attn", "attention.self")
    if "norm1" in name:
        name = name.replace("norm1", "layernorm_before")
    if "norm2" in name:
        name = name.replace("norm2", "layernorm_after")
    if "mlp.fc1" in name:
        name = name.replace("mlp.fc1", "intermediate.dense")
    if "mlp.fc2" in name:
        name = name.replace("mlp.fc2", "output.dense")
    if "q_bias" in name:
        name = name.replace("q_bias", "query.bias")
    if "k_bias" in name:
        name = name.replace("k_bias", "key.bias")
    if "v_bias" in name:
        name = name.replace("v_bias", "value.bias")
    if "cpb_mlp" in name:
        name = name.replace("cpb_mlp", "continuous_position_bias_mlp")
    if name == "norm.weight":
        name = "layernorm.weight"
    if name == "norm.bias":
        name = "layernorm.bias"

    if "head.fc" in name:
        name = name.replace("head.fc", "classifier")
    else:
        name = "swinv2." + name

    if "1.downsample" in name:
        name = name.replace("1.downsample", "0.downsample")
    elif "2.downsample" in name:
        name = name.replace("2.downsample", "1.downsample")
    elif "3.downsample" in name:
        name = name.replace("3.downsample", "2.downsample")

    return name

In [9]:
def convert_state_dict(orig_state_dict, model):
    for key in orig_state_dict.copy().keys():
        val = orig_state_dict.pop(key)

        if "mask" in key:
            continue
        elif "qkv" in key:
            key_split = key.split(".")
            layer_num = int(key_split[1])
            block_num = int(key_split[3])
            dim = (
                model.swinv2.encoder.layers[layer_num]
                .blocks[block_num]
                .attention.self.all_head_size
            )

            if "weight" in key:
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.weight"
                ] = val[:dim, :]
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.weight"
                ] = val[dim : dim * 2, :]
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.weight"
                ] = val[-dim:, :]
            else:
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.query.bias"
                ] = val[:dim]
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.key.bias"
                ] = val[dim : dim * 2]
                orig_state_dict[
                    f"swinv2.encoder.layers.{layer_num}.blocks.{block_num}.attention.self.value.bias"
                ] = val[-dim:]
        else:
            orig_state_dict[rename_key(key)] = val

    return orig_state_dict

In [10]:
timm_file = load_file("./swinv2-v3/model.safetensors")

In [11]:
list(timm_file.keys())

['head.fc.bias',
 'head.fc.weight',
 'layers.0.blocks.0.attn.cpb_mlp.0.bias',
 'layers.0.blocks.0.attn.cpb_mlp.0.weight',
 'layers.0.blocks.0.attn.cpb_mlp.2.weight',
 'layers.0.blocks.0.attn.logit_scale',
 'layers.0.blocks.0.attn.proj.bias',
 'layers.0.blocks.0.attn.proj.weight',
 'layers.0.blocks.0.attn.q_bias',
 'layers.0.blocks.0.attn.qkv.weight',
 'layers.0.blocks.0.attn.v_bias',
 'layers.0.blocks.0.mlp.fc1.bias',
 'layers.0.blocks.0.mlp.fc1.weight',
 'layers.0.blocks.0.mlp.fc2.bias',
 'layers.0.blocks.0.mlp.fc2.weight',
 'layers.0.blocks.0.norm1.bias',
 'layers.0.blocks.0.norm1.weight',
 'layers.0.blocks.0.norm2.bias',
 'layers.0.blocks.0.norm2.weight',
 'layers.0.blocks.1.attn.cpb_mlp.0.bias',
 'layers.0.blocks.1.attn.cpb_mlp.0.weight',
 'layers.0.blocks.1.attn.cpb_mlp.2.weight',
 'layers.0.blocks.1.attn.logit_scale',
 'layers.0.blocks.1.attn.proj.bias',
 'layers.0.blocks.1.attn.proj.weight',
 'layers.0.blocks.1.attn.q_bias',
 'layers.0.blocks.1.attn.qkv.weight',
 'layers.0.block

In [12]:
model = Swinv2ForImageClassification(config)
model.eval()

Swinv2ForImageClassification(
  (swinv2): Swinv2Model(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=4, bias=False)
                  )
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_featur

In [13]:
list(model.state_dict().keys())

['swinv2.embeddings.patch_embeddings.projection.weight',
 'swinv2.embeddings.patch_embeddings.projection.bias',
 'swinv2.embeddings.norm.weight',
 'swinv2.embeddings.norm.bias',
 'swinv2.encoder.layers.0.blocks.0.attention.self.logit_scale',
 'swinv2.encoder.layers.0.blocks.0.attention.self.continuous_position_bias_mlp.0.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.self.continuous_position_bias_mlp.0.bias',
 'swinv2.encoder.layers.0.blocks.0.attention.self.continuous_position_bias_mlp.2.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.self.query.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.self.query.bias',
 'swinv2.encoder.layers.0.blocks.0.attention.self.key.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.self.value.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.self.value.bias',
 'swinv2.encoder.layers.0.blocks.0.attention.output.dense.weight',
 'swinv2.encoder.layers.0.blocks.0.attention.output.dense.bias',
 'swinv2.encoder.layers.0.blocks.0.laye

In [14]:
new_state_dict = convert_state_dict(timm_file, model)
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [17]:
image_processor = AutoImageProcessor.from_pretrained(
    "swinv2-v3-config", trust_remote_code=True
)
image = Image.open("./sample.jpg")
inputs = image_processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

logits = torch.sigmoid(outputs.logits[0])
logits

tensor([9.0890e-01, 8.8995e-02, 9.5065e-04,  ..., 1.8553e-06, 5.5421e-07,
        1.3558e-06])

In [18]:
results = {model.config.id2label[i]: logit.float() for i, logit in enumerate(logits)}
results = {
    k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True)
}
results  # rating tags and character tags are also included

{'1girl': tensor(0.9963),
 'solo': tensor(0.9634),
 'school_uniform': tensor(0.9550),
 'short_hair': tensor(0.9407),
 'skirt': tensor(0.9385),
 'outdoors': tensor(0.9270),
 'serafuku': tensor(0.9171),
 'rating:general': tensor(0.9089),
 'sky': tensor(0.8537),
 'cloud': tensor(0.8494),
 'sailor_collar': tensor(0.7718),
 'bottle': tensor(0.7138),
 'pleated_skirt': tensor(0.7047),
 'black_skirt': tensor(0.6987),
 'long_sleeves': tensor(0.6412),
 'shirt': tensor(0.6234),
 'neckerchief': tensor(0.6087),
 'black_hair': tensor(0.5551),
 'water': tensor(0.5268),
 'railing': tensor(0.5240),
 'sunset': tensor(0.5207),
 'scenery': tensor(0.5109),
 'standing': tensor(0.5085),
 'black_sailor_collar': tensor(0.5049),
 'black_serafuku': tensor(0.4917),
 'ocean': tensor(0.4829),
 'cowboy_shot': tensor(0.4727),
 'cloudy_sky': tensor(0.4689),
 'profile': tensor(0.4669),
 'closed_mouth': tensor(0.3929),
 'from_behind': tensor(0.3805),
 'lighthouse': tensor(0.3789),
 'brown_hair': tensor(0.3671),
 'black_

In [None]:
model.to(torch.bfloat16)

Swinv2ForImageClassification(
  (swinv2): Swinv2Model(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=4, bias=False)
                  )
                  (query): Linear(in_features=128, out_features=128, bias=True)
                  (key): Linear(in_features=128, out_featur

In [None]:
model.save_pretrained("swinv2-v3-hf")

In [None]:
model.push_to_hub("p1atdev/wd-swinv2-tagger-v3-hf")