In [96]:
import torch
import torch.nn as nn
import clip
from PIL import Image
import pandas as pd
import requests
import os.path as osp
import pickle
import random
import numpy as np
from pathlib import Path
import sys

### 1 save images

In [86]:
dataset_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/target_store_furniture_datasets.csv"
image_storage = "C:/Users/aphri/Documents/t0002/work/data/w210_data/target_images"
pickle_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/pickle"
model_path = "C:/Users/aphri/Documents/t0002/work/data/w210_data/model"

Path(image_storage).mkdir(parents=True, exist_ok=True)
Path(pickle_path).mkdir(parents=True, exist_ok=True)
Path(model_path).mkdir(parents=True, exist_ok=True)

In [2]:
d1 = pd.read_csv(dataset_path)

In [3]:
def image_path(uid):
    return osp.join(image_storage, f"{uid}.jpg")

In [4]:
count = 0
total = len(d1)
for idx, v in d1[["main_image", "uniq_id"]].iterrows():
    url = v.main_image
    uid = v.uniq_id
    path = image_path(uid)
    if not osp.exists(path):
        image = Image.open(requests.get(url, stream=True).raw)
        image.save(image_path(uid))
    
    count += 1
    if count % 1000 == 0:
        print(f"prcoessed {round(count/total * 100, 2)}%")

prcoessed 2.37%
prcoessed 4.74%
prcoessed 7.11%
prcoessed 9.48%
prcoessed 11.84%
prcoessed 14.21%
prcoessed 16.58%
prcoessed 18.95%
prcoessed 21.32%
prcoessed 23.69%
prcoessed 26.06%
prcoessed 28.43%
prcoessed 30.79%
prcoessed 33.16%
prcoessed 35.53%
prcoessed 37.9%
prcoessed 40.27%
prcoessed 42.64%
prcoessed 45.01%
prcoessed 47.38%
prcoessed 49.75%
prcoessed 52.11%
prcoessed 54.48%
prcoessed 56.85%
prcoessed 59.22%
prcoessed 61.59%
prcoessed 63.96%
prcoessed 66.33%
prcoessed 68.7%
prcoessed 71.06%
prcoessed 73.43%
prcoessed 75.8%
prcoessed 78.17%
prcoessed 80.54%
prcoessed 82.91%
prcoessed 85.28%
prcoessed 87.65%
prcoessed 90.02%
prcoessed 92.38%
prcoessed 94.75%
prcoessed 97.12%
prcoessed 99.49%


### 2 process images and texts

In [27]:
def read_pickle(dir):
    with open(dir, 'rb') as handle:
        b = pickle.load(handle)
    return b


