<a href="https://colab.research.google.com/github/softmurata/colab_notebooks/blob/main/computervision/dreamSim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Installation

In [None]:
!pip install dreamsim

Download weights

In [None]:
!mkdir models/
!wget -O models/open_clip_vitb32_pretrain.pth.tar https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/open_clip_vitb32_pretrain.pth.tar

Demo Images

In [None]:
!mkdir /content/images
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/sample_images.zip -O images/sample_images.zip
!wget https://github.com/ssundaram21/dreamsim/releases/download/v0.1.0/retrieval_images.zip -O images/retrieval_images.zip
!unzip images/sample_images.zip
!unzip images/retrieval_images.zip

Load model

In [None]:
import sys
import torch
from dreamsim import dreamsim

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = dreamsim(pretrained=True)

Utils function

In [5]:
import matplotlib.pyplot as plt

def show_imgs(ims, captions=None):
    fig, ax = plt.subplots(nrows=1, ncols=len(ims), figsize=(10, 5))
    for i in range(len(ims)):
        ax[i].imshow(ims[i])
        ax[i].axis('off')
        if captions is not None:
          ax[i].set_title(captions[i], fontweight="bold")

Similarity Search

In [6]:
from PIL import Image
import torch
import matplotlib.pyplot as plt

In [None]:
ref_pil = Image.open("sample_images/ref_1.png")
img_a_pil = Image.open("sample_images/img_a_1.png")
img_b_pil = Image.open("sample_images/img_b_1.png")

show_imgs(
    ims=[img_a_pil, ref_pil, img_b_pil],
    captions=["A", "Reference", "B"])

In [None]:
# calculate similarity score
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
img_b = preprocess(img_b_pil).to(device)

dist_a = model(ref, img_a)
dist_b = model(ref, img_b)

show_imgs(
    ims=[img_a_pil, ref_pil, img_b_pil],
    captions=[f"A, Score: {round(float(dist_a.cpu()), 3)}",
              "Reference",
              f"B, Score: {round(float(dist_b.cpu()), 3)}"])

Inference own data

In [9]:
#@markdown Required:
ref_path = "sample_images/ref_2.png" #@param {type:"string"}
img_a_path = "sample_images/img_a_2.png" #@param {type:"string"}
#@markdown Optional:
img_b_path = "sample_images/img_b_2.png" #@param {type:"string"}

In [None]:
ref_pil = Image.open(ref_path)
img_a_pil = Image.open(img_a_path)
ref = preprocess(ref_pil).to(device)
img_a = preprocess(img_a_pil).to(device)
dist_a = model(ref, img_a)

if len(img_b_path) > 0:
  img_b_pil = Image.open(img_b_path)
  img_b = preprocess(img_b_pil).to(device)
  dist_b = model(ref, img_b)
  ims = [img_a_pil, ref_pil, img_b_pil]
  captions = [f"A, Score: {round(float(dist_a.cpu()), 3)}", "Reference",
              f"B, Score: {round(float(dist_b.cpu()), 3)}"]
else:
  ims = [ref_pil, img_a_pil]
  captions = ["Reference", f"Score: {round(float(dist_a.cpu()), 3)}"]

show_imgs(
    ims=ims,
    captions=captions)

In [11]:
import gc
gc.collect()

24670

Image Retrieval

In [12]:
import os
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
import pandas as pd
import pickle

root = "retrieval_images/"
images = []
for path in os.listdir(root):
  try:
    images.append(Image.open(root + path))
  except:
    pass
query, images = images[0], images[1:]

In [13]:
# comparison with 3 models
from dreamsim import PerceptualModel

dreamsim_model = model
dino_model = PerceptualModel(feat_type='cls', model_type='dino_vitb16', stride='16', baseline=True, device="cuda")
open_clip_model = PerceptualModel(feat_type='embedding', model_type='open_clip_vitb32', stride='32', baseline=True, device="cuda")

Using cache found in ./models/facebookresearch_dino_main


