In [12]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import hydra
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import json
# from omegaconf import DictConfig, OmegaConf
from src.datamodules.mscoco import MSCOCODataset

In [2]:
# conf
root = "coco"
partition = "val2017"
batch_size = 100
num_workers = 8

vision_model = "ViT-B/32"
model_id = "openai/clip-vit-base-patch32"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id).to(device)

In [4]:
ds = MSCOCODataset(
        root=root,
        annFile=os.path.join(root, "annotations", f"instances_{partition}.json")
        #transform=preprocess
    )

ann_items = ds.coco.loadCats(ds.coco.getCatIds())
tags = [item['name'] for item in ann_items]


loading annotations into memory...
Done (t=0.27s)
creating index...
index created!


In [6]:
images_tensor = None
tags_tensor = None

with torch.no_grad():
    for i in tqdm(range(0, len(ds.ids), batch_size), desc='encoding image batches'):
        batch_ids = ds.ids[i:i+batch_size]
        batch_imgs = [Image.open('coco/images/'+ds.coco.loadImgs(id)[0]["file_name"]).convert('RGB') for id in batch_ids]
        batch = processor(
                    text=None,
                    images=batch_imgs,
                    return_tensors='pt',
                    padding=True
                )['pixel_values'].to(device)
        batch_emb = model.get_image_features(pixel_values=batch)
        batch_emb = batch_emb.squeeze(0)
        
        if images_tensor is None:
            images_tensor = batch_emb
        else:
            images_tensor = torch.cat((images_tensor, batch_emb), dim=0)

    for tag in tqdm(tags, desc='encoding tags'):
        inputs = tokenizer(tag, return_tensors="pt")
        tag_emb = model.get_text_features(**inputs)

        if tags_tensor is None:
            tags_tensor = tag_emb
        else:
            tags_tensor = torch.cat((tags_tensor, tag_emb), dim=0)
        

encoding image batches: 100%|██████████| 50/50 [01:41<00:00,  2.02s/it]
encoding tags: 100%|██████████| 80/80 [00:01<00:00, 45.83it/s]


In [20]:
print(images_tensor.shape)
print(tags_tensor.shape)
print(images_tensor[0][:10])
print(tags_tensor[0][:10])

torch.Size([5000, 512])
torch.Size([80, 512])
tensor([-0.0724, -0.0059, -0.3013,  0.0826,  0.2546, -0.2976,  0.2935,  0.0293,
        -0.0527,  0.0108])
tensor([ 0.1695,  0.0864,  0.1535,  0.0774,  0.0044, -0.3187, -0.3133, -1.1808,
        -0.2216, -0.0016])


In [10]:
ds.get_img_text_table()
test_key = list(ds.img_text_data.keys())[0]
ds.img_text_data[test_key]

{'img_embed_row': 0,
 'tag_ids': [64, 1, 67, 72, 78, 82, 84, 85, 86, 62],
 'tag_embed_rows': [63, 0, 66, 71, 77, 81, 83, 84, 85, 61],
 'tag_names': ['microwave',
  'tv',
  'vase',
  'chair',
  'potted plant',
  'clock',
  'dining table',
  'book',
  'person',
  'refrigerator']}

## Save image / tag embeddings, as well as relational JSON table

In [19]:
# Save everything
torch.save(images_tensor, 'coco_image_embeddings.pt')
torch.save(tags_tensor, 'coco_tag_embeddings.pt')

with open("img_text_data.json", "w") as outfile:
    json.dump(ds.img_text_data, outfile)

In [18]:
# Make sure datatypes for img_text_data are intact
with open("img_text_data.json", "r") as img_text_json:
    img_text_data = json.load(img_text_json)

test_key = list(img_text_data.keys())[0]
test_img = img_text_data[test_key]
print(test_img['tag_ids'])
print(test_img['tag_ids'][0])

[64, 1, 67, 72, 78, 82, 84, 85, 86, 62]
64


In [23]:
# Check to ensure no data loss between saving and reloading image / text embeddings
images_tensor_reloaded = torch.load('coco_image_embeddings.pt', map_location=device)
tags_tensor_reloaded = torch.load('coco_tag_embeddings.pt', map_location=device)

print(images_tensor_reloaded.shape)
print(tags_tensor_reloaded.shape)

print(images_tensor_reloaded[0][:10])
print(tags_tensor_reloaded[0][:10])

torch.Size([5000, 512])
torch.Size([80, 512])
tensor([-0.0724, -0.0059, -0.3013,  0.0826,  0.2546, -0.2976,  0.2935,  0.0293,
        -0.0527,  0.0108])
tensor([ 0.1695,  0.0864,  0.1535,  0.0774,  0.0044, -0.3187, -0.3133, -1.1808,
        -0.2216, -0.0016])
