## Calculating prompt embeddings for the Stable Diffusion Competition

This notebook demonstrates how to the [sentence-transformers](https://github.com/UKPLab/sentence-transformers/) library to calculate sentence embeddings.
 [[docs](https://www.sbert.net/index.html)]


**NOTE:** Since the re-run notebooks won't have internet access, you will need to attach the dataset [sentence-transformers-222](https://www.kaggle.com/datasets/inversion/sentence-transformers-222) to your notebooks to be able to install the Sentence Transformers library and access the `all-MiniLM-L6-v2` model used for encodings during the notebook re-run on the test data.

In [7]:
import sys
import numpy as np
import pandas as pd
from pathlib import Path

sys.path.append('../input/sentence-transformers-222/sentence-transformers')
from sentence_transformers import SentenceTransformer, models

comp_path = Path('/kaggle/input/stable-diffusion-image-to-prompts/')

## Actual prompts used for the images

**NOTE:** This file will *not* be available for the notebook re-run. References to it will create notebook failures.

In [13]:
prompts = pd.read_csv(comp_path / 'prompts.csv', index_col='imgId')
prompts.head(7)

Unnamed: 0_level_0,prompt
imgId,Unnamed: 1_level_1
20057f34d,hyper realistic photo of very friendly and dys...
227ef0887,"ramen carved out of fractal rose ebony, in the..."
92e911621,ultrasaurus holding a black bean taco in the w...
a4e1c55a9,a thundering retro robot crane inks on parchme...
c98f79f71,"portrait painting of a shimmering greek hero, ..."
d8edf2e40,an astronaut standing on a engaging white rose...
f27825b2c,Kaggle employee Phil at a donut shop ordering ...


## The Sample Submission contains correct embeddings (for the example images)

The `sample_submission.csv` file on the Data page has the correct imbeddings for the prompts listed in the `prompts.csv` file. This is so you can test whether you are calculating embeddings correctly.

In [40]:
import os
if os.path.exists("../input/stable-diffusion-image-to-prompts/prompts.csv"):
    df_prompts = pd.read_csv("../input/stable-diffusion-image-to-prompts/prompts.csv")
df_prompts

Unnamed: 0,imgId,prompt
0,20057f34d,hyper realistic photo of very friendly and dys...
1,227ef0887,"ramen carved out of fractal rose ebony, in the..."
2,92e911621,ultrasaurus holding a black bean taco in the w...
3,a4e1c55a9,a thundering retro robot crane inks on parchme...
4,c98f79f71,"portrait painting of a shimmering greek hero, ..."
5,d8edf2e40,an astronaut standing on a engaging white rose...
6,f27825b2c,Kaggle employee Phil at a donut shop ordering ...


## Load the embedding model `all-MiniLM-L6-v2`

This model maps sentences to a 384 dimensional dense vector space. Load it from the [dataset](https://www.kaggle.com/datasets/inversion/sentence-transformers-222).

In [41]:
st_model = SentenceTransformer('/kaggle/input/sentence-transformers-222/all-MiniLM-L6-v2')

## Calculate prompt embeddings

In [42]:
df_Y = st_model.encode(df_prompts['prompt']).flatten()

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [63]:
df_Y[:7]

array([ 0.01884852,  0.03018975,  0.07279224, -0.00067344,  0.01677445,
       -0.11378008, -0.03121929], dtype=float32)

## Compare calculated embeddings with ground truth (within tolerance)

In [44]:
assert np.all(np.isclose(sample_submission['val'].values, prompt_embeddings, atol=1e-07))

# **Function to load images and then flatten them to put in the dataFrame**

In [46]:
def image_id2path(
    img_id: str, 
    folder: str = "stable-diffusion-image-to-prompts"
) -> str:
    return f"../input/{folder}/images/{img_id}.png"

In [60]:
import cv2
import matplotlib.pyplot as plt

# DataFrame with X data
df_X = []
for _, row in df_prompts[:7].iterrows():
    img_id = row["imgId"]
    prompt = row["prompt"]
    path = image_id2path(img_id, "stable-diffusion-image-to-prompts")
    image = cv2.imread(path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    df_X.append(image.flatten())

In [62]:
df_X

[array([143, 135, 132, ...,  56,  68,  78], dtype=uint8),
 array([167, 174, 181, ..., 168, 172, 177], dtype=uint8),
 array([ 93, 125,  56, ..., 157, 205,  60], dtype=uint8),
 array([71, 70, 56, ...,  1,  0,  1], dtype=uint8),
 array([167, 125,  80, ..., 110,  85,  69], dtype=uint8),
 array([121, 110,  97, ...,  30,  32,   8], dtype=uint8),
 array([110,  69,  71, ...,  85,  61,  45], dtype=uint8)]

### We have the df_X and df_Y provided by the test. However it seems like we will need to produce some more df_X from a stable diffusion model using the promts since all of them are not provided 