In [103]:
import os
import random
import math
import json
import yaml

import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from PIL import Image
import hydra
from transformers import CLIPTokenizerFast, CLIPProcessor, CLIPModel
import dgl

# from omegaconf import DictConfig, OmegaConf
from src.datamodules.mscoco import MSCOCODataset

In [2]:
# conf
root = "coco"
partition = "val2017"
batch_size = 100
num_workers = 8

vision_model = "ViT-B/32"
model_id = "openai/clip-vit-base-patch32"

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = CLIPTokenizerFast.from_pretrained(model_id)
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id).to(device)

In [107]:
ds = MSCOCODataset(
        root=root,
        annFile=os.path.join(root, "annotations", f"instances_{partition}.json")
        #transform=preprocess
    )

ann_items = ds.coco.loadCats(ds.coco.getCatIds())
tags = [item['name'] for item in ann_items]
tag_ids = [item['id'] for item in ann_items]


loading annotations into memory...
Done (t=0.30s)
creating index...
index created!


In [102]:
print(tag_ids)

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]


In [6]:
images_tensor = None
tags_tensor = None

with torch.no_grad():
    for i in tqdm(range(0, len(ds.ids), batch_size), desc='encoding image batches'):
        batch_ids = ds.ids[i:i+batch_size]
        batch_imgs = [Image.open('coco/images/'+ds.coco.loadImgs(id)[0]["file_name"]).convert('RGB') for id in batch_ids]
        batch = processor(
                    text=None,
                    images=batch_imgs,
                    return_tensors='pt',
                    padding=True
                )['pixel_values'].to(device)
        batch_emb = model.get_image_features(pixel_values=batch)
        batch_emb = batch_emb.squeeze(0)
        
        if images_tensor is None:
            images_tensor = batch_emb
        else:
            images_tensor = torch.cat((images_tensor, batch_emb), dim=0)

    for tag in tqdm(tags, desc='encoding tags'):
        inputs = tokenizer(tag, return_tensors="pt")
        tag_emb = model.get_text_features(**inputs)

        if tags_tensor is None:
            tags_tensor = tag_emb
        else:
            tags_tensor = torch.cat((tags_tensor, tag_emb), dim=0)
        

encoding image batches: 100%|██████████| 50/50 [01:41<00:00,  2.02s/it]
encoding tags: 100%|██████████| 80/80 [00:01<00:00, 45.83it/s]


In [20]:
print(images_tensor.shape)
print(tags_tensor.shape)
print(images_tensor[0][:10])
print(tags_tensor[0][:10])

torch.Size([5000, 512])
torch.Size([80, 512])
tensor([-0.0724, -0.0059, -0.3013,  0.0826,  0.2546, -0.2976,  0.2935,  0.0293,
        -0.0527,  0.0108])
tensor([ 0.1695,  0.0864,  0.1535,  0.0774,  0.0044, -0.3187, -0.3133, -1.1808,
        -0.2216, -0.0016])


In [10]:
ds.get_img_text_table()
test_key = list(ds.img_text_data.keys())[0]
ds.img_text_data[test_key]

{'img_embed_row': 0,
 'tag_ids': [64, 1, 67, 72, 78, 82, 84, 85, 86, 62],
 'tag_embed_rows': [63, 0, 66, 71, 77, 81, 83, 84, 85, 61],
 'tag_names': ['microwave',
  'tv',
  'vase',
  'chair',
  'potted plant',
  'clock',
  'dining table',
  'book',
  'person',
  'refrigerator']}

## Save image / tag embeddings, as well as relational JSON table

In [19]:
# Save everything
torch.save(images_tensor, 'coco_image_embeddings.pt')
torch.save(tags_tensor, 'coco_tag_embeddings.pt')

with open("img_text_data.json", "w") as outfile:
    json.dump(ds.img_text_data, outfile)

In [18]:
# Make sure datatypes for img_text_data are intact
with open("img_text_data.json", "r") as img_text_json:
    img_text_data = json.load(img_text_json)

test_key = list(img_text_data.keys())[0]
test_img = img_text_data[test_key]
print(test_img['tag_ids'])
print(test_img['tag_ids'][0])

[64, 1, 67, 72, 78, 82, 84, 85, 86, 62]
64


In [23]:
# Check to ensure no data loss between saving and reloading image / text embeddings
images_tensor_reloaded = torch.load('coco_image_embeddings.pt', map_location=device)
tags_tensor_reloaded = torch.load('coco_tag_embeddings.pt', map_location=device)

print(images_tensor_reloaded.shape)
print(tags_tensor_reloaded.shape)

print(images_tensor_reloaded[0][:10])
print(tags_tensor_reloaded[0][:10])

torch.Size([5000, 512])
torch.Size([80, 512])
tensor([-0.0724, -0.0059, -0.3013,  0.0826,  0.2546, -0.2976,  0.2935,  0.0293,
        -0.0527,  0.0108])
