In [1]:
import torch
import clip
from pathlib import Path
from PIL import Image
from typing import List
import json
from tqdm.auto import tqdm
import joblib
from captioned_image import CaptionedImage
from img2vec_pytorch import Img2Vec
from PIL import Image
from sentence_transformers import SentenceTransformer
from collections import defaultdict
from gensim import corpora, models, matutils
from nltk.stem import WordNetLemmatizer
from typing import Tuple
import numpy as np


EMBEDDINGS_BASE_PATH = Path("embeddings")
TRAIN_TEST_SPLIT = 0.7

  warn(f"Failed to load image Python extension: {e}")


In [2]:
EMBEDDINGS_BASE_PATH.mkdir(exist_ok=True, parents=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
BIRDS_PATH = Path("data/birds/birds.json")
FLOWERS_PATH = Path("data/flowers/flowers.json")


with open(FLOWERS_PATH) as f:
    flowers = [CaptionedImage.parse_obj(v) for v in json.load(f)]


with open(BIRDS_PATH) as f:
    birds = [CaptionedImage.parse_obj(v) for v in json.load(f)]

In [4]:
bird_classes = list(set(b.class_name for b in birds))
flower_classes = list(set(f.class_name for f in flowers))

bird_test_classes = bird_classes[round(len(bird_classes) * TRAIN_TEST_SPLIT) :]
bird_train_classes = [c for c in bird_classes if c not in bird_test_classes]

flower_test_classes = flower_classes[round(len(flower_classes) * TRAIN_TEST_SPLIT) :]
flower_train_classes = [c for c in flower_classes if c not in flower_test_classes]

print(
    len(bird_test_classes),
    len(bird_train_classes),
    len(flower_test_classes),
    len(flower_train_classes),
)

train_birds = [b for b in birds if b.class_name in bird_train_classes]
test_birds = [b for b in birds if b.class_name in bird_test_classes]
train_flowers = [f for f in flowers if f.class_name in flower_train_classes]
test_flowers = [f for f in flowers if f.class_name in flower_test_classes]

60 140 31 71


In [5]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)  # 512 components
resnet = Img2Vec(cuda=device == "cuda", model="resnet-18")  # 512 components
vgg = Img2Vec(cuda=device == "cuda", model="vgg")  # 4096 components
bert_model = SentenceTransformer("all-distilroberta-v1")
lemmatizer = WordNetLemmatizer()

In [6]:
def generate_image_embeddings(data: List[CaptionedImage], name: str) -> None:
    image_embeddings = {}

    for d in tqdm(data):
        p = Path(d.image_path)
        image = Image.open(p).convert(
            "RGB"
        )  # there are some greyscale images in the dataset

        clip_image = preprocess(image).unsqueeze(0).to(device)
        clip_features = clip_model.encode_image(clip_image).cpu().detach()[0].numpy()

        resnet_features = resnet.get_vec(image)

        vgg_features = resnet.get_vec(image)

        image_embeddings[p] = {
            "class_name": d.class_name,
            "clip": clip_features,
            "resnet": resnet_features,
            "vgg": vgg_features,
        }

    joblib.dump(image_embeddings, EMBEDDINGS_BASE_PATH / f"image_embeddings_{name}.p")

In [7]:
def preprocess_text(text: str) -> List[str]:
    return [lemmatizer.lemmatize(word.lower().strip()) for word in text.split()]


def train_tfidf(
    data: List[CaptionedImage],
) -> Tuple[corpora.Dictionary, models.TfidfModel]:
    texts = [preprocess_text(c) for d in data for c in d.captions]
    dictionary = corpora.Dictionary(texts)
    corpus = [dictionary.doc2bow(text) for text in texts]
    tfidf = models.TfidfModel(corpus)
    return dictionary, tfidf


def get_tfidf_vector(
    model: Tuple[corpora.Dictionary, models.TfidfModel], text: str
) -> np.ndarray:
    dictionary, tfidf = model
    tokens = preprocess_text(text)
    bow = dictionary.doc2bow(tokens)
    transformed_bow = tfidf[bow]
    return matutils.sparse2full(transformed_bow, len(dictionary))

In [8]:
def generate_text_embeddings(data: List[CaptionedImage], name: str, tfidf_model=None):
    if tfidf_model is None:
        tfidf_model = train_tfidf(data)
        joblib.dump(tfidf_model, EMBEDDINGS_BASE_PATH / f"tfidf_{name}.p")

    text_embeddings = defaultdict(dict)

    for d in tqdm(data):
        p = Path(d.image_path)
        for text in d.captions:
            clip_text = clip.tokenize([text], truncate=True).to(device)
            clip_features = clip_model.encode_text(clip_text).cpu().detach()[0].numpy()

            bert_features = bert_model.encode(text)
            
            tfidf_features = get_tfidf_vector(tfidf_model, text)

            text_embeddings[p][text] = {
                "class_name": d.class_name,
                "clip": clip_features,
                "bert": bert_features,
                "tfidf": tfidf_features,
            }

    joblib.dump(text_embeddings, EMBEDDINGS_BASE_PATH / f"text_embeddings_{name}.p")

    return tfidf_model

In [9]:
def generate_embeddings(data: List[CaptionedImage], name: str, tfidf=None) -> None:
    generate_image_embeddings(data, name)
    generate_text_embeddings(data, name, tfidf)

In [10]:
birds_tfidf = generate_embeddings(train_birds, "train_birds")
generate_embeddings(test_birds, "test_birds", birds_tfidf)

flowers_tfidf = generate_embeddings(train_flowers, "train_flowers")
generate_embeddings(test_flowers, "test_flowers", flowers_tfidf)

  0%|          | 0/8272 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


  0%|          | 0/8272 [00:00<?, ?it/s]

  0%|          | 0/3516 [00:00<?, ?it/s]

  0%|          | 0/3516 [00:00<?, ?it/s]

  0%|          | 0/5514 [00:00<?, ?it/s]

  0%|          | 0/5514 [00:00<?, ?it/s]

  0%|          | 0/2675 [00:00<?, ?it/s]

  0%|          | 0/2675 [00:00<?, ?it/s]