In [1]:
import json
from typing import List

import pandas as pd
import torch
from clip import tokenize
from transformers import CLIPTextModelWithProjection, CLIPVisionModelWithProjection, CLIPImageProcessor

from src.ablation_experiment.validate_notebook import fiq_val_retrieval_text_image_combinations_clip
from src.ablation_experiment.validate_notebook_cirr import cirr_val_retrieval_text_image_combinations_clip
from src.fashioniq_experiment.utils import get_combing_function_with_alpha
from src.utils import device

%load_ext autoreload
%autoreload 2

# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 1: Set up the experiment</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Set up the cache for the experiment</div>

In [2]:
cache = {}

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Same concept as script version here</div>

In [3]:
CLIP_NAME = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'

In [4]:
clip_text_encoder = CLIPTextModelWithProjection.from_pretrained(CLIP_NAME, torch_dtype=torch.float32, projection_dim=768)
clip_text_encoder = clip_text_encoder.float().to(device)

print("clip text encoder loaded.")
clip_text_encoder.eval()

clip text encoder loaded.


CLIPTextModelWithProjection(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,

In [5]:
clip_img_encoder = CLIPVisionModelWithProjection.from_pretrained(CLIP_NAME,torch_dtype=torch.float32, projection_dim=768)

clip_img_encoder = clip_img_encoder.float().to(device)
print("clip img encoder loaded.")
clip_img_encoder.eval()

clip img encoder loaded.


CLIPVisionModelWithProjection(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
      (position_embedding): Embedding(257, 1024)
    )
    (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-23): 24 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): GELUActivation()
            (fc1): Linear(in_features=1024, out_features=40

In [6]:
print('CLIP preprocess pipeline is used')
preprocess = CLIPImageProcessor(
    crop_size={'height': 224, 'width': 224},
    do_center_crop=True,
    do_convert_rgb=True,
    do_normalize=True,
    do_rescale=True,
    do_resize=True,
    image_mean=[0.48145466, 0.4578275, 0.40821073],
    image_std=[0.26862954, 0.26130258, 0.27577711],
    resample=3,
    size={'shortest_edge': 224},
)

CLIP preprocess pipeline is used


In [7]:
clip_tokenizer = tokenize

# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 2: Load the MLLM generated text captions</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Load the addition text captions</div>

In [8]:
with open('../../fashionIQ_dataset/labeled_images_cir_cleaned.json', 'r') as f:
    text_captions = json.load(f)
    
total_recall_list: List[List[pd.DataFrame]] = []

print(f'Total number of text captions: {len(text_captions)}')

Total number of text captions: 74357


In [9]:
with open('../../cirr_dataset/cirr_labeled_images_cir_cleaned.json', 'r') as f:
    text_captions_cirr = json.load(f)

print(f'Total number of text captions: {len(text_captions_cirr)}')

Total number of text captions: 4609


# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 3: Perform retrieval on the FashionIQ dataset</div>

## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the shirt category</div>

In [10]:
shirt_recall = fiq_val_retrieval_text_image_combinations_clip(
    'shirt',
    get_combing_function_with_alpha(0.8),
    clip_text_encoder,
    clip_img_encoder,
    clip_tokenizer,
    text_captions,
    preprocess,
    0.1,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [01:28<00:00, 12.60s/it]


In [11]:
shirt_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.1,32.679096,49.509323,First set
1,0.1,32.777232,48.135427,Second set
2,0.1,32.384691,47.988224,Third set
3,0.1,33.022571,49.018645,First and second set
4,0.1,32.924435,47.742885,Second and third set
5,0.1,33.07164,48.969579,First and third set
6,0.1,32.777232,48.969579,All sets


## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the dress category</div>

In [12]:
dress_recall = fiq_val_retrieval_text_image_combinations_clip(
    'dress',
    get_combing_function_with_alpha(0.8),
    clip_text_encoder,
    clip_img_encoder,
    clip_tokenizer,
    text_captions,
    preprocess,
    0.1,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [01:20<00:00, 11.47s/it]


In [13]:
dress_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.1,25.582549,46.603867,First set
1,0.1,24.144769,46.45513,Second set
2,0.1,24.243927,46.207237,Third set
3,0.1,26.177493,47.347546,First and second set
4,0.1,24.987605,46.306396,Second and third set
5,0.1,25.929597,47.198811,First and third set
6,0.1,25.880021,47.297966,All sets


## <div style="font-family: 'Lucida Sans Unicode', sans-serif; font-size: 18px; color: #4A235A; background-color: #D7BDE2; text-align: left; padding: 10px; border-left: 5px solid #7D3C98; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2); margin-bottom: 10px;">Perform retrieval on the toptee category</div>

In [14]:
toptee_recall = fiq_val_retrieval_text_image_combinations_clip(
    'toptee',
    get_combing_function_with_alpha(0.8),
    clip_text_encoder,
    clip_img_encoder,
    clip_tokenizer,
    text_captions,
    preprocess,
    0.1,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [01:29<00:00, 12.76s/it]


In [15]:
toptee_recall

Unnamed: 0,beta,recall_at10,recall_at50,Combination
0,0.1,36.053035,56.348801,First set
1,0.1,35.390106,55.940849,Second set
2,0.1,35.90005,55.685872,Third set
3,0.1,36.308005,56.45079,First and second set
4,0.1,35.237125,56.297809,Second and third set
5,0.1,36.308005,56.756759,First and third set
6,0.1,35.951045,56.705761,All sets


In [16]:
# Change the index to 'Combination' column
shirt_recall.set_index('Combination', inplace=True)
dress_recall.set_index('Combination', inplace=True)
toptee_recall.set_index('Combination', inplace=True)

In [17]:
# Average the recall values
average_recall = (shirt_recall + dress_recall + toptee_recall) / 3
average_recall

Unnamed: 0_level_0,beta,recall_at10,recall_at50
Combination,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
First set,0.1,31.438227,50.820664
Second set,0.1,30.770702,50.177135
Third set,0.1,30.842889,49.960444
First and second set,0.1,31.836023,50.938994
Second and third set,0.1,31.049721,50.115697
First and third set,0.1,31.769748,50.97505
All sets,0.1,31.536099,50.991102


# <div style="font-family: 'Garamond', serif; font-size: 22px; color: #ffffff; background-color: #34568B; text-align: center; padding: 15px; border-radius: 10px; border: 2px solid #FF6F61; box-shadow: 0 6px 12px rgba(0, 0, 0, 0.3); margin-bottom: 20px;">Step 3: Perform retrieval on the CIRR dataset</div>

In [18]:
cirr_recall = cirr_val_retrieval_text_image_combinations_clip(
    get_combing_function_with_alpha(0.8),
    clip_text_encoder,
    clip_img_encoder,
    clip_tokenizer,
    text_captions_cirr,
    preprocess,
    0.1,
    cache,
)

Evaluating feature combinations: 100%|██████████| 7/7 [03:18<00:00, 28.42s/it]


In [19]:
cirr_recall

Unnamed: 0,beta,recall_at1,recall_at5,recall_at10,recall_at50,group_recall_at1,group_recall_at2,group_recall_at3,Combination
0,0.1,30.25592,61.803395,74.336284,92.394161,59.315956,80.100453,89.332694,First set
1,0.1,29.873237,61.277205,73.570913,92.178905,58.93327,79.311168,89.476204,Second set
2,0.1,29.992825,61.468548,73.570913,91.91581,59.076774,79.693854,88.997847,Third set
3,0.1,30.351591,61.922985,74.216694,92.418081,59.028941,79.933029,89.476204,First and second set
4,0.1,30.184168,61.707723,73.905766,92.226738,59.196365,79.789525,89.18919,Second and third set
5,0.1,30.542931,62.090409,74.168861,92.322409,59.722555,80.220044,89.260942,First and third set
6,0.1,30.638602,62.090409,74.144942,92.418081,59.43554,80.100453,89.428365,All sets
