In [2]:
import os
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import torch

In [3]:
is_kaggle = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ
path = Path('../input/visual-taxonomy') if is_kaggle else Path('data')

In [4]:
cat_info = pd.read_parquet(path/'category_attributes.parquet')
df = pd.read_csv(path/'train.csv')
test_df = pd.read_csv(path/'test.csv')

In [5]:
def add_img_path(id_, path, img_folder):
    return str(path/f'{img_folder}/{str(id_).zfill(6)}.jpg')

df['img_path'] = df['id'].apply(add_img_path, path=path, img_folder='train_images')

In [6]:
class LabelEncoder:
    def __init__(self, df):
        attr_cols = [col for col in df.columns if col.startswith('attr')]
        self.unique_labels = df.groupby('Category')[attr_cols].agg(lambda col: col.unique())
        self.label2id_map = {}
        self.id2label_map = {}
        for category, row in self.unique_labels.iterrows():
            self.label2id_map[category] = {
                attr:{str(label):idx for idx, label in enumerate(labels)}
                for attr, labels in row.items()
            }
            self.id2label_map[category] = {
                attr:{idx:str(label) for idx, label in enumerate(labels)}
                for attr, labels in row.items()
            }

    def label2id(self, category, attr,  label):
        return self.label2id_map[category][attr][label]

    def id2label(self, category, attr, id_):
        return self.id2label_map[category][attr][id_]

In [7]:
def get_labels(row, encoder):
    cat = row['Category']
    labels = []
    for i in range(10):
        attr = f'attr_{i+1}'
        label = str(row[attr])
        label = encoder.label2id(cat, attr, label)
        labels.append(label)
    return tuple(labels)

encoder = LabelEncoder(df)
df['labels'] = df.apply(get_labels, axis=1, encoder=encoder)

In [9]:
df.sample(5)

Unnamed: 0,id,Category,len,attr_1,attr_2,attr_3,attr_4,attr_5,attr_6,attr_7,attr_8,attr_9,attr_10,img_path,labels
59275,59441,Women Tops & Tunics,10,black,boxy,crop,round neck,casual,printed,typography,short sleeves,regular sleeves,,../input/visual-taxonomy/train_images/059441.jpg,"(0, 2, 1, 2, 1, 2, 2, 1, 0, 0)"
37236,37402,Women Tshirts,8,pink,loose,regular,printed,funky print,short sleeves,regular sleeves,,,,../input/visual-taxonomy/train_images/037402.jpg,"(4, 0, 2, 2, 4, 2, 0, 0, 0, 0)"
19774,19939,Sarees,10,,zari,small border,cream,traditional,,,zari woven,peacock,yes,../input/visual-taxonomy/train_images/019939.jpg,"(1, 1, 0, 1, 1, 1, 1, 0, 5, 1)"
43060,43226,Women Tshirts,8,multicolor,,crop,,solid,,,,,,../input/visual-taxonomy/train_images/043226.jpg,"(0, 2, 1, 3, 2, 3, 1, 0, 0, 0)"
2126,2126,Men Tshirts,5,multicolor,round,,default,short sleeves,,,,,,../input/visual-taxonomy/train_images/002126.jpg,"(1, 0, 2, 0, 0, 0, 0, 0, 0, 0)"


In [10]:
def process_df(df, path, is_test_df=False):
    img_folder = 'test_images' if is_test_df else 'train_images'
    df['img_path'] = df['id'].apply(add_img_path, path, img_folder=img_folder)
    encoder = LabelEncoder(df)
    df['labels'] = df.apply(get_labels, axis=1, encoder=encoder)
    return df

In [11]:
from transformers import BlipImageProcessor, BertTokenizerFast

ckpt = "Salesforce/blip-itm-base-coco"
img_processor = BlipImageProcessor.from_pretrained(ckpt)
tokenizer = BertTokenizerFast.from_pretrained(ckpt)



In [12]:
tokenizer.add_special_tokens({'additional_special_tokens':['[Encode]']})
encode_token_id = tokenizer.convert_tokens_to_ids('[Encode]')

