# Demo of pre-trained anime character identification

In [None]:
! pip install git+https://github.com/kosuke1701/AnimeCV.git

In [None]:
!wget https://github.com/kosuke1701/AnimeCV/releases/download/0111_best_randaug/0111_best_randaug.zip
!unzip 0111_best_randaug

In [None]:
# Face detection module
from animecv.object_detection import FaceDetector_EfficientDet
from animecv.util import load_image

detector = FaceDetector_EfficientDet(coef=2, use_cuda=True)

In [None]:
# Character face encoder
import animecv
from animecv.module import ImageBBEncoder, Similarity
from torchvision import transforms

torch_model = animecv.general.OML_ImageFolder_Pretrained("0111_best_randaug")
transform = [
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
]
transform = transforms.Compose(transform)

encoder = ImageBBEncoder(torch_model, post_trans=transform, scale=1.0)
encoder.to("cuda")

threshold = 0.65 # Threshold of dot-product of embeddings which is determined so that the model's FPR becomes 0.22.

In [None]:
from google.colab import files
import IPython

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

Upload your image here. Two images are required.

ここで画像をアップロードしてください。合計２枚の画像をアップロードします。

In [None]:
uploaded = list(files.upload())
image1 = uploaded[0]
IPython.display.Image(image1, width=300)

In [None]:
uploaded = list(files.upload())
image2 = uploaded[0]
IPython.display.Image(image2, width=300)

In [None]:
images = [load_image(image1), load_image(image2)]
face_bbox = detector.detect(images)
face_embs, lst_i_img, lst_i_bbox = encoder.encode(images, face_bbox)
face_embs = face_embs.detach().cpu().numpy()

cropped_images = []
for i_img, i_bbox in zip(lst_i_img, lst_i_bbox):
    xmin, ymin, xmax, ymax = face_bbox[i_img][i_bbox]["coordinates"]
    crop_img = images[i_img].crop((xmin, ymin, xmax, ymax))

    if min(crop_img.size) == 0:
        continue

    cropped_images.append(crop_img)

n_img = len(cropped_images)
print(f"Detected {n_img} faces.")

for i_img, img in enumerate(cropped_images):
    ax = plt.subplot(1, n_img, i_img+1)
    ax.imshow(np.array(img))
plt.show()

print("Similarity of each face pair. Rows and columns correspond to each image.")
for i_img in range(n_img):
    line = []
    for j_img in range(n_img):
        sim = np.dot(face_embs[i_img], face_embs[j_img])
        label = "SAME" if sim > threshold else "DIFF"
        line.append(f"{sim:.3f}/{label}")
    print("\t".join(line))