In [1]:
import os
from pathlib import Path

import pandas as pd
import torch

In [2]:
is_kaggle = 'KAGGLE_KERNEL_RUN_TYPE' in os.environ
path = Path('../input/visual-taxonomy') if is_kaggle else Path('data')

if is_kaggle:
    src_path = Path('../input/attribute-extraction-meesho/src')
    from shutil import copytree
    copytree(src = src_path, dst = "../working/src")

In [3]:
from src.data_preparation import *

In [4]:
cat_info = pd.read_parquet(path/'category_attributes.parquet')
df = pd.read_csv(path/'train.csv')
test_df = pd.read_csv(path/'test.csv')

In [5]:
from transformers import BlipImageProcessor, BertTokenizerFast

ckpt = "Salesforce/blip-itm-base-coco"
img_processor = BlipImageProcessor.from_pretrained(ckpt)
tokenizer = BertTokenizerFast.from_pretrained(ckpt)

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/456 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



In [6]:
label_encoder = LabelEncoder(df)
df_processed = process_df(df, path, encoder=label_encoder)
cat_info_processed = process_cat_info(cat_info, tokenizer)

In [7]:
batch_size = 4
dl = MeeshoDataloader(df_processed, cat_info_processed, batch_size=batch_size, img_processor=img_processor)
batch = next(iter(dl))

In [8]:
batch.input_ids.size(), batch.pixel_values.size()

(torch.Size([1, 82]), torch.Size([4, 3, 224, 224]))

In [9]:
from transformers import AutoConfig, BlipTextModel, BlipVisionModel

config = AutoConfig.from_pretrained(ckpt)
vision_model = BlipVisionModel(config.vision_config)
text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
text_encoder.resize_token_embeddings(len(tokenizer))

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

Embedding(30523, 768, padding_idx=0)

In [10]:
vision_outputs = vision_model(
    pixel_values=batch.pixel_values,
    output_attentions=config.output_attentions,
    output_hidden_states=config.output_hidden_states,
    interpolate_pos_encoding=False,
)
image_embeds = vision_outputs[0]
image_embeds.size()

torch.Size([4, 197, 768])

In [18]:
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)

text_outputs = text_encoder(
    input_ids=batch.input_ids,
    attention_mask=batch.attention_mask,
    encoder_hidden_states=image_embeds,
    encoder_attention_mask=image_atts,
)
question_embeds = text_outputs[0]
attr_embeds = question_embeds[:, batch.special_token_mask, :]
question_embeds.size(), attr_embeds.size()

(torch.Size([4, 82, 768]), torch.Size([4, 9, 768]))

In [13]:
import torch.nn as nn

embed_size = config.text_config.hidden_size
heads = {}
for cat in label_encoder.vocab.index:
    heads[cat] = [nn.Linear(embed_size, n) for n in label_encoder.num_classes(cat)]

In [58]:
attr_loss = torch.ones(attr_embeds.size()[1], dtype=torch.float32)
preds = torch.ones(attr_embeds.size()[:-1], dtype=torch.int8)

In [59]:
for i, head in enumerate(heads[batch.category]):
    logits = head(attr_embeds[:, i, :])
    preds[:, i] = logits.argmax(dim=1)
    
    attr_loss[i] = nn.CrossEntropyLoss()(logits, batch.labels[:,i])
    print(attr_loss[i])

loss = attr_loss.mean()
print(loss)

tensor(3.2601, grad_fn=<SelectBackward0>)
tensor(1.2027, grad_fn=<SelectBackward0>)
tensor(0.7630, grad_fn=<SelectBackward0>)
tensor(0.5768, grad_fn=<SelectBackward0>)
tensor(0.8727, grad_fn=<SelectBackward0>)
tensor(0.9106, grad_fn=<SelectBackward0>)
tensor(1.0647, grad_fn=<SelectBackward0>)
tensor(1.0684, grad_fn=<SelectBackward0>)
tensor(0.8854, grad_fn=<SelectBackward0>)
tensor(1.1783, grad_fn=<MeanBackward0>)


In [44]:
from typing import Optional
from dataclasses import dataclass
from transformers.utils import ModelOutput

@dataclass
class BlipAttributeExtractionModelOutput(ModelOutput):
    preds: torch.FloatTensor
    loss: Optional[torch.FloatTensor] = None
    attr_loss: Optional[torch.FloatTensor] = None
    image_embeds: Optional[torch.FloatTensor] = None
    question_embeds: Optional[torch.FloatTensor] = None

In [64]:
from transformers import BlipPreTrainedModel

class BlipForAttributeExtraction(BlipPreTrainedModel):
    def __init__(self, config, tokenizer, label_encoder):
        super().__init__(config)
        self.vision_model = BlipVisionModel(config.vision_config)
        self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
        self.text_encoder.resize_token_embeddings(len(tokenizer))
        
        embed_size = config.text_config.hidden_size
        self.heads = {}
        for cat in label_encoder.vocab.index:
            self.heads[cat] = [nn.Linear(embed_size, n) for n in label_encoder.num_classes(cat)]
        self.post_init()

    def forward(self, batch):
        image_embeds = self.vision_model(pixel_values=batch.pixel_values,
                                         output_attentions=self.config.output_attentions,
                                         output_hidden_states=self.config.output_hidden_states,
                                         interpolate_pos_encoding=False)[0]
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
        
        question_embeds = self.text_encoder(input_ids=batch.input_ids,
                                            attention_mask=batch.attention_mask,
                                            encoder_hidden_states=image_embeds,
                                            encoder_attention_mask=image_atts)[0]
        attr_embeds = question_embeds[:, batch.special_token_mask, :]

        attr_loss = torch.ones(attr_embeds.size()[1], dtype=torch.float32)
        preds = torch.ones(attr_embeds.size()[:-1], dtype=torch.int8)
        
        for i, head in enumerate(self.heads[batch.category]):
            logits = head(attr_embeds[:, i, :])
            preds[:, i] = logits.argmax(dim=1)
            if batch.labels is not None:
                attr_loss[i] = nn.CrossEntropyLoss()(logits.cpu(), batch.labels[:,i])
        loss = attr_loss.mean() if batch.labels is not None else None

        return BlipAttributeExtractionModelOutput(loss=loss, attr_loss=attr_loss,
                                                  preds=preds, image_embeds=image_embeds,
                                                  question_embeds=question_embeds)

In [65]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BlipForAttributeExtraction.from_pretrained(
    ckpt, config=config, label_encoder=label_encoder,
    tokenizer=tokenizer, ignore_mismatched_sizes=True
).to(device)

Some weights of BlipForAttributeExtraction were not initialized from the model checkpoint at Salesforce/blip-itm-base-coco and are newly initialized because the shapes did not match:
- text_encoder.embeddings.word_embeddings.weight: found shape torch.Size([30524, 768]) in the checkpoint and torch.Size([30523, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [66]:
with torch.no_grad():
    output = model(batch)

In [67]:
output.preds.size(), batch.labels.size()

(torch.Size([4, 9]), torch.Size([4, 9]))