In [None]:
import os
import cv2
import json
import time
import torch
import open_clip
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from rich import print as rprint
from sentence_transformers import util

In [None]:
open_clip.list_pretrained()

### load model

In [None]:
load_start_time = time.time()
model_name = "ViT-H-14"
# model_name="ViT-B-32"
pretrained = "laion2b_s32b_b79k"
# pretrained = "laion400m_e32"
device = "cuda" if torch.cuda.is_available() else "mps"
model, _, preprocess = open_clip.create_model_and_transforms(
    model_name, pretrained
)
model.to(device)
rprint(f"Model loaded in {time.time() - load_start_time:.2f} seconds")

### Load images
- base images: images of the product we're trying to generate a listing for, taken by us.
- item images: images of eBay listings, taken from the eBay dataset, grouped into folders named by the item ID.

In [None]:
base_image_folder = "./data/base_images"
base_image_paths = [
    os.path.join(base_image_folder, img_name)
    for img_name in os.listdir(base_image_folder)
]

def load_and_preprocess_images(image_paths, batch_size=256):
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i : i + batch_size]
        imgs = [
            cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
            for img_path in batch_paths
        ]
        imgs = [Image.fromarray(img).convert("RGB") for img in imgs]
        imgs = torch.stack([preprocess(img) for img in imgs]).to(device)
        yield imgs


def encode_images(image_paths):
    print(f"Encoding {len(image_paths)} images")
    if not isinstance(image_paths, list):
        print("Image paths are not a list")
        image_paths = [
            os.path.join(base_image_folder, img_name)
            for img_name in os.listdir(base_image_folder)
        ]

    encoded_images = []
    for batch_imgs in load_and_preprocess_images(image_paths):
        with torch.no_grad():
            batch_encoded_imgs = model.encode_image(batch_imgs)
            encoded_images.append(batch_encoded_imgs)
        print(f"Encoded {len(encoded_images) * 256} images")
    return torch.cat(encoded_images)

In [None]:
encoded_images = encode_images(base_image_paths)
rprint(f"images: {encoded_images.shape[0]}, features/image: {encoded_images.shape[1]}")

### Example: cross-similarity for source images
Here, we're producing a matrix based on the provided source images, comparing the similarity of each image to every other image using cosine similarity. As expected, we've got a diagonal of 1s (comparing identical images).

In [None]:
similarity_matrix = util.pytorch_cos_sim(encoded_images, encoded_images)
similarity_matrix = similarity_matrix.cpu().numpy()
plt.imshow(similarity_matrix, cmap='Blues', interpolation="nearest")
plt.xticks(np.arange(len(base_image_paths)), [i for i in range(len(base_image_paths))])
plt.yticks(np.arange(len(base_image_paths)), [i for i in range(len(base_image_paths))])
plt.colorbar()
plt.title("Similarity Matrix")
for i in range(len(base_image_paths)):
    for j in range(len(base_image_paths)):
        plt.text(
            j, i, f"{similarity_matrix[i, j]:.2f}", ha="center", va="center", color="black"
        )
plt.show()

In [None]:
item_image_folders_path = "./data/item_images"
item_image_folders = os.listdir(item_image_folders_path)
image_paths = []
for folder in item_image_folders:
    folder_path = os.path.join(item_image_folders_path, folder)
    image_paths.extend(
        [os.path.join(folder_path, img) for img in os.listdir(folder_path)]
    )
print(len(image_paths))
print(type(image_paths))
encoded_item_images = encode_images(image_paths)

In [None]:
base_image_paths = [
    os.path.join(base_image_folder, img_name)
    for img_name in os.listdir(base_image_folder)
]

encoded_base_images = encode_images(base_image_paths)

In [None]:
encoded_base_images.shape

In [None]:
encoded_item_images.shape

In [None]:
similarity_matrix = util.pytorch_cos_sim(encoded_item_images, encoded_base_images)

In [None]:
item_image_paths = image_paths

In [None]:
# similarity_matrix.shape -> torch.Size([2263, 4])
# The first dimension is the item images we've provded. They will match the order of item_image_paths (i.e. item_image_paths[0] will match similarity_matrix[0])
# The second dimension is the base images we've provided. They will match the order of base_image_paths (i.e. base_image_paths[0] will match similarity_matrix[:, 0])

In [None]:
# Find the max index for each base image.
max_indices = torch.argmax(similarity_matrix, dim=0)

# same as above, but for base targets 1 through 4
for base_target in range(len(base_image_paths)):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(cv2.cvtColor(cv2.imread(item_image_paths[max_indices[base_target]]), cv2.COLOR_RGB2BGR))
    ax[0].set_title(f"Closest Image {item_image_paths[max_indices[base_target]].split('/')[-2:-1]}")
    ax[1].imshow(cv2.cvtColor(cv2.imread(base_image_paths[base_target]), cv2.COLOR_RGB2BGR))
    ax[1].set_title(f'Base Image (score={similarity_matrix[max_indices[base_target], base_target]:.5f})')
    plt.show()
    print(f"Closest Image {item_image_paths[max_indices[base_target]].split('/')[-2:-1]}")
    print(f'Base Image (score={similarity_matrix[max_indices[base_target], base_target]:.5f})')
    