In [1]:
# Use a pipeline as a high-level helper
from transformers import pipeline
import os
import json
import pandas as pd

In [3]:
image_to_text = pipeline(
    "image-to-text",
    model="Salesforce/blip-image-captioning-base",
    device="mps",
)

In [2]:
current_dir = os.path.dirname(os.path.realpath("./"))
senticap_data_dir = os.path.join(current_dir, "txt_data", "data")
coco_img_data_dir = os.path.join(current_dir, "img_data", "coco_val2014", "val2014")
senticap_data_json_path = os.path.join(senticap_data_dir, "senticap_dataset.json")
senticap_data_csv_path = os.path.join(senticap_data_dir, "senticap_dataset.csv")
coco_ann_data_dir = os.path.join(current_dir, "img_data", "coco_ann2014", "annotations")
coco_cap_data_path = os.path.join(coco_ann_data_dir, "captions_val2014.json")


# Load the captions from the coco dataset
with open(coco_cap_data_path, "r") as f:
    coco_cap_data = json.load(f)

coco_cap_data_ann = coco_cap_data["annotations"]
coco_cap_data_img = coco_cap_data["images"]

# Create a dataframe from the coco captions
coco_cap_ann_df = pd.DataFrame(coco_cap_data_ann)
coco_cap_img_df = pd.DataFrame(coco_cap_data_img)

# Rename the id column to image_id
coco_cap_img_df.rename(columns={"id": "image_id"}, inplace=True)

In [8]:
import glob

img_file_paths = sorted(
    glob.glob(
        "/Users/roraa/repos/img-txt-categorisation-chat/img_data/coco_val2014/**/**.jpg"
    )
)

In [4]:
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader


# https://discuss.huggingface.co/t/progress-bar-for-hf-pipelines/20498
class ListDataset(Dataset):
    def __init__(self, original_list):
        self.original_list = original_list

    def __len__(self):
        return len(self.original_list)

    def __getitem__(self, i):
        return self.original_list[i]


dataset = ListDataset(img_file_paths)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

In [None]:
image_captions = []
for batch in tqdm(dataloader):
    image_captions.extend(image_to_text(batch))

In [66]:
generated_captions = pd.read_csv(
    "../img_data/coco_ann2014/annotations/generated_captions.csv"
)

In [72]:
generated_captions["file_name"] = generated_captions["file_path"].apply(
    lambda x: x.split("/")[-1].split(".")[0]
)

generated_captions["generated_caption"] = generated_captions["caption"].apply(
    lambda x: str(eval(x)[0]["generated_text"])
)

generated_captions.drop(columns=["caption"], inplace=True)

In [81]:
generated_captions["image_id"] = generated_captions["file_name"].apply(
    lambda x: int(x.split("_")[-1])
)

In [82]:
captions = pd.merge(generated_captions, coco_cap_ann_df, on="image_id")[
    ["image_id", "generated_caption", "caption"]
]

In [87]:
captions[captions["generated_caption"] == "a train on the tracks"]["caption"]

350                 a green train is coming down the tracks
351                 The train is on the tracks by the road.
352              A train riding on a track near a platform.
353       A commuter train pulling out of a suburban sta...
354             a green and yellow train on the train track
                                ...                        
201717    A monorail is going down the track with people...
201718    A passenger train drives passed a station on a...
201719            a train is on the tracks near some trees.
201720     A three car passenger train on the train tracks.
201721    The transit train stretches down the track und...
Name: caption, Length: 3277, dtype: object

In [13]:
# load sentence transformer

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("paraphrase-MiniLM-L6-v2")

# Compute embedding for both lists
generated_captions = captions["generated_caption"].astype(str).tolist()
coco_captions = captions["caption"].astype(str).tolist()

In [41]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

batch_size = 1000
num_batches = int(np.ceil(len(captions) / batch_size))

similarity_scores = []
from tqdm.auto import tqdm


def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


for i in tqdm(range(num_batches)):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(captions))

    generated_embeddings = model.encode(generated_captions[start_idx:end_idx])
    normal_embeddings = model.encode(coco_captions[start_idx:end_idx])

    for j in range(len(generated_embeddings)):
        similarity_scores.append(
            cosine_similarity(
                generated_embeddings[j],
                normal_embeddings[j],
            )
        )

similarity_scores = np.array(similarity_scores)

  0%|          | 0/203 [00:00<?, ?it/s]

In [46]:
captions["similarity_score"] = similarity_scores

In [56]:
captions.sort_values(by="similarity_score", ascending=False)

Unnamed: 0,image_id,generated_caption,caption,similarity_score
104960,300962,a baseball player swinging a bat at a ball,a baseball player swinging a bat at a ball,1.000000
189261,543692,a little girl holding a teddy bear in a field,A little girl holding a teddy bear in a field,1.000000
188595,541887,a stop sign with graffiti written on it,a stop sign with graffiti written on it,1.000000
148431,426053,a bathroom with a toilet and a sink,a bathroom with a toilet and a sink,1.000000
111278,318573,a baseball player holding a bat on a field,a baseball player holding a bat on a field,1.000000
...,...,...,...,...
136332,390826,a man in a white shirt,The panda bear likes to walk in the grass.,-0.117677
185781,533532,a group of people sitting on a wooden bench,A woman standing around with several pieces of...,-0.119673
180417,518326,a woman in a hospital bed with a child,A man is working on some parade floats.,-0.122634
62666,180357,a red and white wall,Three woman dancing with open umbrellas in the...,-0.128763


In [57]:
# plot the similarity scores for the generated captions and the coco captions using plotly express
import plotly.express as px

fig = px.histogram(captions, x="similarity_score")
fig.update_layout(
    title="Histogram of cosine similarity scores between generated captions and COCO captions",
    xaxis_title="Similarity score",
    yaxis_title="Frequency",
    width=800,
    height=600,
)
fig.show()

# save the figure as a html file for later use
fig.write_html("../img_data/coco_ann2014/annotations/similarity_scores.html")