In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

import os
import re
import glob
import shutil
import requests
from bs4 import BeautifulSoup

import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from PIL import Image


ROOT_PATH = os.getcwd()
HERO_NAMES_PATH = f"{ROOT_PATH}/test_data/hero_names.txt"
HERO_IMAGES_DIR = f"{ROOT_PATH}/test_data/hero_images/"
TEST_IMAGES_DIR = f"{ROOT_PATH}/test_data/test_images/"
TEST_LABELS_PATH = f"{ROOT_PATH}/test_data/test.txt"

# Get hero names list
with open(HERO_NAMES_PATH, "r") as f:
    hero_names = f.read().splitlines()
len(hero_names)

64

# Download Hero Images

In [2]:
def get_champion_hero_image_links(hero_names):
    url = "https://leagueoflegends.fandom.com/wiki/Champion_(Wild_Rift)"
    response = requests.get(url)
    soup = BeautifulSoup(response.text, "html.parser")

    hero_image_links = {}

    for hero_name in tqdm(hero_names):
        
        # Replace spaces with underscores and fix Kai'Sa and Kha'Zix
        new_hero_name = hero_name.replace("_", " ")
        new_hero_name = "Kai'Sa" if new_hero_name == "KaiSa" else new_hero_name
        new_hero_name = "Kha'Zix" if new_hero_name == "KhaZix" else new_hero_name

        link = soup.find_all("img", attrs={'alt': new_hero_name})[0]
        
        match = re.search(r'https://.*?\.png', link["src"])

        if match:
            extracted_link = match.group(0)
        else:
            match = re.search(r'https://.*?\.png', link["data-src"])
            if match:
                extracted_link = match.group(0)
            else:
                print("No match found for ", hero_name)
                continue

        hero_image_links[hero_name] = extracted_link

    return hero_image_links

hero_image_links = get_champion_hero_image_links(hero_names)
len(hero_image_links)

100%|██████████| 64/64 [00:00<00:00, 112.68it/s]


64

In [3]:
def download_hero_images(hero_image_links, path):
    for hero_name, link in tqdm(hero_image_links.items()):
        response = requests.get(link)
        
        if response.status_code == 200:
            with open(f"{path}{hero_name}.png", "wb") as f:
                f.write(response.content)
        else:
            print(f"Failed to download image for {hero_name}.")

download_hero_images(hero_image_links, path=HERO_IMAGES_DIR)

100%|██████████| 64/64 [00:15<00:00,  4.07it/s]


In [4]:
def add_folder_each_hero_image(directory=HERO_IMAGES_DIR):
    files = os.listdir(directory)

    for file in files:
        if file.endswith('.png'):
            hero_name = file.replace('.png', '')
            os.mkdir(os.path.join(directory, hero_name))
            shutil.move(os.path.join(directory, file), os.path.join(directory, hero_name, file))

# Dataset & Dataloader

## Prepare path & label

In [2]:
test_images_path_list = glob.glob(TEST_IMAGES_DIR+"*")
print(len(test_images_path_list))

# Get labels for test images
with open(TEST_LABELS_PATH, "r") as f:
    test_labels = f.read().splitlines()
    test_file_2_labels = [label.split("\t") for label in test_labels]
    test_file_2_labels = {label[0]: label[1] for label in test_file_2_labels}

def path_2_label(path):
    file_name = path.split("\\")[-1]
    return test_file_2_labels[file_name]

path_2_label(test_images_path_list[0])

98


'Ahri'

In [3]:
hero_images_path_list = glob.glob(HERO_IMAGES_DIR+"*")
len(hero_images_path_list)

# Get labels for hero images
hero_images_path_2_label = {path: path.split("\\")[-1].split(".")[0] for path in hero_images_path_list}

# Load Model

In [6]:
# Define the model to be used for feature extraction
model = torchvision.models.resnet18(pretrained=False)
model = torch.nn.Sequential(*list(model.children())[:-1])  # Remove the last layer (classifier)

ckpt_path = "model/model.ckpt"
checkpoint = torch.load(ckpt_path)
ckpt_state_dict = {key.replace('model.', ''): value for key, value in checkpoint['state_dict'].items()}

model.load_state_dict(ckpt_state_dict)
model.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

## Transform 

