# CLIP Interrogator

This notebook used some codes from LEONID KULYK (https://www.kaggle.com/code/leonidkulyk/lb-0-45836-blip-clip-clip-interrogator) 
and made some changes in ensembling CLIP and ViT models

modified Interrogator function taken from the original
https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator/clip_interrogator.py#L213



In [1]:
wheels_path = "/kaggle/input/clip-interrogator-wheels-x"
CI_wheel_path = f"{wheels_path}/clip_interrogator-0.4.3-py3-none-any.whl"

In [2]:
!pip install --no-index --find-links $wheels_path $CI_wheel_path -q

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, which is not installed.
beatrix-jupyterlab 3.1.7 requires google-cloud-bigquery-storage, which is not installed.
tfx-bsl 1.9.0 requires google-api-python-client<2,>=1.7.11, but you have google-api-python-client 2.52.0 which is incompatible.
tfx-bsl 1.9.0 requires tensorflow!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,<3,>=1.15.5, but you have tensorflow 2.6.4 which is incompatible.
tensorflow 2.6.4 requires h5py~=3.1.0, but you have h5py 3.8.0 which is incompatible.
tensorflow 2.6.4 requires numpy~=1.19.2, but you have numpy 1.21.6 which is incompatible.
tensorflow 2.6.4 requires typing-extensions<3.11,>=3.7, but you have typing-extensions 4.1.1 which is incompatible.
tensorflow-transform 1.9.0 requires tens

Importing necessary libraries

In [3]:
import inspect
import importlib

from blip.models import blip
from clip_interrogator import clip_interrogator

In [4]:
# replace tokenizer path to prevent downloading
blip_tokenizer_path = inspect.getfile(blip)

file_in = open(blip_tokenizer_path, "rt")
dt = file_in.read()
dt = dt.replace(
    "BertTokenizer.from_pretrained('bert-base-uncased')", 
    "BertTokenizer.from_pretrained('/kaggle/input/clip-interrogator-models-x/bert-base-uncased')"
)
file_in.close()

file_in = open(blip_tokenizer_path, "wt")
file_in.write(dt)
file_in.close()

# reload module
importlib.reload(blip)

<module 'blip.models.blip' from '/opt/conda/lib/python3.7/site-packages/blip/models/blip.py'>

In [5]:
# fix clip_interrogator bug
CI_path = inspect.getfile(clip_interrogator.Interrogator)

file_in = open(CI_path, "rt")
dt = file_in.read()
dt = dt.replace(
    'open_clip.get_tokenizer(clip_model_name)', 
    'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
file_in.close()

file_in = open(CI_path, "wt")
file_in.write(dt)
file_in.close()

# reload module
importlib.reload(clip_interrogator)

<module 'clip_interrogator.clip_interrogator' from '/opt/conda/lib/python3.7/site-packages/clip_interrogator/clip_interrogator.py'>

In [6]:
import os
import sys
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt 

import numpy as np
import pandas as pd
import torch
import open_clip


sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

competition_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

Configuration

In [7]:
class CONFIGURE:
    device = "cuda"
    seed = 42
    len_embedding = 384
    sent_mod_path = "/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2"
    blip_mod_path = "/kaggle/input/clip-interrogator-models-x/model_large_caption.pth"
    ci_clip_mod_name = "ViT-H-14/laion2b_s32b_b79k"
    clip_mod_name = "ViT-H-14"
    clip_mod_path = "/kaggle/input/clip-interrogator-models-x/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin"
    path_cache = "/kaggle/input/clip-interrogator-models-x"

In [8]:
submission_data = pd.read_csv(competition_path / 'sample_submission.csv', index_col='imgId_eId')
submission_data.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [9]:
images = os.listdir(competition_path / 'images')
imgIds = [i.split('.')[0] for i in images]

eIds = list(range(CONFIGURE.len_embedding))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, CONFIGURE.len_embedding),
        np.tile(range(CONFIGURE.len_embedding), len(imgIds))
    )
]

assert sorted(imgId_eId) == sorted(submission_data.index)

In [10]:
sentence_model = SentenceTransformer(CONFIGURE.sent_mod_path)

In [11]:
model_configuration = clip_interrogator.Config(clip_model_name=CONFIGURE.ci_clip_mod_name)
model_configuration.cache_path = CONFIGURE.path_cache

In [12]:
configurations_path = os.path.join(os.path.dirname(os.path.dirname(blip_tokenizer_path)), 'configs')
med_json_config = os.path.join(configurations_path, 'med_config.json')
blip_model = blip.blip_decoder(
    pretrained=CONFIGURE.blip_mod_path,
    image_size=model_configuration.blip_image_eval_size, 
    vit=model_configuration.blip_model_type, 
    med_config=med_json_config
)
blip_model.eval()
blip_model = blip_model.to(model_configuration.device)
model_configuration.blip_model = blip_model

load checkpoint from /kaggle/input/clip-interrogator-models-x/model_large_caption.pth


In [13]:
clip_model = open_clip.create_model(CONFIGURE.clip_mod_name, precision='fp16' if model_configuration.device == 'cuda' else 'fp32')
open_clip.load_checkpoint(clip_model, CONFIGURE.clip_mod_path)
clip_model.to(model_configuration.device).eval()
model_configuration.clip_model = clip_model

In [14]:
clip_preprocess = open_clip.image_transform(
    clip_model.visual.image_size,
    is_train = False,
    mean = getattr(clip_model.visual, 'image_mean', None),
    std = getattr(clip_model.visual, 'image_std', None),
)
model_configuration.clip_preprocess = clip_preprocess