tensor([ 0.1695,  0.0864,  0.1535,  0.0774,  0.0044, -0.3187, -0.3133, -1.1808,
        -0.2216, -0.0016])


## Build Graph from Embeddings using DGL

Special Note: To avoid having duplicate node IDs between images and tags, we reserve node IDs 1 - max(tag_ids) for tags (text) and node IDs >= max(tag_ids) + 1 for images.

In [106]:
tests = tag_ids.copy()
random.shuffle(tests)
print(tests)
print(tag_ids)

[10, 37, 88, 75, 54, 72, 24, 52, 38, 59, 3, 65, 85, 73, 57, 2, 50, 35, 47, 28, 49, 27, 77, 67, 8, 17, 4, 18, 46, 25, 61, 40, 32, 5, 39, 89, 23, 15, 6, 31, 78, 42, 20, 62, 87, 51, 11, 64, 53, 70, 80, 81, 63, 79, 41, 21, 58, 55, 44, 56, 86, 36, 7, 48, 9, 16, 76, 13, 82, 74, 34, 14, 60, 33, 19, 84, 22, 43, 90, 1]
[73, 36, 27, 19, 88, 78, 79, 7, 10, 9, 56, 57, 53, 18, 4, 46, 25, 49, 74, 3, 58, 55, 44, 28, 86, 41, 81, 60, 15, 16, 6, 50, 32, 61, 70, 77, 64, 23, 1, 37, 54, 62, 22, 2, 39, 38, 34, 67, 11, 48, 87, 40, 13, 47, 51, 59, 33, 20, 76, 31, 21, 8, 84, 43, 80, 90, 65, 89, 72, 35, 63, 75, 85, 5, 17, 24, 82, 52, 42, 14]


In [77]:
def train_val_test_ids(node_ids):
    random.shuffle(node_ids)
    n = len(node_ids)
    train_ct, val_ct = [math.floor(n*0.7), math.ceil(n*0.15)]
    return [node_ids[0:train_ct], node_ids[train_ct:train_ct+val_ct], node_ids[train_ct+val_ct:]]

def split_mask(node_id, split_ids):
    return True if node_id in split_ids else False

def nodes_table(node_ids, features):
    nodes = pd.DataFrame({'node_id': list(node_ids)})
    train_ids, val_ids, test_ids = train_val_test_ids(node_ids)
    nodes['train_mask'] = nodes['node_id'].apply(lambda x: split_mask(x, train_ids))
    nodes['val_mask'] = nodes['node_id'].apply(lambda x: split_mask(x, val_ids))
    nodes['test_mask'] = nodes['node_id'].apply(lambda x: split_mask(x, test_ids))
    features = [', '.join([str(val) for val in item]) for item in features.tolist()]
    nodes['feat'] = features

    return nodes

img_id_offset = max(tag_ids) + 1
img_ids_rand = [int(key) + img_id_offset for key in img_text_data.keys()]
tag_ids_rand = tag_ids.copy()

img_nodes = nodes_table([int(key) + img_id_offset for key in img_text_data.keys()], images_tensor)
tag_nodes = nodes_table(tag_ids, tags_tensor)

nodes = pd.concat([img_nodes, tag_nodes])
nodes.head()



Unnamed: 0,node_id,train_mask,val_mask,test_mask,feat
0,220,False,False,True,"-0.07242653518915176, -0.00588227529078722, -0..."
1,366,True,False,False,"-0.015050115995109081, -0.6406027674674988, -0..."
2,713,True,False,False,"0.10296262800693512, 0.34772026538848877, 0.35..."
3,805,True,False,False,"0.00568790640681982, -0.02356841042637825, 0.5..."
4,857,True,False,False,"-0.374350905418396, -0.18562030792236328, -0.0..."


In [88]:
def edges_table(img_text_data):
    edges = pd.DataFrame(columns=['src_id', 'dst_id'])
    for img_id in img_text_data.keys():
        for tag in img_text_data[img_id]['tag_ids']:
            edges = pd.concat([edges, pd.DataFrame({'src_id': [int(img_id)+img_id_offset], 'dst_id': [int(tag)]})])
    return edges

edges = edges_table(img_text_data)
edges.head()

Unnamed: 0,src_id,dst_id
0,220,64
0,220,1
0,220,67
0,220,72
0,220,78


In [89]:
# Store graph csv files

if not os.path.exists('graph_csv'):
    os.mkdir('graph_csv')

edges.to_csv('graph_csv/coco_edges.csv')
nodes.to_csv('graph_csv/coco_nodes.csv')

g_metadata = {
    'dataset_name': 'coco_val_graph',
    'edge_data': [{'file_name': 'coco_edges.csv'}],
    'node_data': [{'file_name': 'coco_nodes.csv'}]
}

with open('graph_csv/meta.yaml', 'w') as file:
    yaml.dump(g_metadata, file)

In [90]:
dataset = dgl.data.CSVDataset('./graph_csv')
dataset



KeyError: 82