# Data Embedding

In [None]:
import os
import json
import torch
import pickle
import pandas as pd

from PIL import Image
from tqdm import tqdm
from torchvision import models
from torchvision import datasets
from torchvision import transforms
from transformers import AutoModel
from transformers import AutoTokenizer

In [None]:
# Image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Image and text encoders
image_encoder = models.resnet50(pretrained=True)
image_encoder.fc = torch.nn.Identity()
image_encoder = image_encoder.to("cuda").eval()

In [None]:
text_encoder = AutoModel.from_pretrained("bert-base-uncased").to("cuda").eval()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
def image_embedder(image_folder, output_file):
    image_embeddings = {}
    
    for image_name in tqdm(os.listdir(image_folder)):
        image_path = os.path.join(image_folder, image_name)
        image = Image.open(image_path).convert("RGB")
        image_tensor = transform(image).unsqueeze(0).to("cuda")
    
        with torch.no_grad():
            image_embedding = image_encoder(image_tensor).squeeze(0).cpu()  # Move back to CPU for saving
            image_embeddings[image_name] = image_embedding

    with open(output_file, "wb") as f:
        pickle.dump(image_embeddings, f)

    return image_embeddings

In [None]:
def cifar_image_embedder(dataset, output_file):
    image_embeddings = {}
    
    for idx, (image_tensor, label) in enumerate(tqdm(dataset)):    
        with torch.no_grad():
            image_embedding = image_encoder(image_tensor.unsqueeze(0).to("cuda")).squeeze(0).cpu()  # Move back to CPU for saving
            image_embeddings[idx] = {
                "label" : label,
                "emb" : image_embedding
            }

    with open(output_file, "wb") as f:
        pickle.dump(image_embeddings, f)

    return image_embeddings

In [None]:
def coco_text_embedder(caption_file, output_file):
    with open(caption_file, 'r') as file:
        captions = json.load(file)

    annotations = [captions["annotations"][i]["caption"] for i in range(len(captions["annotations"]))]
    annotations = pd.DataFrame(annotations)
    annotations["len"] = annotations[0].apply(lambda x : len(tokenizer(x)["input_ids"]))

    max_length = int(annotations.describe()["len"].values[-1]) + 1
    print("max_length", max_length)
    
    img_filenames = {captions["images"][i]["id"] : captions["images"][i]["file_name"] for i in range(len(captions["images"]))}

    text_embeddings = {}
    for i in tqdm(range(len(captions["images"]))):
        tokens = tokenizer(captions["annotations"][i]["caption"], 
                           return_tensors="pt",
                           padding=True, 
                           truncation=True, 
                           max_length=max_length).to("cuda")
        with torch.no_grad():
            text_embedding = text_encoder(**tokens).last_hidden_state[:, 0, :].squeeze(0).cpu()  # Move back to CPU
            text_embeddings[captions["annotations"][i]["id"]] = {
                "img_fn" : img_filenames[captions["annotations"][i]["image_id"]],
                "emb" : text_embedding,
                "caption" : captions["annotations"][i]["caption"], 
            }

    with open(output_file, "wb") as f:
        pickle.dump(text_embeddings, f)

    return text_embeddings

In [None]:
def flickr30k_text_embedder(caption_file, output_file):
    flickr30k_captions = pd.read_csv(caption_file)
    flickr30k_captions["len"] = flickr30k_captions["comment"].apply(lambda x : len(tokenizer(x)["input_ids"]))
    max_length = int(flickr30k_captions.describe()["len"].values[-1]) + 1
    print("max_length", max_length)
    
    text_embeddings = {}
    for i in tqdm(range(len(flickr30k_captions))):
        tokens = tokenizer(flickr30k_captions["comment"].values[i], 
                           return_tensors="pt",
                           padding=True, 
                           truncation=True, 
                           max_length=max_length).to("cuda")
    
        with torch.no_grad():
            text_embedding = text_encoder(**tokens).last_hidden_state[:, 0, :].squeeze(0).cpu()  # Move back to CPU
            text_embeddings[i] = {
                "comment_number" : flickr30k_captions["comment_number"].values[i],
                "emb" : text_embedding,
                "comment" : flickr30k_captions["comment"].values[i],
                "image_name" : flickr30k_captions["image_name"].values[i]
            }
    
    with open(output_file, "wb") as f:
        pickle.dump(text_embeddings, f)

    return text_embeddings