In [15]:
ci = clip_interrogator.Interrogator(model_configuration)

Loaded CLIP model and data in 2.45 seconds.


In [16]:
cos_sim = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)
#artists_features_array = torch.stack([torch.from_numpy(t) for t in ci.artists.embeds]).to(ci.device)

In [17]:
def interrogate(image: Image) -> str:
    caption = ci.generate_caption(image)
    image_features = ci.image_to_features(image)
    
    medium = [ci.mediums.labels[i] for i in cos_sim(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos_sim(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos_sim(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement}, {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

In [18]:
prompts = []

images_path = "../input/stable-diffusion-image-to-prompts/images/"
for image_name in images:
    img = Image.open(images_path + image_name).convert("RGB")

    generated = interrogate(img)
    
    prompts.append(generated)

In [19]:
def add_text_limiters(text: str) -> str:
    return " ".join([
        word + "\n" if i % 15 == 0 else word 
        for i, word in enumerate(text.split(" "), start=1)
    ])

def plot_image(image: np.ndarray, original_prompt: str, generated_prompt: str) -> None:
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.annotate(
        "Original prompt:\n" + add_text_limiters(original_prompt) + "\n\nGenerated prompt:\n" + add_text_limiters(generated_prompt), 
        xy=(1.05, 0.5), xycoords='axes fraction', ha='left', va='center', 
        fontsize=16, rotation=0, color="#104a6e"
    )

In [20]:
# DO NOT FORGET TO COMMENT OUT THIS CELL DURING SUBMISSION

# original_prompts_df = pd.read_csv("/kaggle/input/stable-diffusion-image-to-prompts/prompts.csv")

# for image_name, prompt in zip(images, prompts):
#     img = Image.open(images_path + image_name).convert("RGB")
#     original_prompt = original_prompts_df[
#         original_prompts_df.imgId == image_name.split(".")[0]
#     ].prompt.iloc[0]
#     plot_image(img, original_prompt, prompt)

In [21]:
clip_embeddings = sentence_model.encode(prompts).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

# ViT

In [22]:
import numpy as np
import pandas as pd
from pathlib import Path
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm

In [23]:
class CustomViT(nn.Module):
    def __init__(self, backbone, num_classes=384):
        super(CustomViT, self).__init__()
        self.backbone = backbone
        self.backbone.head = nn.Identity()  # Remove the original classification head
        
        # Adding custom layers
        self.fc1 = nn.Linear(backbone.num_features, 1024)
        self.dropout1 = nn.Dropout(0.5)
        self.activation1 = nn.GELU()
        self.fc2 = nn.Linear(1024, 512)
        self.dropout2 = nn.Dropout(0.5)
        self.activation2 = nn.GELU()
        self.fc3 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.fc1(x)
        x = self.dropout1(x)
        x = self.activation1(x)
        x = self.fc2(x)
        x = self.dropout2(x)
        x = self.activation2(x)
        x = self.fc3(x)
        return x


In [24]:
class CFG:
    model_path = '/kaggle/input/trained-vit-3/vit_best_model_3ep.pth'
    model_name = 'vit_base_patch16_224'
    input_size = 224
    batch_size = 64


In [25]:
class DiffusionTestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform
    
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        image = self.transform(image)
        return image

def load_model(model_path, model_name):
    backbone = timm.create_model(model_name, pretrained=False)
    model = CustomViT(backbone)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model



In [26]:
def predict(images, model_path, model_name, input_size, batch_size):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    dataset = DiffusionTestDataset(images, transform)
    dataloader = DataLoader(
        dataset=dataset,
        shuffle=False,
        batch_size=batch_size,
        pin_memory=True,
        num_workers=2,
        drop_last=False
    )

    model = load_model(model_path, model_name)
    model.to(device)
    model.eval()
    
    preds = []
    for X in tqdm(dataloader, leave=False):
        X = X.to(device)

        with torch.no_grad():
            X_out = model(X)
            preds.append(X_out.cpu().numpy())
    
    return np.vstack(preds).flatten()




In [27]:
images = list(Path('/kaggle/input/stable-diffusion-image-to-prompts/images').glob('*.png'))
imgIds = [i.stem for i in images]
EMBEDDING_LENGTH = 384
imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, EMBEDDING_LENGTH),
        np.tile(range(EMBEDDING_LENGTH), len(imgIds)))]


In [28]:
prompt_embeddings = predict(images, CFG.model_path, CFG.model_name, CFG.input_size, CFG.batch_size)


  0%|          | 0/1 [00:00<?, ?it/s]

In [29]:
len(prompt_embeddings)

2688

In [30]:
clip_weight = 0.25
vit_weight = 0.75

In [31]:
final_embeddings = (clip_weight * clip_embeddings) + (vit_weight * prompt_embeddings)

In [32]:
len(final_embeddings)

2688

In [33]:
submission = pd.DataFrame(
    index=imgId_eId,
    data=final_embeddings,
    columns=['val']
).rename_axis('imgId_eId')

In [34]:
submission.to_csv('submission.csv')

In [35]:
submission

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
f27825b2c_0,-1.244370
f27825b2c_1,1.671970
f27825b2c_2,-1.163932
f27825b2c_3,0.959149
f27825b2c_4,-0.874024
...,...
c98f79f71_379,0.406330
c98f79f71_380,7.430394
c98f79f71_381,1.012362
c98f79f71_382,-1.398369