In [7]:
class CircleCrop(object):
    def __init__(self, size):
        self.size = size
        
    def __call__(self, img):
        # Convert PIL image to PyTorch tensor
        img = transforms.ToTensor()(img)
        
        # Define circular mask
        mask = np.zeros((img.shape[-2], img.shape[-1]))
        center = [img.shape[-1] / 2, img.shape[-2] / 2]
        radius = min(img.shape[-1], img.shape[-2]) / 2
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if (i - center[1]) ** 2 + (j - center[0]) ** 2 <= radius ** 2:
                    mask[i, j] = 1
        
        # Apply mask to image tensor
        img = img * mask
        
        # Convert tensor back to PIL image
        img = transforms.ToPILImage()(img)
        
        # Apply additional transforms
        transform = transforms.Compose([
            transforms.Resize(self.size),
            transforms.CenterCrop(self.size)
        ])
        img = transform(img)
        
        return img


In [12]:
transform = transforms.Compose([
                        transforms.Lambda(lambda x: x.crop((0, 0, int(x.height*1), x.height))),  # Crop the left side
                        transforms.Resize((40, 40)),
                        transforms.Resize((256, 256)),
                        # transforms.CenterCrop(180),
                        CircleCrop(size=256),  # Crop the left side
                        transforms.Resize((256, 256)),
                        # transforms.ToTensor(),
                        # transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        #                     std=[0.229, 0.224, 0.225]),
                    ])
image = Image.open(hero_images_path_list[11]).convert('RGB')
image = transform(image)
image.show()

# Retrieval System

In [8]:
def get_embedding(model, image_path, is_test=True, downsize=40):
    if is_test:
        transform = transforms.Compose([
                        transforms.Lambda(lambda x: x.crop((0, 0, int(x.height*1.2), x.height))),  # Crop the left side
                        transforms.Resize((256, 256)),
                        transforms.CenterCrop(175),
                        # CircleCrop(size=200),  # Crop the left side
                        # transforms.Resize((256, 256)),
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225]),
                    ])
    else:
        transform = transforms.Compose([
                        transforms.Resize((downsize, downsize)),
                        transforms.Resize((256, 256)),
                        CircleCrop(size=256),  # Crop the left side
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225]),
                    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)

    with torch.no_grad():
        embedding = model(image).squeeze().cpu().detach().numpy()
        return embedding

In [9]:
import annoy

n_trees = 100  # Number of trees in the index

annoy_index = annoy.AnnoyIndex(f=512, metric='euclidean')
num_heroes = len(hero_images_path_list)

for k, downsize in enumerate([25, 30, 35]):
    for i, hero_path in tqdm(enumerate(hero_images_path_list), total=num_heroes):
        embedding = get_embedding(model, hero_path, is_test=False, downsize=downsize)
        annoy_index.add_item(i+k*num_heroes, embedding)

annoy_index.build(n_trees)
annoy_index.get_n_items() 

100%|██████████| 64/64 [00:12<00:00,  5.32it/s]
 41%|████      | 26/64 [00:05<00:08,  4.38it/s]

In [17]:
annoy_index.save('model/heroes.ann')

True

In [19]:
annoy_index = annoy.AnnoyIndex(f=512, metric='euclidean')
annoy_index.load('model/heroes.ann')

True

In [20]:
k = 5  # Number of nearest neighbors to retrieve

query_path = test_images_path_list[10]
# query_path = hero_images_path_list[35]
print("Label: ", path_2_label(query_path))
# print(query_path)
query_embedding = get_embedding(model, query_path, is_test=True)

nn_indices, nn_scores = annoy_index.get_nns_by_vector(query_embedding, k, include_distances=True)
# print(annoy_index.get_nns_by_vector(query_embedding, k, include_distances=True))
nn_names = [list(hero_images_path_2_label.values())[i%num_heroes] for i in nn_indices]
print("Top 5 nearest neighbors: ", nn_names)

Label:  Akali
Top 5 nearest neighbors:  ['Akali', 'Akali', 'Akali', 'Nami', 'Orianna']


In [16]:
acc = 0
k=1

for test_path in tqdm(test_images_path_list):
    query_label = path_2_label(test_path)
    query_embedding = get_embedding(model, test_path, is_test=True)
    nn_indices, nn_scores = annoy_index.get_nns_by_vector(query_embedding, k, include_distances=True)
    nn_labels = [list(hero_images_path_2_label.values())[i%64] for i in nn_indices]
     
    if query_label in nn_labels:
        acc += 1

print(f"Accuracy: {acc/len(test_images_path_list)}")

100%|██████████| 98/98 [00:02<00:00, 37.38it/s]

Accuracy: 0.6224489795918368



