## Install useful packages

In [1]:
!pip install transformers accelerate tokenizers nltk



## Eval method of competition:

Cosine similiarity between predictions and actual prompt embeddings

## Example using SBERT to get embeddings

In [2]:
import sys
import numpy as np
import pandas as pd
from pathlib import Path
import os

sys.path.append('../data/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models, util

# Not available during re-run
comp_path = Path('../data/stable-diffusion-image-to-prompts/')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
prompts = pd.read_csv(comp_path / 'prompts.csv', index_col='imgId')
prompts.head(7)

Unnamed: 0_level_0,prompt
imgId,Unnamed: 1_level_1
20057f34d,hyper realistic photo of very friendly and dys...
227ef0887,"ramen carved out of fractal rose ebony, in the..."
92e911621,ultrasaurus holding a black bean taco in the w...
a4e1c55a9,a thundering retro robot crane inks on parchme...
c98f79f71,"portrait painting of a shimmering greek hero, ..."
d8edf2e40,an astronaut standing on a engaging white rose...
f27825b2c,Kaggle employee Phil at a donut shop ordering ...


In [4]:
st_model = SentenceTransformer('../data/sentence-transformers-222/all-MiniLM-L6-v2')
prompt_embeddings = st_model.encode(prompts['prompt'], device="cuda:1", convert_to_tensor=True).flatten()
print(prompt_embeddings.shape)

torch.Size([2688])


#### Hidden dimension is 384, thus for each image id, there will be a vector of size 384

In [5]:
!pip install clip-interrogator==0.5.4

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


### CLIP Interrogator helps with generating similiar images from given image
Basically it consists of two models: BLIP and CLIP. Firstly, image is encoded through BLIP and outputs text description. At the same time, image is ecnoded via CLIP and then find the most similiar text from predefined text label files. Then texts from BLIP and output of CLIP branch are concatenated.

In [6]:
df_submission = pd.read_csv('../data/stable-diffusion-image-to-prompts/sample_submission.csv', index_col='imgId_eId')
df_submission.head()

Unnamed: 0_level_0,val
imgId_eId,Unnamed: 1_level_1
20057f34d_0,0.018848
20057f34d_1,0.03019
20057f34d_2,0.072792
20057f34d_3,-0.000673
20057f34d_4,0.016774


In [7]:
images = os.listdir('../data/stable-diffusion-image-to-prompts/images')
imgIds = [i.split('.')[0] for i in images]

eIds = list(range(384))

imgId_eId = [
    '_'.join(map(str, i)) for i in zip(
        np.repeat(imgIds, 384),
        np.tile(range(384), len(imgIds))
    )
]

In [8]:
imgId_eId

['227ef0887_0',
 '227ef0887_1',
 '227ef0887_2',
 '227ef0887_3',
 '227ef0887_4',
 '227ef0887_5',
 '227ef0887_6',
 '227ef0887_7',
 '227ef0887_8',
 '227ef0887_9',
 '227ef0887_10',
 '227ef0887_11',
 '227ef0887_12',
 '227ef0887_13',
 '227ef0887_14',
 '227ef0887_15',
 '227ef0887_16',
 '227ef0887_17',
 '227ef0887_18',
 '227ef0887_19',
 '227ef0887_20',
 '227ef0887_21',
 '227ef0887_22',
 '227ef0887_23',
 '227ef0887_24',
 '227ef0887_25',
 '227ef0887_26',
 '227ef0887_27',
 '227ef0887_28',
 '227ef0887_29',
 '227ef0887_30',
 '227ef0887_31',
 '227ef0887_32',
 '227ef0887_33',
 '227ef0887_34',
 '227ef0887_35',
 '227ef0887_36',
 '227ef0887_37',
 '227ef0887_38',
 '227ef0887_39',
 '227ef0887_40',
 '227ef0887_41',
 '227ef0887_42',
 '227ef0887_43',
 '227ef0887_44',
 '227ef0887_45',
 '227ef0887_46',
 '227ef0887_47',
 '227ef0887_48',
 '227ef0887_49',
 '227ef0887_50',
 '227ef0887_51',
 '227ef0887_52',
 '227ef0887_53',
 '227ef0887_54',
 '227ef0887_55',
 '227ef0887_56',
 '227ef0887_57',
 '227ef0887_58',
 '227ef

In [9]:
from blip.models import blip
from clip_interrogator import clip_interrogator

In [28]:
# fix clip_interrogator bug
import inspect
import importlib

clip_interrogator_path = inspect.getfile(clip_interrogator.Interrogator)

fin = open(clip_interrogator_path, "rt")
data = fin.read()
data = data.replace(
    'open_clip.get_tokenizer(clip_model_name)', 
    'open_clip.get_tokenizer(config.clip_model_name.split("/", 2)[0])'
)
fin.close()

fin = open(clip_interrogator_path, "wt")
fin.write(data)
fin.close()

# reload module
importlib.reload(clip_interrogator)

<module 'clip_interrogator.clip_interrogator' from '/home/bobrin_m_s/anaconda3/envs/DL/lib/python3.10/site-packages/clip_interrogator/clip_interrogator.py'>

In [29]:
model_config = clip_interrogator.Config(clip_model_name="ViT-H-14/laion2b_s32b_b79k",
                                        cache_path="../data/cache",
                                        download_cache=True, device="cuda:1")

In [30]:
blip_model = blip.blip_decoder(
    image_size=model_config.blip_image_eval_size, 
    vit=model_config.blip_model_type,
    med_config="../data/BLIP_configs/med_config.yaml"
)

blip_model.eval()
blip_model = blip_model.to(model_config.device)
model_config.blip_model = blip_model

In [31]:
import open_clip

clip_model = open_clip.create_model("ViT-H-14",
                                    precision='fp16',
                                   pretrained="laion2B-s32B-b79K")

#open_clip.load_checkpoint(clip_model, CFG.clip_model_path)
clip_model.to(model_config.device).eval()
model_config.clip_model = clip_model

In [32]:
clip_preprocess = open_clip.image_transform(
    clip_model.visual.image_size,
    is_train = False,
    mean = getattr(clip_model.visual, 'image_mean', None),
    std = getattr(clip_model.visual, 'image_std', None),
)
model_config.clip_preprocess = clip_preprocess

In [33]:
ci = clip_interrogator.Interrogator(model_config)

ViT-H-14_laion2b_s32b_b79k_artists.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 21.6M/21.6M [00:00<00:00, 72.0MB/s]
ViT-H-14_laion2b_s32b_b79k_flavors.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 207M/207M [00:02<00:00, 89.0MB/s]
ViT-H-14_laion2b_s32b_b79k_mediums.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 195k/195k [00:00<00:00, 4.26MB/s]
ViT-H-14_laion2b_s32b_b79k_movements.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 410k/410k [00:00<00:00, 6.01MB/s]
ViT-H-14_laion2b_s32b_b79k_trendings.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████| 148k/148k [00:00<00:00, 3.71MB/s]
ViT-H-14_laion2b_s32b_b79k_negative.safetensors: 100%|██████████████████████████

Loaded CLIP model and data in 6.98 seconds.





In [34]:
import torch

cos = torch.nn.CosineSimilarity(dim=1)

mediums_features_array = torch.stack([torch.from_numpy(t) for t in ci.mediums.embeds]).to(ci.device)
movements_features_array = torch.stack([torch.from_numpy(t) for t in ci.movements.embeds]).to(ci.device)
flavors_features_array = torch.stack([torch.from_numpy(t) for t in ci.flavors.embeds]).to(ci.device)

In [41]:
def interrogate(image: Image) -> str:
    caption = ci.generate_caption(image)
    image_features = ci.image_to_features(image)
    
    medium = [ci.mediums.labels[i] for i in cos(image_features, mediums_features_array).topk(1).indices][0]
    movement = [ci.movements.labels[i] for i in cos(image_features, movements_features_array).topk(1).indices][0]
    flaves = ", ".join([ci.flavors.labels[i] for i in cos(image_features, flavors_features_array).topk(3).indices])

    if caption.startswith(medium):
        prompt = f"{caption}, {movement}, {flaves}"
    else:
        prompt = f"{caption}, {medium}, {movement}, {flaves}"

    return clip_interrogator._truncate_to_fit(prompt, ci.tokenize)

In [42]:
prompts = []
from PIL import Image

images_path = "../data/stable-diffusion-image-to-prompts/images/"
for image_name in images:
    img = Image.open(images_path + image_name).convert("RGB")

    generated = interrogate(img)
    
    prompts.append(generated)