In [13]:
def add_question_and_tokenize(row, tokenizer, encode_token_id):
    cat = row.name
    attrs = row['Attribute_list']
    question_lines = [f'For this image of {cat[:-1]}, please answer the following:']
    question_lines.extend([f'[Encode] What is the {attr}?' for attr in attrs])
    question = '\n'.join(question_lines)
    
    tokenized = tokenizer(question, padding=True, truncation=True, return_tensors='pt')
    tokenized['special_token_mask'] = tokenized['input_ids']==encode_token_id
    tokenized = {k:v.tolist() for k, v in tokenized.items()}

    return pd.Series({**row.to_dict(), 'question':question, **tokenized})

cat_info.set_index('Category', inplace=True)
cat_info = cat_info.apply(add_question_and_tokenize, axis=1, args=(tokenizer, encode_token_id))

In [14]:
def process_cat_info(cat_info, tokenizer):
    tokenizer.add_special_tokens({'additional_special_tokens':['[Encode]']})
    encode_token_id = tokenizer.convert_tokens_to_ids('[Encode]')
    cat_info.set_index('Category', inplace=True)
    return cat_info.apply(add_question_and_tokenize, axis=1, args=(tokenizer, encode_token_id))

In [15]:
from PIL import Image
from dataclasses import dataclass
import numpy as np

@dataclass
class MiniBatch:
    category: str
    input_ids: torch.FloatTensor
    special_token_mask: torch.FloatTensor
    attention_mask: torch.FloatTensor
    pixel_values: torch.FloatTensor
    labels: torch.FloatTensor

In [16]:
class MeeshoDataloader:
    def __init__(self, df, cat_info, batch_size, img_processor):
        self.df = df
        self.cat_info = cat_info
        self.bs = batch_size
        self.img_processor = img_processor

    def __len__(self):
        return sum((len(idxs)+self.bs-1)//self.bs for idxs in self.group2idxs.values())

    def _sample(self):
        group2idxs = self.df.groupby("Category").apply(lambda group: group.index.tolist())
        for cat in np.random.permutation(group2idxs.index):
            idxs = group2idxs[cat]
            for i in range(0, len(idxs), self.bs):
                yield idxs[i:i+self.bs]

    def __iter__(self):
        for idxs in self._sample():
            cat = self.df.loc[idxs[0], 'Category']
            
            input_ids = self.cat_info.loc[cat,'input_ids']
            special_token_mask = self.cat_info.loc[cat,'special_token_mask']
            attention_mask = self.cat_info.loc[cat,'attention_mask']
            
            images = [Image.open(path) for path in self.df.loc[idxs,'img_path']]
            pixel_values = self.img_processor(images=images, return_tensors='pt', size=(224,224))
            
            labels = self.df.loc[idxs,'labels']
            yield MiniBatch(category=cat, input_ids=torch.tensor(input_ids),
                            special_token_mask=torch.tensor(special_token_mask),
                            attention_mask=torch.tensor(attention_mask),
                            pixel_values=pixel_values.pixel_values,
                            labels=torch.tensor(labels.tolist()))

In [17]:
batch_size = 32
dl = MeeshoDataloader(df, cat_info, batch_size=batch_size, img_processor=img_processor)

In [18]:
for batch in dl:
    print(batch)
    break

MiniBatch(category='Men Tshirts', input_ids=tensor([[  101,  2005,  2023,  3746,  1997,  2273, 24529, 11961,  2102,  1010,
          3531,  3437,  1996,  2206,  1024, 30522,  2054,  2003,  1996,  3609,
          1029, 30522,  2054,  2003,  1996,  3300,  1029, 30522,  2054,  2003,
          1996,  5418,  1029, 30522,  2054,  2003,  1996,  6140,  1035,  2030,
          1035,  5418,  1035,  2828,  1029, 30522,  2054,  2003,  1996, 10353,
          1035,  3091,  1029,   102]]), special_token_mask=tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False,  True, False, False, False, False,
         False,  True, False, False, False, False, False,  True, False, False,
         False, False, False,  True, False, False, False, False, False, False,
         False, False, False, False, False,  True, False, False, False, False,
         False, False, False, False]]), attention_mask=tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1