<a href="https://colab.research.google.com/github/nakamura196/ndl_ocr/blob/main/%5Bresnet%5D_Image_Similarity_Search_in_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  Image Similarity Search with PyTorch and ResNet

In [None]:
!nvidia-smi

In [None]:
HOME_DIR = "/content"
IMG_DIR = "/content/dataset"
# IMG_RESIZE_SIZE = 256
IMG_RESIZE_SIZE = 224

`dataset`フォルダがない場合には、データをダウンロード

サンプルデータとして、以下のデータを使用させていただきます。

https://idealo.github.io/imageatm/examples/cats_and_dogs/

In [None]:
import os
if not os.path.exists(IMG_DIR):
  !wget --no-check-certificate \
      https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip \
      -O cats_and_dogs_filtered.zip
  !unzip -q cats_and_dogs_filtered.zip
  !mkdir dataset

  import glob
  from tqdm import tqdm
  files = glob.glob("/content/cats_and_dogs_filtered/*/*/*.jpg")
  for file in tqdm(files):
    filename = file.split("/")[-1]
    !mv $file $IMG_DIR/$filename

  !rm -rf cats_and_dogs_filtered*

## ベクトル化

In [None]:
import torch
from torch import optim, nn
from torchvision import models, transforms
from torchvision.models import ResNet152_Weights
from torchvision.models import ResNet18_Weights

# Initialize the model
model_ft = models.resnet152(weights=ResNet152_Weights.DEFAULT)
# model_ft = models.resnet18(weights=ResNet18_Weights.DEFAULT)
feature_extractor = torch.nn.Sequential(*list(model_ft.children())[:-1])

# Change the device to GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
feature_extractor = feature_extractor.to(device)

import cv2
from tqdm import tqdm
import numpy as np

# Transform the image, so it becomes readable with the model
transform = transforms.Compose([
  transforms.ToPILImage(),
  transforms.Resize((IMG_RESIZE_SIZE, IMG_RESIZE_SIZE)),
  transforms.ToTensor()                              
])

def getVector(path):
  img = cv2.imread(path)
  # Transform the image
  img = transform(img)
  img = img.reshape(1, 3, IMG_RESIZE_SIZE, IMG_RESIZE_SIZE)

  img = img.to(device)
  # We only extract features, so we don't need gradient
  with torch.no_grad():
    # Extract the feature from the image
    feature = feature_extractor(img)
  # Convert to NumPy Array, Reshape it, and save it to features variable

  v = feature.cpu().detach().numpy().reshape(-1)

  return v


In [None]:
# Will contain the feature
features = []

mappings = {}

import glob
files = glob.glob(f"{IMG_DIR}/*.jpg")

files.sort()

for index in tqdm(range(len(files))):

  path = files[index]

  filename = path.split("/")[-1]

  v_path = f"data/{filename}"

  if True or not os.path.exists(v_path+".npy"):
    v = getVector(path)
    
    os.makedirs(os.path.dirname(v_path), exist_ok=True)

    np.save(v_path, v)

  v = np.load(v_path+".npy")
  features.append(v)

  mappings[index] = {
      "nconst": os.path.splitext(os.path.basename(path))[0],
      "name": "",
      "url": ""
  }

# mapping結果の保存

import json

with open('mappings.json', mode='wt', encoding='utf-8') as file:
  json.dump(mappings, file, ensure_ascii=False, indent=2)

indexの構築

In [None]:
!pip install annoy

In [None]:
# Convert to NumPy Array
features = np.array(features)

N_TREES = 1000

from annoy import AnnoyIndex

dims = features.shape[1]

print("dims", dims)

annoy_index = AnnoyIndex(dims, metric='angular')

for i in range(len(features)):

    feature = features[i]

    annoy_index.add_item(i, feature)

# k-d tree をビルドする
annoy_index.build(n_trees=N_TREES)

annoy_index.save("index.ann")

## 推論

In [None]:
n_matches = 10

from PIL import Image
import matplotlib.pyplot as plt

def plot_similar_images(TEST_IMAGE_PATH):
    v = getVector(TEST_IMAGE_PATH)
    results = annoy_index.get_nns_by_vector(v, n_matches, include_distances=True)

    indices = results[0]
    scores = results[1]

    for i in range(len(indices)):
        index = indices[i]

        mapping = mappings[index]

        img_name = mapping["nconst"] + ".jpg"
        img_path = os.path.join(f"{IMG_DIR}/{img_name}")
        print("-------------------------------------------------------------------")
        print(img_path)
        img = Image.open(img_path).convert("RGB")
        plt.imshow(img)
        plt.show()

猫の例

In [None]:
TEST_IMAGE_PATH = "/content/dataset/cat.0.jpg"
plot_similar_images(TEST_IMAGE_PATH)

犬の例

In [None]:
TEST_IMAGE_PATH = "/content/dataset/dog.0.jpg"
plot_similar_images(TEST_IMAGE_PATH)