In [1]:
import torch
import torch.nn as nn
import clip
from PIL import Image
import pandas as pd
import requests
import os.path as osp
import pickle
import random
import numpy as np
from pathlib import Path
import sys
from operator import itemgetter
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import time
import shutil

  warn(f"Failed to load image Python extension: {e}")


In [3]:
def read_pickle(dir):
    with open(dir, 'rb') as handle:
        b = pickle.load(handle)
    return b


def write_pickle(dir, data):
    with open(dir, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [4]:
def image_path(uid):
    return osp.join(image_storage, f"{uid}.jpg")

In [11]:
def combine_category(x):
    res = x["sub_category_1"]
    if not pd.isna(x["sub_category_2"]):
        res += ", " + x["sub_category_2"]
    if not pd.isna(x["sub_category_3"]):
        res += ", " + x["sub_category_3"]
    
    return res


def clean_data(d1):
    d1 = d1[d1["primary_category"] == "Furniture"]
    d1 = d1[~(d1["primary_category"].isna() & d1["sub_category_1"].isna() & d1["sub_category_2"].isna() & d1["sub_category_3"].isna())]
    d1 = d1[~(d1["sub_category_1"].isna() & d1["sub_category_2"].isna() & d1["sub_category_3"].isna())]
    d1 = d1[~d1["description"].isna()]
    d1 = d1[~d1["colors"].isna()]
    d1["colors"] = d1["colors"].astype(str)
    d1["combined_category"] = d1.apply(combine_category, axis=1)
    return d1

In [16]:
def combine_title(x):
    res = x["title"]
    if not pd.isna(x["colors"]):
        res += ", " + x["colors"]
    if not pd.isna(x["material"]):
        res += ", " + x["material"]
    
    return res

In [27]:
dataset_path = "data/cleaned_target_furniture_dataset.csv"
pickle_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/data_finetuned_pickle"
image_pickle_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/data_finetuned_image_pickle"
model_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/finetuned_model/finetuned_model.pt"
image_storage = "C:/Users/aphri/Documents/t0002/work/data/w210_data/target_images"

Path(pickle_path).mkdir(parents=True, exist_ok=True)
Path(image_pickle_path).mkdir(parents=True, exist_ok=True)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [23]:
if osp.exists(model_path):
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint["model_state_dict"])
    # optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [12]:
d1 = pd.read_csv(dataset_path)
d1 = clean_data(d1)

In [13]:
d1 = d1[d1["combined_category"] == "Home Office Furniture, Bookshelves & Bookcases"]

In [14]:
len(d1)

1103

In [17]:
d1["combined_title"] = d1.apply(combine_title, axis=1)

In [18]:
d1["combined_title"].iloc[0]

'Best Choice Products 9-Cube Bookshelf, Display Storage System, Compartment Organizer w/ 3 Removable Back Panels - Black, black, MDF (Medium-Density Fiberboard) '

In [20]:
d2 = d1[["uniq_id", "combined_category", "combined_title"]]
d2.columns = ["uniq_id", "combined_category", "title"]
write_pickle(osp.join(pickle_path, f"meta_data.pkl"), d2)

In [24]:
cat_set = set(d1["combined_category"].values)

cat_list = []
ecat_list = []

for cat in cat_set:
    
    ecat = clip.tokenize(cat)
    cat_list.append(cat)
    ecat_list.append(ecat)
    
ecat_list = torch.cat(ecat_list).to(device)
with torch.no_grad():
    ecat_list = list(model.encode_text(ecat_list))

print(f"saving encoded categories")
df = pd.DataFrame(data={
    "category": cat_list,
    "encoded_category": ecat_list
})

write_pickle(osp.join(pickle_path, f"categories.pkl"), df)

saving encoded categories


In [25]:
title_set = set(d1["combined_title"].values)

title_list = []
etitle_list = []

for title in title_set:
    
    etitle = clip.tokenize(title)
    title_list.append(title)
    etitle_list.append(etitle)
    
res = []
chunk = 500
idx = 0
total_len = len(etitle_list)
while True:
    print(f"{round(idx*100/total_len, 2)}%")
    if idx >= len(etitle_list):
        break
    curr_list = torch.cat(etitle_list[idx:idx+chunk]).to(device)
    with torch.no_grad():
        curr_list = list(model.encode_text(curr_list))
        res += curr_list
    idx += chunk

print(f"saving encoded titles")
df = pd.DataFrame(data={
    "title": title_list,
    "encoded_title": res
})

write_pickle(osp.join(pickle_path, f"titles.pkl"), df)

0.0%
45.5%
90.99%
136.49%
saving encoded titles


In [28]:
total_len = len(cat_set)
cidx = 0
printed_cidx = set()
for cat in cat_set:
    
    pct = int(round(cidx*100/total_len, 0))
    if pct % 10 == 0 and pct not in printed_cidx:
        print(f"{pct}%")
        printed_cidx.add(pct)
    
    cidx += 1
    
    store_path = osp.join(image_pickle_path, f"{cat}.pkl")
    if osp.exists(store_path):
        continue

    uid_list = []
    for idx, row in d1[d1["combined_category"] == cat].iterrows():
        uid = row.uniq_id
        uid_list.append(uid)

    image = torch.cat([preprocess(Image.open(image_path(uid))).unsqueeze(0) for uid in uid_list]).to(device)
    with torch.no_grad():
        eimage_list = list(model.encode_image(image))
        
    if len(uid_list) > 0:
        df = pd.DataFrame(data={
            "uid": uid_list,
            "encoded_image": eimage_list
        })
        write_pickle(store_path, df)

0%