In [None]:
# prepare for embedding
def get_embeddings(model, name, images):
  embeddings = []
  for img in tqdm(images):
    img = preprocess(img).to(device)
    embeddings.append(model.embed(img).detach().cpu())
  with open(f"images/{name}_embeds.pkl", "wb") as f:
    pickle.dump(embeddings, f)

get_embeddings(dreamsim_model, "dreamsim", images)
get_embeddings(dino_model, "dino", images)
get_embeddings(open_clip_model, "open_clip", images)

In [15]:
def nearest_neighbors(embeddings, query_index):
    query_embed = embeddings[query_index]
    dists = {}

    # Compute the (cosine) distance between the query embedding
    # and each search image embedding
    for i, im in enumerate(embeddings):
      if i == query_index:
        continue
      dists[i] = (1 - F.cosine_similarity(query_embed, embeddings[i],
                                          dim=-1)).item()

    # Return results sorted by distance
    df = pd.DataFrame({"ids": list(dists.keys()), "dists": list(dists.values())})
    df = df.sort_values(by="dists")
    return df

In [None]:
#@title query
query_index = 15 #@param {type:"number"}
n = 3
display_width = 11
display_height = 4

## Load embeddings for each metric and compute nearest neighbors to the query_index-th image
nn_dfs = {}
for metric_name in ["dreamsim", "open_clip", "dino"]:
    with open(f"images/{metric_name}_embeds.pkl", "rb") as f:
      embeddings = pickle.load(f)
    nn_dfs[metric_name] = nearest_neighbors(embeddings, query_index)

## Plot results
f, ax = plt.subplots(4, n+2, figsize=(14,7), gridspec_kw={"height_ratios":[0.005,1,1,1]})
ax[0,0].axis('off')
for col in range(1, n+2):
    title = "Query" if col == 1 else f"n{col-1}"
    ax[0, col].set_title(title, fontweight="bold", fontsize=15)
    ax[0, col].axis('off')

for i, name in enumerate(["dreamsim", "open_clip", "dino"]):
    ax[i+1, 0].text(0.5, 0.5, name, fontsize=13)
    ax[i+1, 0].axis('off')

    ax[i+1, 1].imshow(images[query_index])
    ax[i+1, 1].axis("off")

    for j in range(n):
        im_idx = nn_dfs[name]['ids'].iloc[j]
        ax[i + 1, j + 2].imshow(images[im_idx])
        ax[i + 1, j + 2].axis('off')
plt.tight_layout()

Perceptual Loss

In [21]:
ref_img_path = "sample_images/ref_1.png"

In [22]:
import numpy as np
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image

In [23]:
def crop(img, sizex, sizey, crop_margin_x, crop_margin_y):
  startx = np.random.randint(crop_margin_x)
  starty = np.random.randint(crop_margin_y)
  endx = startx + sizex
  endy = starty + sizey
  return img[:, :, startx:endx, starty:endy]

In [None]:
sizex, sizey, crop_margin_x, crop_margin_y = 224, 224, 32, 32
ref = preprocess(Image.open(ref_img_path)).to(device)
pred = torch.rand([1, 3, sizex+crop_margin_x, sizey+crop_margin_y])
pred = Variable(pred.cuda(), requires_grad=True)

# use model as the loss for dreamsim, and loss_fn as the loss for lpips
from tqdm import tqdm

optimizer = torch.optim.Adam([pred], lr=1e-2)
plt.ion()
fig = plt.figure(1)
ax = fig.add_subplot(131)
ax.imshow(ref[0].permute(1,2,0).detach().cpu( ))
ax.axis('off')
ax.set_title('Reference')
ax = fig.add_subplot(132)
ax.imshow(pred[0].permute(1, 2, 0).detach().cpu())
ax.axis('off')
ax.set_title('Initialization')

for i in tqdm(list(range(2500))):
    optimizer.zero_grad()
    pred_inpt = crop(pred, sizex, sizey, crop_margin_x, crop_margin_y)
    dist = model(pred_inpt, ref)
    dist.backward()
    optimizer.step()

ax = fig.add_subplot(133)
ax.imshow(pred[0].permute(1, 2, 0).detach().cpu())
ax.axis('off')
ax.set_title('Output')