In [2]:
import os
import random
import pandas as pd
from tqdm import tqdm
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import DataLoader

from utils.dataset import PolyvoreDataset, DatasetArgs
from model.model import *
from model.encoder import ItemEncoder

import matplotlib.pyplot as plt

from transformers import AutoTokenizer
import cv2

import json
from tqdm import tqdm

In [18]:
class Preprocessor():
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-albert-small-v2')
        self.max_token_len=16
        self.transform = A.Compose([A.Resize(224, 224), A.Normalize(), ToTensorV2()])
        self.query_img_path = os.path.join('F:/Projects/outfit-transformer', 'data', 'query_img.jpg')

    def _load_img(self, path):
        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transform(image=img)['image'].unsqueeze(0)
        return img
    
    def preprocess_item(self, item_dict):
        img_path, desc = item_dict['image_path'], item_dict['desc']

        img = self._load_img(img_path)
        input_ids, _, attention_mask, *_  = self.tokenizer(desc, max_length=self.max_token_len, padding='max_length', truncation=True, return_tensors='pt').values()
        return (img, input_ids, attention_mask)

    def preprocess_compatibility_input(self, transformer_item_dicts): 
        img = []
        input_ids = []
        attention_mask = []

        for category, item_dict in transformer_item_dicts.items():
            if not item_dict['desc']:
                item_dict['desc'] = str(category)

            item = self.preprocess_item(item_dict)
            img.append(item[0])
            input_ids.append(item[1])
            attention_mask.append(item[2])
        
        img = torch.concat(img, dim=0)
        input_ids = torch.concat(input_ids, dim=0)
        attention_mask = torch.concat(attention_mask, dim=0)

        return img, input_ids, attention_mask

    def preprocess_fitb_input(self, target_desc, transformer_item_dicts): 
        img = []
        input_ids = []
        attention_mask = []

        query_dict = {
            'image_path': self.query_img_path,
            'desc': target_desc
            }

        item = self.preprocess_item(query_dict)
        img.append(item[0])
        input_ids.append(item[1])
        attention_mask.append(item[2])

        for category, item_dict in transformer_item_dicts.items():
            if not item_dict['desc']:
                item_dict['desc'] = str(category)

            item = self.preprocess_item(item_dict)
            img.append(item[0])
            input_ids.append(item[1])
            attention_mask.append(item[2])
        
        img = torch.concat(img, dim=0)
        input_ids = torch.concat(input_ids, dim=0)
        attention_mask = torch.concat(attention_mask, dim=0)

        return img, input_ids, attention_mask

In [19]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

preprocessor = Preprocessor()
model = OutfitTransformer(embedding_dim=128).to(device)
encoder = ItemEncoder(embedding_dim=128).to(device)
model.eval()
model.eval()

save_path= 'F:\Projects\outfit-transformer\model\saved_model'
model_name = 'checkpoint_FITB_2023-12-25_0_0.371'

checkpoint = torch.load(os.path.join(save_path, f'{model_name}.pth'))
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
encoder.load_state_dict(checkpoint['encoder_state_dict'], strict=False)



<All keys matched successfully>

In [20]:
data_dir='F:/Projects/datasets/polyvore_outfits'
# outfit_data_path = json.load(open(os.path.join(data_dir, 'polyvore_item_metadata.json')))
meta_data = json.load(open(os.path.join(data_dir, 'polyvore_item_metadata.json')))

In [26]:
transformer_item_dicts = {
    'top': {
        'image_path': '~/1212.jpg',
        'desc': 'dfdsf'
        },
    'bottom' : {
        'image_path': '~/1212.jpg',
        'desc': 'dfdsf'
        }
    }

item_dict = {
    'image_path': '~/1212.jpg',
    'desc': 'dfdsf'
    }


embeds = {}
with torch.no_grad():
    for id, data in tqdm(meta_data.items()):
        img_path = os.path.join(data_dir, 'images', f'{id}.jpg')
        
        item_dict = {
            'image_path': img_path,
            'desc': data['url_name']
            }

        item = preprocessor.preprocess_item(item_dict)
        embed = encoder(item[0].to(device),
                        item[1].to(device),
                        item[2].to(device))
        embeds[id] = embed

  0%|          | 414/251008 [00:07<1:19:25, 52.59it/s]


KeyboardInterrupt: 

In [28]:
embeds

{'211990161': tensor([[ 0.0591,  0.0110,  0.0507,  0.2518, -0.0626,  0.2710, -0.1453, -0.1677,
           0.0083, -0.1675,  0.0201,  0.0303,  0.1181, -0.0740,  0.1095, -0.0420,
           0.1090,  0.0622, -0.1296,  0.0447, -0.0017, -0.0370,  0.0106, -0.0959,
           0.0890, -0.0061,  0.1069,  0.1158, -0.1148,  0.0204,  0.0786,  0.1994,
          -0.0356, -0.2912, -0.2796, -0.1847, -0.0747, -0.0021, -0.0263, -0.2919,
           0.0748,  0.2191,  0.0868, -0.0493,  0.0692, -0.0683, -0.0571, -0.0049,
          -0.0929, -0.0836,  0.0213,  0.1592, -0.0932, -0.0975,  0.2266,  0.0607,
          -0.1993,  0.1942, -0.1474, -0.0104,  0.1307, -0.0662,  0.0229,  0.0412,
           0.0320,  0.0326, -0.1665, -0.0777,  0.0475,  0.0731,  0.2797,  0.0311,
          -0.1340, -0.0403, -0.1896, -0.1690, -0.0581,  0.0099, -0.0085,  0.1036,
           0.2123, -0.1225, -0.0896,  0.0557,  0.1332,  0.1657,  0.0594, -0.0307,
           0.0737, -0.0355,  0.0692, -0.0014, -0.0633,  0.0352,  0.2442,  0.2521,
   