In [None]:
def cifar_text_embedder(dataset, output_file):
    max_length = max([len(tokenizer(name)["input_ids"]) for name in dataset.classes])
    print("max_length", max_length)
    
    text_embeddings = {}
    for idx, name in enumerate(dataset.classes):
        tokens = tokenizer(name, 
                           return_tensors="pt",
                           padding=True, 
                           truncation=True, 
                           max_length=max_length).to("cuda")
    
        with torch.no_grad():
            text_embedding = text_encoder(**tokens).last_hidden_state[:, 0, :].squeeze(0).cpu()  # Move back to CPU
            text_embeddings[idx] = {
                "name" : name,
                "emb" : text_embedding,
                "idx" : idx, 
            }
    
    with open(output_file, "wb") as f:
        pickle.dump(text_embeddings, f)

    return text_embeddings

## flickr30k

### Images

In [None]:
flickr30k_target_dir = "data/flickr30k"
os.makedirs(f"{flickr30k_target_dir}/embs", exist_ok=True)
flickr30k_target_dir

In [None]:
image_embeddings_flickr30k = image_embedder(f"{flickr30k_target_dir}/flickr30k_images", f"{flickr30k_target_dir}/embs/flickr30k_images.pkl")
len(list(image_embeddings_flickr30k.keys()))

### Caption

In [None]:
flickr30k_text_embeddings = flickr30k_text_embedder(f"{flickr30k_target_dir}/captions.txt", f"{flickr30k_target_dir}/embs/flickr30k_captions.pkl")
len(list(flickr30k_text_embeddings.keys()))

## CIFAR-10, CIFAR-100

In [None]:
cifar_target_dir = "data/cifar"
os.makedirs(f"{cifar_target_dir}/embs", exist_ok=True)
cifar_target_dir

### Image

In [None]:
cifar10 = datasets.CIFAR10(root=cifar_target_dir, train=False, download=True, transform=transform)

In [None]:
image_embeddings_cifar10 = cifar_image_embedder(cifar10, f"{cifar_target_dir}/embs/cifar10_images.pkl")
len(list(image_embeddings_cifar10.keys()))

In [None]:
cifar100 = datasets.CIFAR100(root=cifar_target_dir, train=False, download=True, transform=transform)

In [None]:
image_embeddings_cifar100 = cifar_image_embedder(cifar100, f"{cifar_target_dir}/embs/cifar100_images.pkl")
len(list(image_embeddings_cifar100.keys()))

### Caption

In [None]:
cifar10_text_embeddings_train = cifar_text_embedder(cifar10, f"{cifar_target_dir}/embs/cifar10_captions.pkl")
len(list(cifar10_text_embeddings_train.keys()))

In [None]:
cifar100_text_embeddings_train = cifar_text_embedder(cifar100, f"{cifar_target_dir}/embs/cifar100_captions.pkl")
len(list(cifar100_text_embeddings_train.keys()))

## Coco

In [None]:
coco_target_dir = "data/coco"
os.makedirs(f"{coco_target_dir}/embs", exist_ok=True)
coco_target_dir

### Images

In [None]:
coco_image_embeddings_val = image_embedder(f"{coco_target_dir}/val2017/val2017", f"{coco_target_dir}/embs/val2017_images.pkl")
len(list(coco_image_embeddings_val.keys()))

In [None]:
coco_image_embeddings_train = image_embedder(f"{coco_target_dir}/train2017/train2017", f"{coco_target_dir}/embs/train2017_images.pkl")
len(list(coco_image_embeddings_train.keys()))

### Captions

In [None]:
coco_text_embeddings_val = coco_text_embedder(f"{coco_target_dir}/annotations/annotations/captions_val2017.json", f"{coco_target_dir}/embs/val2017_captions.pkl")
len(list(coco_text_embeddings_val.keys()))

In [None]:
coco_text_embeddings_train = coco_text_embedder(f"{coco_target_dir}/annotations/annotations/captions_train2017.json", f"{coco_target_dir}/embs/train2017_captions.pkl")
len(list(coco_text_embeddings_train.keys()))

In [None]:
from IPython import get_ipython

get_ipython().kernel.do_shutdown(restart=True)