def write_pickle(dir, data):
    with open(dir, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [28]:
def save_processed_data(name, uid_list, text_list, eimage_list, etext_list):
    df = pd.DataFrame(data={
        "uid": uid_list,
        "text": text_list,
        "encoded_image": eimage_list,
        "encoded_text": etext_list
    })
    
    write_pickle(name, df)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [6]:
d1 = d1[["uniq_id", "sub_category_2"]]

In [34]:
uid_list = []
text_list = []
eimage_list = []
etext_list = []

count = 0
total = len(d1)
for idx, row in d1.iterrows():
    uid = row.uniq_id
    text = row.sub_category_2
    
    uid_list.append(uid)
    text_list.append(text)
    try:
        image = preprocess(Image.open(image_path(uid))).unsqueeze(0)
        text = clip.tokenize(text)

        eimage_list.append(image)
        etext_list.append(text)
    except:
        # print(f"failed: {uid}, {text}")
        eimage_list.append(None)
        etext_list.append(None)
    
    count += 1
    if count % 1000 == 0:
        print(f"prcoessed {round(count/total * 100, 2)}%")
        save_processed_data(osp.join(pickle_path, f"{count}.pkl"), uid_list, text_list, eimage_list, etext_list)
        uid_list = []
        text_list = []
        eimage_list = []
        etext_list = []
    
if len(uid_list) > 0:
    save_processed_data(osp.join(pickle_path, f"{count}.pkl"), uid_list, text_list, eimage_list, etext_list)

prcoessed 2.37%
prcoessed 4.74%
prcoessed 7.11%
prcoessed 9.48%
prcoessed 11.84%
prcoessed 14.21%
prcoessed 16.58%
prcoessed 18.95%
prcoessed 21.32%
prcoessed 23.69%
prcoessed 26.06%
prcoessed 28.43%
prcoessed 30.79%
prcoessed 33.16%
prcoessed 35.53%
prcoessed 37.9%
prcoessed 40.27%
prcoessed 42.64%
prcoessed 45.01%
prcoessed 47.38%
prcoessed 49.75%
prcoessed 52.11%
prcoessed 54.48%
prcoessed 56.85%
prcoessed 59.22%
prcoessed 61.59%
prcoessed 63.96%
prcoessed 66.33%
prcoessed 68.7%
prcoessed 71.06%
prcoessed 73.43%
prcoessed 75.8%
prcoessed 78.17%
prcoessed 80.54%
prcoessed 82.91%
prcoessed 85.28%
prcoessed 87.65%
prcoessed 90.02%
prcoessed 92.38%
prcoessed 94.75%
prcoessed 97.12%
prcoessed 99.49%


### 3 fine tune the model

In [103]:
def calc_loss(data, device):
    encoded_images = torch.cat(list(data["encoded_image"].values)).to(device)
    encoded_texts = torch.cat(list(data["encoded_text"].values)).to(device)

    logits_per_image, logits_per_text = model(encoded_images, encoded_texts)

    # during training # images == # texts, so calc only once
    n_classes = logits_per_image.shape[0]
    device = device
    labels = torch.arange(n_classes, device=device, dtype=torch.long)

    loss_image = criterion(logits_per_image, labels)
    loss_text = criterion(logits_per_text, labels)
    curr_loss = (loss_image + loss_text) / 2
    
    return curr_loss


def train(model, pickle_path, pickles, criterion, device, batch_size, optimizer, max_norm):
    
    model.train()
    total_loss = 0
    total_count = 0
    
    for p in pickles:
        print(f"training: {p}")
        file = osp.join(pickle_path, f"{p}.pkl")
        data = read_pickle(file)
        data = data[~data["encoded_text"].isna()]
        
        start = 0
        end = batch_size
        while start < len(data):
            sub_data = data[start:end]
            
            optimizer.zero_grad()
            curr_loss = calc_loss(sub_data, device)
            
            curr_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            
            total_loss += curr_loss.item()
            total_count += len(sub_data)
            
            start = end
            end += batch_size
            sys.stdout.write(".")
        sys.stdout.write("\n")
            
        print(f"average loss: {total_loss/total_count}")
            
            
def evaluate(model, pickle_path, pickles, criterion, device, batch_size):
    
    model.eval()
    total_loss = 0
    total_count = 0
    
    with torch.no_grad():
        for p in pickles:
            print(f"evaluating: {p}")

            file = osp.join(pickle_path, f"{p}.pkl")
            data = read_pickle(file)
            data = data[~data["encoded_text"].isna()]

            start = 0
            end = batch_size
            while start < len(data):
                sub_data = data[start:end]
                curr_loss = calc_loss(sub_data, device)

                total_loss += curr_loss.item()
                total_count += len(sub_data)

                start = end
                end += batch_size
                sys.stdout.write(".")
            sys.stdout.write("\n")

            print(f"average loss: {total_loss/total_count}")
    
    
def run_epoch(
        model,
        pickle_path,
        train_pickles,
        eval_pickles,
        criterion,
        device,
        batch_size,
        optimizer,
        max_norm,
        n_epoch,
        seed=1234,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
    for n in range(n_epoch):
        print(f"training epoch: {n}")
        train(model, pickle_path, train_pickles, criterion, device, batch_size, optimizer, max_norm)
        evaluate(model, pickle_path, eval_pickles, criterion, device, batch_size)

In [None]:
train_pickles = [
    1000, 2000, 3000, 4000, 5000, 6000, 
    7000, 8000, 9000, 10000, 11000, 12000, 
    13000, 14000, 15000, 16000, 17000, 18000, 
    19000, 20000, 21000, 22000, 23000, 24000, 
    25000, 26000, 27000, 28000, 29000, 30000, 
    31000, 32000, 33000, 34000, 35000, 36000, 
    37000, 38000, 39000 
]
eval_pickles = [40000, 41000, 42215]
learning_rate = 0.01
criterion = nn.CrossEntropyLoss().to(device)
batch_size = 20
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
max_norm = 1 # for gradient clipping
n_epoch = 10

run_epoch(model, pickle_path, train_pickles, eval_pickles, criterion, device, batch_size, optimizer, max_norm, n_epoch)
torch.save(
    model.state_dict(),
    str(osp.join(model_path, 'model.pt'))
)

training epoch: 0
training: 1000
..................................................
average loss: 0.15060437881585323
training: 2000
..................................................
average loss: 0.15064976537354258
training: 3000
..................................................
average loss: 0.15059188721998246
training: 4000
..................................................
average loss: 0.15061008336270285
training: 5000
..................................................
average loss: 0.15062650029475871
training: 6000
.................................................
average loss: 0.1504874652100576
training: 7000
.................................................
average loss: 0.15040228786391094
training: 8000
..................................................
average loss: 0.15043451290394316
training: 9000
..................................................
average loss: 0.150459579973052
training: 10000
..................................................
average loss: 0.1504