<a href="https://colab.research.google.com/github/nakamura196/ndl_ocr/blob/main/CLIP%E3%82%92%E7%94%A8%E3%81%84%E3%81%9FText_to_Image%E3%81%A8Image_to_Image%E6%A4%9C%E7%B4%A2%E3%81%AE%E3%83%81%E3%83%A5%E3%83%BC%E3%83%88%E3%83%AA%E3%82%A2%E3%83%AB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CLIPを用いたText-to-ImageとImage-to-Image検索のチュートリアル

以下の記事を参考にしています。

[Text-to-Image and Image-to-Image Search Using CLIP](https://www.pinecone.io/learn/clip-image-search/)

In [None]:
!pip -qqq install transformers torch datasets

In [None]:
!pip -qqq install gdcm
!pip -qqq install pydicom
!pip -qqq install faiss-gpu

In [None]:
import os
import faiss
import torch
import skimage
import requests
# import pinecone
import numpy as np
import pandas as pd
from PIL import Image
from io import BytesIO
import IPython.display
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import OrderedDict
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer

In [None]:
# Get the dataset
image_data = load_dataset("conceptual_captions", split="train")

URLをbase64でエンコードして、tmpフォルダに保存する

In [None]:
import base64
import hashlib

def encode_and_get_path(url):
  b = hashlib.md5(url.encode()).hexdigest()
  return "/tmp/" + b + ".jpg"

def check_valid_URLs(image_URL):
   try:
     path = encode_and_get_path(image_URL)

     if os.path.exists(path):
       return True

     response = requests.get(image_URL)
     img = Image.open(BytesIO(response.content))

     os.makedirs(os.path.dirname(path), exist_ok=True)

     img.save(path)

     return True
   except Exception as e:
     print(e)
     return False

def get_image(image_URL):
  path = encode_and_get_path(image_URL)

  image = Image.open(path).convert("RGB")

  return image

In [None]:
size = 500
!echo ダウンロードしたデータのうち、{size}件を使用する

In [None]:
image_data_df_all = image_data.to_pandas()
image_data_df = image_data_df_all.head(size)

In [None]:
from tqdm import tqdm
tqdm.pandas()

# Transform dataframe
image_data_df["is_valid"] = image_data_df["image_url"].progress_apply(check_valid_URLs)
# Get valid URLs
image_data_df = image_data_df[image_data_df["is_valid"]==True]
# Get image from URL
image_data_df["image"] = image_data_df["image_url"].progress_apply(get_image)

In [None]:
image_data_df

In [None]:
def get_model_info(model_ID, device):
  # Save the model to device
	model = CLIPModel.from_pretrained(model_ID).to(device)
 	# Get the processor
	processor = CLIPProcessor.from_pretrained(model_ID)
  # Get the tokenizer
	tokenizer = CLIPTokenizer.from_pretrained(model_ID)
       # Return model, processor & tokenizer
	return model, processor, tokenizer
# Set the device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
# Define the model ID
model_ID = "openai/clip-vit-base-patch32"
# Get model, processor & tokenizer
model, processor, tokenizer = get_model_info(model_ID, device)

In [None]:
def get_single_text_embedding(text):
  inputs = tokenizer(text, return_tensors = "pt").to(device)
  text_embeddings = model.get_text_features(**inputs)
  # convert the embeddings to numpy array
  embedding_as_np = text_embeddings.cpu().detach().numpy()
  return embedding_as_np

def get_all_text_embeddings(df, text_col):
  df["text_embeddings"] = df[str(text_col)].apply(get_single_text_embedding)
  return df

# Apply the functions to the dataset
image_data_df = get_all_text_embeddings(image_data_df, "caption")
image_data_df

In [None]:
def get_single_image_embedding(my_image):
  image = processor(
		text = None,
		images = my_image,
		return_tensors="pt"
		)["pixel_values"].to(device)
  embedding = model.get_image_features(image)
  # convert the embeddings to numpy array
  embedding_as_np = embedding.cpu().detach().numpy()
  return embedding_as_np

def get_all_images_embedding(df, img_column):
	df["img_embeddings"] = df[str(img_column)].apply(get_single_image_embedding)
	return df

image_data_df = get_all_images_embedding(image_data_df, "image")
image_data_df

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
def get_top_N_images(query, data, top_K=4, search_criterion="text"):
   # Text to image Search
   if(search_criterion.lower() == "text"):
     query_vect = get_single_text_embedding(query)
   # Image to image Search
   else:
     query_vect = get_single_image_embedding(query)
   # Relevant columns
   revevant_cols = ["caption", "image", "cos_sim"]
   # Run similarity Search
   data["cos_sim"] = data["img_embeddings"].apply(lambda x: cosine_similarity(query_vect, x))# line 17
   data["cos_sim"] = data["cos_sim"].apply(lambda x: x[0][0])
   """
   Retrieve top_K (4 is default value) articles similar to the query
   """
   most_similar_articles = data.sort_values(by='cos_sim',  ascending=False)[1:top_K+1] # line 24
   return most_similar_articles[revevant_cols].reset_index()

In [None]:
def plot_images_by_side(top_images):
 index_values = list(top_images.index.values)
 list_images = [top_images.iloc[idx].image for idx in index_values]
 list_captions = [top_images.iloc[idx].caption for idx in index_values]
 similarity_score = [top_images.iloc[idx].cos_sim for idx in index_values]
 n_row = n_col = 2
 _, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
 axs = axs.flatten()
 for img, ax, caption, sim_score in zip(list_images, axs, list_captions, similarity_score):
     ax.imshow(img)
     sim_score = 100*float("{:.2f}".format(sim_score))
     ax.title.set_text(f"Caption: {caption}\nSimilarity: {sim_score}%")
 plt.show()

In [None]:
# query_caption = image_data_df.iloc[10].caption
query_caption = "computer"
# Print the original query text
print("Query: {}".format(query_caption))
# Run the similarity search
top_images = get_top_N_images(query_caption, image_data_df)
# Plot the recommended images
plot_images_by_side(top_images)

In [None]:
# Get the query image and show it
query_image = image_data_df.iloc[30].image
query_image

In [None]:
# Run the similarity search and plot the result
top_images = get_top_N_images(query_image, image_data_df, search_criterion="image")
# Plot the result
plot_images_by_side(top_images)