In [None]:
import os
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as v2
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForObjectDetection, AutoImageProcessor, SwinModel, SwinConfig

from huggingface_hub import PyTorchModelHubMixin

In [None]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(device)

In [None]:
ckpt = 'yainage90/fashion-object-detection'
detector_image_processor = AutoImageProcessor.from_pretrained(ckpt)
detector = AutoModelForObjectDetection.from_pretrained(ckpt).to(device)

ckpt = "yainage90/fashion-image-feature-extractor"
encoder_config = SwinConfig.from_pretrained(ckpt)
encoder_image_processor = AutoImageProcessor.from_pretrained(ckpt)

class ImageEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.swin = SwinModel(config=encoder_config)
        self.embedding_layer = nn.Linear(encoder_config.hidden_size, 128)

    def forward(self, image_tensor):
        features = self.swin(image_tensor).pooler_output
        embeddings = self.embedding_layer(features)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        
        return embeddings

encoder = ImageEncoder().from_pretrained('yainage90/fashion-image-feature-extractor').to(device)

In [None]:
def calculate_iou(box1, box2):
    x1, y1, x2, y2 = box1
    x3, y3, x4, y4 = box2

    xx1, yy1 = max(x1, x3), max(y1, y3)
    xx2, yy2 = min(x2, x4), min(y2, y4)

    intersection_area = max(0, xx2 - xx1) * max(0, yy2 - yy1)

    box1_area = (x2 - x1) * (y2 - y1)
    box2_area = (x4 - x3) * (y4 - y3)

    iou = intersection_area / float(box1_area + box2_area - intersection_area)
    return iou

def non_max_suppression(items, iou_threshold=0.7):
    sorted_items = sorted(items, key=lambda x: x[0], reverse=True)
    
    keep = []
    while sorted_items:
        current = sorted_items.pop(0)
        keep.append(current)
        
        sorted_items = [
            item for item in sorted_items
            if calculate_iou(current[2], item[2]) < iou_threshold
        ]
    
    return keep

def detect_objects(image_path):
    image = Image.open(image_path).convert("RGB")

    with torch.no_grad():
        inputs = detector_image_processor(images=[image], return_tensors="pt")
        outputs = detector(**inputs.to(device))
        target_sizes = torch.tensor([[image.size[1], image.size[0]]])
        results = detector_image_processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=target_sizes)[0]

    items = []
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        score = score.item()
        label = label.item()
        box = [i.item() for i in box]
        print(f"{detector.config.id2label[label]}: {round(score, 3)} at {box}")
        items.append((score, label, box))
    
    items = non_max_suppression(items)
        
    original_image = copy.deepcopy(image)
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default(size=20)
    result = []
    for score, label, bbox in items:
        box = [round(i, 2) for i in bbox]
        result.append((original_image.crop(bbox), detector.config.id2label[label], score))
        x, y, x2, y2 = tuple(bbox)
        draw.rectangle((x, y, x2, y2), outline="green", width=3)
        text_position = (bbox[0], bbox[1] - 25)
        draw.text(text_position, f"{detector.config.id2label[label]} {score:.2f}", fill="red", font=font)

    plt.figure(figsize=(12, 9))
    plt.imshow(image)
    plt.xticks([])
    plt.yticks([])
    plt.show()

    return result

In [None]:
transform = v2.Compose([
    v2.Resize((encoder_config.image_size, encoder_config.image_size)),
    v2.ToTensor(),
    v2.Normalize(mean=encoder_image_processor.image_mean, std=encoder_image_processor.image_std),
])

In [None]:
target_image_paths = []

categories = [d for d in os.listdir('../crawl/kream_thumbnails') if not d.startswith('.')]
for category in categories:
    fnames = [d for d in os.listdir(f'../crawl/kream_thumbnails/{category}') if not d.startswith('.')]
    target_image_paths.extend([f'../crawl/kream_thumbnails/{category}/{f}' for f in fnames])

In [None]:
class ThumbnailDataset(Dataset):
    
    def __init__(self, image_paths, image_processor):
        super(ThumbnailDataset, self).__init__()

        self.image_paths = image_paths
        self.image_processor = image_processor
        self.transform = v2.Compose([
            v2.Resize((224, 224)),
            v2.ToTensor(),
            v2.Normalize(mean=self.image_processor.image_mean, std=self.image_processor.image_std),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, i):
        image = Image.open(self.image_paths[i]).convert("RGB")
        image = self.transform(image)
        return image

dataset = ThumbnailDataset(target_image_paths, encoder_image_processor)
dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=False, pin_memory=True)

In [None]:
target_embeddings = []
for images in tqdm(dataloader):
    with torch.no_grad():
        embeddings = encoder(images.to(device)).tolist()
        target_embeddings.extend(embeddings)

In [None]:
def find_nearest_neighbors(query_embeddings, target_embeddings, top_k=5):
    test = query_embeddings
    target = target_embeddings
    
    distances = np.linalg.norm(test[:, np.newaxis] - target, axis=2)
    
    nearest_indices = np.argsort(distances, axis=1)[:, :top_k]
    
    return nearest_indices

In [None]:
def visualize_query_and_results(query_images, target_image_paths, nearest_indices, n_results=5):
    n_queries = len(query_images)
    fig, axes = plt.subplots(n_queries, n_results + 1, figsize=(3 * (n_results + 1), 3 * n_queries))

    if n_queries == 1:
        for i, query_image in enumerate(query_images):
            axes[0].imshow(query_image)
            axes[0].set_title(f"Query {i+1}")
            axes[0].axis('off')

            axes[1].axis('off')
            axes[1].axvline(x=0.5, color='black', linewidth=2)
            
            for j in range(n_results):
                result_idx = nearest_indices[i, j]
                axes[j+1].imshow(Image.open(target_image_paths[result_idx]).convert('RGB'))
                axes[j+1].set_title(f"Result {j+1}")
                axes[j+1].axis('off')
    else: 
        for i, query_image in enumerate(query_images):
            axes[i, 0].imshow(query_image)
            axes[i, 0].set_title(f"Query {i+1}")
            axes[i, 0].axis('off')

            axes[i, 1].axis('off')
            axes[i, 1].axvline(x=0.5, color='black', linewidth=2)
            
            for j in range(n_results):
                result_idx = nearest_indices[i, j]
                axes[i, j+1].imshow(Image.open(target_image_paths[result_idx]).convert('RGB'))
                axes[i, j+1].set_title(f"Result {j+1}")
                axes[i, j+1].axis('off')
        
    plt.tight_layout()
    plt.show()

In [None]:
query_image_fnames = [f for f in os.listdir('../media') if f.startswith('query_image')]
query_image_paths = sorted([f'../media/{f}' for f in query_image_fnames])

for query_image_path in query_image_paths:
    result = detect_objects(query_image_path)
    cropped_images = [r[0] for r in result]

    cropped_image_tensors = torch.stack([transform(image) for image in cropped_images])
    with torch.no_grad():
        cropped_image_embeddings = encoder(cropped_image_tensors.to(device)).tolist()
        cropped_image_embeddings = np.array(cropped_image_embeddings)

    nearest_indices = find_nearest_neighbors(cropped_image_embeddings, target_embeddings)
    visualize_query_and_results(cropped_images, target_image_paths, nearest_indices, 5)