In [1]:
import json
from typing import List

import pandas as pd

from src.ablation_experiment.validate_notebook_cirr import cirr_val_retrieval_text_image_combinations
from src.blip_modules.blip_text_encoder import BLIPTextEncoder
from src.blip_modules.blip_img_encoder import BLIPImgEncoder

from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations
from src.data_utils import targetpad_transform
from src.fashioniq_experiment.utils import get_combing_function_with_alpha
from src.utils import device

%load_ext autoreload
%autoreload 2

# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 1: Set up the experiment</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Set up the cache for the experiment</div>

In [2]:
cache = {}

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Same concept as script version here</div>

In [3]:
BLIP_PRETRAINED_PATH = '../../models/model_base.pth'
MED_CONFIG_PATH = '../blip_modules/med_config.json'

In [4]:
blip_text_encoder = BLIPTextEncoder(
    BLIP_PRETRAINED_PATH, 
    MED_CONFIG_PATH,
    use_pretrained_proj_layer=True
)

blip_text_encoder = blip_text_encoder.to(device)
print("blip text encoder loaded.")
blip_text_encoder.eval()

load checkpoint from ../../models/model_base.pth for text_encoder.
load checkpoint from ../../models/model_base.pth for text_proj.
blip text encoder loaded.


BLIPTextEncoder(
  (text_encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30524, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Drop

In [5]:
blip_img_encoder = BLIPImgEncoder(BLIP_PRETRAINED_PATH)
blip_img_encoder = blip_img_encoder.to(device)
print("blip img encoder loaded.")
blip_img_encoder.eval()

reshape position embedding from 196 to 576
load checkpoint from ../../models/model_base.pth for visual_encoder.
load checkpoint from ../../models/model_base.pth for vision_proj.
blip img encoder loaded.


BLIPImgEncoder(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
      

In [6]:
print('Target pad preprocess pipeline is used')
preprocess = targetpad_transform(1.25, 384)

Target pad preprocess pipeline is used


# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 2: Load the MLLM generated text captions</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Load the addition text captions</div>

In [7]:
with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:
    text_captions = json.load(f)
    
total_recall_list: List[List[pd.DataFrame]] = []

print(f'Total number of text captions: {len(text_captions)}')

Total number of text captions: 74357


In [8]:
with open('../../cirr_dataset/cirr_labeled_images_cir_cleaned.json', 'r') as f:
    text_captions_cirr = json.load(f)

print(f'Total number of text captions: {len(text_captions_cirr)}')

Total number of text captions: 4609


# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 3: Perform retrieval on the FashionIQ dataset</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the shirt category</div>

In [9]:
shirt_recall = fiq_val_retrieval_text_image_combinations(
    'shirt',
    get_combing_function_with_alpha(0.95),
    blip_text_encoder,
    blip_img_encoder,
    text_captions,
    preprocess,
    0.2,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [00:21<00:00,  3.14s/it]


In [10]:
shirt_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.2,22.52208,36.457312,First set
1,0.2,21.687929,35.525024,Second set
2,0.2,22.571148,35.672227,Third set
3,0.2,22.816487,37.095192,First and second set
4,0.2,22.669284,36.26104,Second and third set
5,0.2,23.110893,36.800784,First and third set
6,0.2,23.159961,36.997056,All sets


## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the dress category</div>

In [11]:
dress_recall = fiq_val_retrieval_text_image_combinations(
    'dress',
    get_combing_function_with_alpha(0.95),
    blip_text_encoder,
    blip_img_encoder,
    text_captions,
    preprocess,
    0.2,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [00:22<00:00,  3.15s/it]


In [12]:
dress_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.2,20.327219,38.07635,First set
1,0.2,18.889439,37.233517,Second set
2,0.2,18.641546,35.448685,Third set
3,0.2,20.525533,38.572136,First and second set
4,0.2,19.732276,37.332672,Second and third set
5,0.2,20.079325,37.729302,First and third set
6,0.2,20.475954,38.621715,All sets


## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the toptee category</div>

In [13]:
toptee_recall = fiq_val_retrieval_text_image_combinations(
    'toptee',
    get_combing_function_with_alpha(0.95),
    blip_text_encoder,
    blip_img_encoder,
    text_captions,
    preprocess,
    0.2,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [00:23<00:00,  3.30s/it]


In [14]:
toptee_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.2,24.579297,46.09893,First set
1,0.2,23.406425,45.690975,Second set
2,0.2,24.987252,45.28302,Third set
3,0.2,25.038245,46.81285,First and second set
4,0.2,24.783275,46.863845,Second and third set
5,0.2,25.650179,46.965834,First and third set
6,0.2,25.140235,47.37379,All sets


In [15]:
# Change the index to 'Combination' column
shirt_recall.set_index('Combination', inplace=True)
dress_recall.set_index('Combination', inplace=True)
toptee_recall.set_index('Combination', inplace=True)

In [16]:
# Average the recall values
average_recall = (shirt_recall + dress_recall + toptee_recall) / 3
average_recall

Unnamed: 0_level_0,beta,recall_at10,recall_at50
Combination,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
First set,0.2,22.476199,40.210864
Second set,0.2,21.327931,39.483172
Third set,0.2,22.066649,38.80131
First and second set,0.2,22.793422,40.826726
Second and third set,0.2,22.394945,40.152519
First and third set,0.2,22.946799,40.49864
All sets,0.2,22.925383,40.99752


# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 3: Perform retrieval on the CIRR dataset</div>

In [17]:
cirr_recall = cirr_val_retrieval_text_image_combinations(
    get_combing_function_with_alpha(0.95),
    blip_text_encoder,
    blip_img_encoder,
    text_captions_cirr,
    preprocess,
    0.2,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [00:49<00:00,  7.11s/it]


In [18]:
cirr_recall

Unnamed: 0,beta,recall_at1,recall_at5,recall_at10,recall_at50,group_recall_at1,group_recall_at2,group_recall_at3,Combination
0,0.2,25.568047,52.786416,66.515189,86.319065,56.756759,77.397752,88.495576,First set
1,0.2,24.730925,52.858168,65.606314,85.840708,56.995934,76.560634,88.878256,Second set
2,0.2,24.611337,53.049511,65.3193,86.079884,56.230569,77.493423,88.423824,Third set
3,0.2,25.257117,53.527862,66.371679,86.079884,56.948102,77.278161,88.782591,First and second set
4,0.2,24.850515,52.858168,66.108584,86.127722,56.613249,77.230328,88.950014,Second and third set
5,0.2,25.639799,53.049511,66.395599,86.342978,57.091606,77.541256,88.543409,First and third set
6,0.2,25.448456,53.647453,66.539103,86.319065,57.019854,77.493423,88.519496,All sets
