In [1]:
import numpy as np
import torch
import pandas as pd
import json
import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import typing as tp

from dataclasses import dataclass
from PIL import Image
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity

from deepfashion import read_splits as read_deepfashion_splits, Crop
from load_model import load_model
from utils import apk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
@dataclass
class Query:
    crop: Crop
    item_id: int
    embedding: np.ndarray

@dataclass
class GalleryItem:
    crop: Crop
    item_id: int
    embedding: np.ndarray

In [3]:
def save_results(same_category_apks, same_item_apks, same_style_apks, filename):
    with open(filename, "w") as f:
        json.dump(
            {
                "same_category": {k: np.mean(val) for k, val in same_category_apks.items()},
                "same_item": {k: np.mean(val) for k, val in same_item_apks.items()},
                "same_style": {k: np.mean(val) for k, val in same_style_apks.items()},
            },
            f,
            indent=4,
        )


In [4]:
DEEP_FASHION_DIR = ""

In [5]:
MODEL_NAME = "convnextv2_base"
MODEL_CHECKPOINT = ""

In [6]:
ks = [1, 5, 10, 20]

In [7]:
model, transform = load_model(MODEL_NAME, MODEL_CHECKPOINT, is_wrapped_checkpoint=True)

In [8]:
model = model.eval()

In [9]:
items_data = read_deepfashion_splits(DEEP_FASHION_DIR, ["validation"], False)

Reading validation split


  2%|▏         | 550/32153 [00:00<00:21, 1440.01it/s]

100%|██████████| 32153/32153 [00:19<00:00, 1659.35it/s]


In [10]:
device = torch.device("cuda:0")

In [11]:
model = model.to(device)

In [12]:
querries = []
gallery = []

for item_id, crops in tqdm.tqdm(items_data.items()):
    transformed = []
    for crop in crops:
        img = Image.open(crop.crop_file)
        transformed.append(transform(img))
    transformed = torch.stack(transformed).squeeze(1).to(device)
    with torch.no_grad():
        embeds = model(transformed).cpu().numpy()
    for i, crop in enumerate(crops):
        if crop.source == "user":
            querries.append(
                Query(
                    crop,
                    item_id,
                    embeds[i,:],
                )
            )
        else:
            gallery.append(
                GalleryItem(
                    crop,
                    item_id,
                    embeds[i,:],
                )
            )

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 2279/2279 [07:31<00:00,  5.05it/s]


In [13]:
gallery_embeds = np.stack([g.embedding for g in gallery])

In [14]:
gallery_counts = defaultdict(int)
for g in gallery:
    gallery_counts[g.item_id] += 1

In [64]:
same_category_apks = defaultdict(list)
same_item_apks = defaultdict(list)
same_style_apks = defaultdict(list)
for q in tqdm.tqdm(querries):
    csim = cosine_similarity(q.embedding[None,:], gallery_embeds)
    gallery_ids = np.argsort(-csim)[0]
    is_relevant = [(gallery[gid].crop.category_id == q.crop.category_id) for gid in gallery_ids]
    for k in ks:
        same_category_apks[k].append(apk(is_relevant, k))
    is_relevant = [(gallery[gid].item_id == q.item_id) for gid in gallery_ids]
    for k in ks:
        same_item_apks[k].append(apk(is_relevant, k))
    is_relevant = [(gallery[gid].item_id == q.item_id) and
                   (gallery[gid].crop.item_style == q.crop.item_style)
                   for gid in gallery_ids]
    for k in ks:
        same_style_apks[k].append(apk(is_relevant, k))

100%|██████████| 10844/10844 [30:55<00:00,  5.84it/s]


In [65]:
save_results(same_category_apks, same_item_apks, same_style_apks, "validation_results/res.json")