In [1]:
import torch
import torch.nn as nn
import clip
from PIL import Image
import pandas as pd
import requests
import os.path as osp
import pickle
import random
import numpy as np
from pathlib import Path
import sys
from operator import itemgetter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import time
import shutil

  warn(f"Failed to load image Python extension: {e}")


In [2]:
def read_pickle(dir):
    with open(dir, 'rb') as handle:
        b = pickle.load(handle)
    return b


def write_pickle(dir, data):
    with open(dir, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

class Timer:
    def __init__(self):

        self.t1 = None

    @staticmethod
    def delta_to_string(td):

        res_list = []

        def format():
            return ", ".join(reversed(res_list)) + " elapsed."

        seconds = td % 60
        td //= 60
        res_list.append(f"{round(seconds,3)} seconds")

        if td <= 0:
            return format()

        minutes = td % 60
        td //= 60
        res_list.append(f"{minutes} minutes")

        if td <= 0:
            return format()

        hours = td % 24
        td //= 24
        res_list.append(f"{hours} hours")

        if td <= 0:
            return format()

        res_list.append(f"{td} days")

        return format()

    def __enter__(self):

        self.t1 = time.time()

    def __exit__(self, *args, **kwargs):

        t2 = time.time()
        td = t2 - self.t1

        print(self.delta_to_string(td))


def top_n(input_dict, n):
    return dict(sorted(input_dict.items(), key=itemgetter(1), reverse=True)[:n])


def find_products(text_input, category_df, image_pickle_path):

    text_input = [text_input]

    # stage one, compare categories
    category_df = category_df[~category_df["encoded_category"].isna()]
    categories = list(category_df["category"].values)

    categories_features = torch.stack(list(category_df["encoded_category"].values))
    encoded_texts = clip.tokenize(text_input).to(device)

    with torch.no_grad():

        text_features = model.encode_text(encoded_texts)

        categories_features /= categories_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity =  100 * categories_features @ text_features.T

    res = dict(zip(categories, similarity.reshape(-1).tolist()))

    res = sorted(res.items(), key=itemgetter(1), reverse=True)

    n = 10
    res = res[:n]
    res_set = set([r[0] for r in res])
    
    # do image matching
    res = []
    for cat in res_set:
        store_path = osp.join(image_pickle_path, f"{cat}.pkl")
        cat_res = read_pickle(store_path)
        res.append(cat_res)
    res = pd.concat(res, axis=0)
    
    uniq_ids = list(res["uid"].values)
    image_features = torch.stack(list(res["encoded_image"].values))
    similarity =  100 * image_features @ text_features.T
    res = dict(zip(uniq_ids, similarity.reshape(-1).tolist()))
    res = sorted(res.items(), key=itemgetter(1), reverse=True)
    
    n = 5
    res = res[:n]
    res_set = set([r[0] for r in res])
    
    return res_set


def show_images(res):
    n = len(res)
    fig, ax = plt.subplots(1, n)

    fig.set_figheight(5)
    fig.set_figwidth(5 * n)
    
    iterable = True
    try:
       it = ax[0]
    except:
        iterable = False

    if not iterable:
        img_path = image_path(res[0])
        img = mpimg.imread(img_path)
        ax.imshow(img)
        ax.axis("off")
    else:
        for i, image in enumerate(res):
            img_path = image_path(image)
            img = mpimg.imread(img_path)

            ax[i].imshow(img)
            ax[i].axis('off')
            # ax[i].set_title(get_label(image), fontsize=8)

    plt.subplots_adjust(wspace=0, hspace=0.1)
    plt.show()
    
    
def image_path(uid):
    return osp.join(image_storage, f"{uid}.jpg")


def load_data(pickle_path):
    category_df = read_pickle(osp.join(pickle_path, "categories.pkl"))
    meta_df = read_pickle(osp.join(pickle_path, "meta_data.pkl"))
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    
    return device, model, preprocess, category_df, meta_df

In [3]:
def combine_category(x):
    res = x["sub_category_1"]
    if not pd.isna(x["sub_category_2"]):
        res += ", " + x["sub_category_2"]
    if not pd.isna(x["sub_category_3"]):
        res += ", " + x["sub_category_3"]
    return res

In [4]:
def clean_data(d1):
    d1 = d1[d1["primary_category"] == "Furniture"]
    d1 = d1[~(d1["primary_category"].isna() & d1["sub_category_1"].isna() & d1["sub_category_2"].isna() & d1["sub_category_3"].isna())]
    d1 = d1[~(d1["sub_category_1"].isna() & d1["sub_category_2"].isna() & d1["sub_category_3"].isna())]
    d1 = d1[~d1["description"].isna()]
    d1 = d1[~d1["colors"].isna()]
    d1["colors"] = d1["colors"].astype(str)
    d1["combined_category"] = d1.apply(combine_category, axis=1)
    return d1

In [5]:
def comb(x):
    res = ""
    if not pd.isna(x["colors"]):
        res += x["colors"]
    if not pd.isna(x["material"]):
        res += " " + x["material"]
    if not pd.isna(x["sub_category_2"]):
        res += " " + x["sub_category_2"]
    
    return res

In [22]:
def is_match(answer, res):
    s1 = {tuple(v) for v in d10[d10.uniq_id.isin([answer])][["sub_category_2", "colors"]].values}
    s2 = {tuple(v) for v in d10[d10.uniq_id.isin(res)][["sub_category_2", "colors"]].values}
    
    ans = len(s1.intersection(s2))
    return ans > 0

In [14]:
image_storage = "demo_data/target_images"
pickle_path = "demo_data/data3_pickle"
image_pickle_path = "demo_data/data3_image_pickle"
dataset_path = "data/cleaned_target_furniture_dataset.csv"

In [15]:
eval_path = "demo_data/eval_res2_comb_no_finetune.pkl"

In [9]:
with Timer():
    (
        device,
        model, 
        preprocess,
        category_df,
        meta_df
    ) = load_data(pickle_path)

1.0 minutes, 22.016 seconds elapsed.


In [23]:
d1 = pd.read_csv(dataset_path)
d1 = clean_data(d1)
d10 = d1[["uniq_id", "sub_category_2", "colors"]]
d1 = d1[d1["combined_category"] == "Home Office Furniture, Bookshelves & Bookcases"]

d1 = d1[["uniq_id", "sub_category_2", "material", "colors"]]
d1["comb"] = d1.apply(comb, axis=1)

In [24]:
curr_total = 0
curr_right = 0
right_set = set()

In [25]:
df1 = d1
count = 0
for idx, row in df1.iterrows():
    count += 1
    if count < curr_total + 1:
        continue
    
    query = row["comb"]
    answer = row["uniq_id"]
    res = find_products(query, category_df, image_pickle_path)
    curr_total += 1
    if is_match(answer, res):
        curr_right += 1
        right_set.add(answer)
    
    if curr_total % 200 == 0:
        print(f"{curr_total} current accuracy: {round(curr_right * 100/curr_total, 2)}%")

200 current accuracy: 72.5%
400 current accuracy: 71.75%
600 current accuracy: 74.5%
800 current accuracy: 74.38%
1000 current accuracy: 74.7%


In [26]:
res = dict(
    total=curr_total,
    right=curr_right,
    right_set=right_set
)
write_pickle(eval_path, res)

In [27]:
round(res["right"] * 100 / res["total"], 2)

74.71

In [28]:
res["total"]

1103