In [1]:
import os
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import torch

In [2]:
is_kaggle = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ
path = Path('../input/visual-taxonomy') if is_kaggle else Path('data')

In [3]:
cat_info = pd.read_parquet(path/'category_attributes.parquet')
df = pd.read_csv(path/'train.csv')
test_df = pd.read_csv(path/'test.csv')

In [4]:
def add_img_path(id_, img_folder):
    return str(path/f'{img_folder}/{str(id_).zfill(6)}.jpg')

df['img_path'] = df['id'].apply(add_img_path, img_folder='train_images')

In [5]:
class LabelEncoder:
    def __init__(self, df):
        self.vocab = df.drop(columns=['id','len']).groupby('Category').agg(lambda col: col.unique())
        self.label2id_map, self.id2label_map = {}, {}
        for cat, row in self.vocab.iterrows():
            self.label2id_map[cat] = {
                attr:{str(label):idx for idx, label in enumerate(labels)}
                for attr, labels in row.items()
            }
            self.id2label_map[cat] = {
                attr:{idx:str(label) for idx, label in enumerate(labels)}
                for attr, labels in row.items()
            }
    
    def label2id(self, cat, attr,  label):
        return self.label2id_map[cat][attr][label]

    def id2label(self, cat, attr, id_):
        return self.id2label_map[cat][attr][id_]

    def num_classes(self, cat):
        return [len(labels) for labels in self.vocab.loc[cat].values if len(labels)>1]

In [6]:
def get_labels(row, encoder):
    cat = row['Category']
    num_attrs = row['len']
    ids = []
    for i in range(num_attrs):
        attr = f'attr_{i+1}'
        id_ = encoder.label2id(cat, attr, str(row[attr]))
        ids.append(id_)
    return tuple(ids)

encoder = LabelEncoder(df)
df['labels'] = df.apply(get_labels, axis=1, encoder=encoder)

In [7]:
df.sample(3)

Unnamed: 0,id,Category,len,attr_1,attr_2,attr_3,attr_4,attr_5,attr_6,attr_7,attr_8,attr_9,attr_10,img_path,labels
296,296,Men Tshirts,5,default,polo,solid,solid,short sleeves,,,,,,../input/visual-taxonomy/train_images/000296.jpg,"(0, 1, 1, 1, 0)"
54581,54747,Women Tops & Tunics,10,default,fitted,regular,v-neck,casual,default,solid,long sleeves,regular sleeves,,../input/visual-taxonomy/train_images/054747.jpg,"(3, 1, 2, 5, 1, 1, 1, 4, 0, 0)"
65146,65312,Women Tops & Tunics,10,multicolor,fitted,regular,square neck,casual,printed,default,short sleeves,puff sleeves,,../input/visual-taxonomy/train_images/065312.jpg,"(12, 1, 2, 6, 1, 2, 4, 1, 4, 0)"


In [11]:
def process_df(df, path, encoder, is_test_df=False):
    df = df.copy()
    img_folder = 'test_images' if is_test_df else 'train_images'
    df['img_path'] = df['id'].apply(add_img_path, path=path, img_folder=img_folder)
    if is_test_df: return df
    df['labels'] = df.apply(get_labels, axis=1, encoder=encoder)
    return df

In [12]:
from transformers import BlipImageProcessor, BertTokenizerFast

ckpt = "Salesforce/blip-itm-base-coco"
img_processor = BlipImageProcessor.from_pretrained(ckpt)
tokenizer = BertTokenizerFast.from_pretrained(ckpt)

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/456 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



In [13]:
tokenizer.add_special_tokens({'additional_special_tokens':['[Encode]']})
encode_token_id = tokenizer.convert_tokens_to_ids('[Encode]')

In [14]:
def add_question_and_tokenize(row, tokenizer, encode_token_id):
    cat = row.name
    attrs = row['Attribute_list']
    question_lines = [f'For this image of {cat[:-1]}, please answer the following:']
    question_lines.extend([f'[Encode] What is the {attr}?' for attr in attrs])
    question = '\n'.join(question_lines)
    
    tokenized = tokenizer(question, padding=True, truncation=True, return_tensors='pt')
    tokenized['special_token_mask'] = (tokenized['input_ids']==encode_token_id).squeeze()
    tokenized = {k:v.tolist() for k, v in tokenized.items()}

    return pd.Series({**row.to_dict(), 'question':question, **tokenized})

cat_info.set_index('Category', inplace=True)
cat_info = cat_info.apply(add_question_and_tokenize, axis=1, args=(tokenizer, encode_token_id))

In [15]:
def process_cat_info(cat_info, tokenizer):
    cat_info = cat_info.copy()
    tokenizer.add_special_tokens({'additional_special_tokens':['[Encode]']})
    encode_token_id = tokenizer.convert_tokens_to_ids('[Encode]')
    cat_info.set_index('Category', inplace=True)
    return cat_info.apply(add_question_and_tokenize, axis=1, args=(tokenizer, encode_token_id))

In [16]:
from PIL import Image
import numpy as np

class Batch:
    def __init__(self, category, input_ids, special_token_mask, attention_mask, pixel_values, labels):
        self.category = category
        self.input_ids = input_ids
        self.special_token_mask = special_token_mask
        self.attention_mask = attention_mask
        self.pixel_values = pixel_values
        self.labels = labels

    def to(self, device):
        self.input_ids = self.input_ids.to(device)
        self.special_token_mask = self.special_token_mask.to(device)
        self.attention_mask = self.attention_mask.to(device)
        self.pixel_values = self.pixel_values.to(device)

In [17]:
class MeeshoDataloader:
    def __init__(self, df, cat_info, batch_size, img_processor):
        self.df = df
        self.cat_info = cat_info
        self.bs = batch_size
        self.img_processor = img_processor

    def __len__(self):
        return sum((len(idxs)+self.bs-1)//self.bs for idxs in self.group2idxs.values())

    def _sample(self):
        group2idxs = self.df.groupby("Category").apply(lambda group: group.index.tolist())
        for cat in np.random.permutation(group2idxs.index):
            idxs = group2idxs[cat]
            for i in range(0, len(idxs), self.bs):
                yield idxs[i:i+self.bs]

    def __iter__(self):
        for idxs in self._sample():
            cat = self.df.loc[idxs[0], 'Category']
            
            input_ids = self.cat_info.loc[cat,'input_ids']
            special_token_mask = self.cat_info.loc[cat,'special_token_mask']
            attention_mask = self.cat_info.loc[cat,'attention_mask']
            
            images = [Image.open(path) for path in self.df.loc[idxs,'img_path']]
            pixel_values = self.img_processor(images=images, return_tensors='pt', size=(224,224))
            
            labels = self.df.loc[idxs,'labels'] if 'labels' in self.df.columns else None
            yield Batch(category=cat, input_ids=torch.tensor(input_ids),
                        special_token_mask=torch.tensor(special_token_mask),
                        attention_mask=torch.tensor(attention_mask),
                        pixel_values=pixel_values.pixel_values,
                        labels=torch.tensor(labels.tolist()))

In [18]:
batch_size = 32
dl = MeeshoDataloader(df, cat_info, batch_size=batch_size, img_processor=img_processor)

In [20]:
batch = next(iter(dl))
batch.category, batch.pixel_values.size(), batch.special_token_mask.size(), batch.labels.size()

('Women Tshirts',
 torch.Size([32, 3, 224, 224]),
 torch.Size([78]),
 torch.Size([32, 8]))