In [4]:
from tqdm import tqdm
import matplotlib.pyplot as plt
import networkx as nx
import io
from PIL import Image
LOCAL_DATA_PATH = '/local/home/jthomm/GraphCLIP/datasets/visual_genome/'
import json
# you need to download the scene graph data from visual genome, it's not included in dario's folder (and i don't have write access there)
with open(LOCAL_DATA_PATH+'raw/scene_graphs.json', 'r') as f:
    scene_graphs_dict = json.load(f)

def build_graph(g_dict):
        G = nx.DiGraph()
        G.image_id=g_dict['image_id']
        with open(LOCAL_DATA_PATH+'raw/VG/'+str(G.image_id)+'.jpg', 'rb') as f:
            image_bytes = f.read()
            s = Image.open(io.BytesIO(image_bytes)).size
            G.image_w = s[0]
            G.image_h = s[1]
        G.labels = {}
        for obj in g_dict['objects']:
            G.add_node(obj['object_id'], w=obj['w'], h=obj['h'], x=obj['x'], y=obj['y'], attributes=obj.get('attributes',[]), name=obj['names'][0])
            G.labels[obj['object_id']] = obj['names'][0]
        for rel in g_dict['relationships']:
            G.add_edge(rel['subject_id'], rel['object_id'], synsets=rel['synsets'] ,relationship_id=rel['relationship_id'], predicate=rel['predicate'])
        return G
graphs = [] 
for g_dict in tqdm(scene_graphs_dict):
    graphs.append(build_graph(g_dict))

In [7]:
# convert all object labels, attributes and predicates to lower case and remove trailing spaces
for g in tqdm(graphs):
    for n in g.nodes:
        g.nodes[n]['name'] = g.nodes[n]['name'].lower().strip()
        g.nodes[n]['attributes'] = [a.lower().strip() for a in g.nodes[n]['attributes']]
    for e in g.edges:
        g.edges[e]['predicate'] = g.edges[e]['predicate'].lower().strip()


100%|██████████| 108077/108077 [00:15<00:00, 6927.48it/s] 


In [8]:
# extract all object labels and all relationship labels and all attribute labels from the graphs
object_labels = {}
relationship_labels = {}
attribute_labels = {}
for g in graphs:
    for obj_label in g.labels.values():
        # # remove trailing spaces
        # obj_label = obj_label.strip().lower()
        object_labels[obj_label] = object_labels.get(obj_label, 0) + 1
    for rel_label in [g.edges[e]['predicate'] for e in g.edges]:
        # rel_label = rel_label.strip().lower()
        relationship_labels[rel_label] = relationship_labels.get(rel_label, 0) + 1
    for attr_label in [g.nodes[n]['attributes'] for n in g.nodes]:
        for a in attr_label:
            # a = a.strip().lower()
            attribute_labels[a] = attribute_labels.get(a, 0) + 1
print(f'number of graphs: {len(graphs)}')
print(f'number of object labels: {len(object_labels)}')
print(f'number of relationship labels: {len(relationship_labels)}')
print(f'number of attribute labels: {len(attribute_labels)}')
# sort the labels for both the objects and relationships and extract the 100 most frequent ones in the graphs
object_labels_occurrences = sorted(object_labels.items(), key=lambda x: x[1], reverse=True)[:200]
relationship_labels_occurrences = sorted(relationship_labels.items(), key=lambda x: x[1], reverse=True)[:100]
attribute_labels_occurrences = sorted(attribute_labels.items(), key=lambda x: x[1], reverse=True)[:100]
# extract the labels from the tuples
object_labels = [l[0] for l in object_labels_occurrences]
relationship_labels = [l[0] for l in relationship_labels_occurrences]
attribute_labels = [l[0] for l in attribute_labels_occurrences]

number of graphs: 108077
number of object labels: 100298
number of relationship labels: 34459
number of attribute labels: 65557


In [9]:
print(object_labels)
print(relationship_labels)
print(attribute_labels)

['man', 'person', 'window', 'tree', 'building', 'shirt', 'wall', 'woman', 'sign', 'sky', 'ground', 'grass', 'table', 'pole', 'head', 'light', 'water', 'car', 'hand', 'hair', 'people', 'leg', 'trees', 'clouds', 'ear', 'plate', 'leaves', 'fence', 'door', 'pants', 'eye', 'train', 'chair', 'floor', 'road', 'street', 'hat', 'snow', 'wheel', 'shadow', 'jacket', 'nose', 'boy', 'line', 'shoe', 'clock', 'sidewalk', 'boat', 'tail', 'cloud', 'handle', 'letter', 'girl', 'leaf', 'horse', 'bus', 'helmet', 'bird', 'giraffe', 'field', 'plane', 'flower', 'elephant', 'umbrella', 'dog', 'shorts', 'arm', 'zebra', 'face', 'windows', 'sheep', 'glass', 'bag', 'cow', 'bench', 'cat', 'food', 'bottle', 'rock', 'tile', 'kite', 'tire', 'post', 'number', 'stripe', 'surfboard', 'truck', 'logo', 'glasses', 'roof', 'skateboard', 'motorcycle', 'picture', 'flowers', 'bear', 'player', 'foot', 'bowl', 'mirror', 'background', 'pizza', 'bike', 'shoes', 'spot', 'tracks', 'pillow', 'shelf', 'cap', 'mouth', 'box', 'jeans', 'd

In [10]:
print(graphs[0].nodes[list(graphs[0].nodes)[0]])
print(graphs[0].edges[list(graphs[0].edges)[0]])

{'w': 79, 'h': 339, 'x': 421, 'y': 91, 'attributes': ['green', 'tall'], 'name': 'clock'}
{'synsets': ['along.r.01'], 'relationship_id': 15927, 'predicate': 'on'}


In [12]:
def build_filtered_graph(g, object_labels, relationship_labels, attribute_labels):
        G = nx.DiGraph()
        G.image_id=g.image_id
        G.image_w = g.image_w
        G.image_h = g.image_h
        G.labels = {}
        for n in g.nodes:
            if g.labels[n] in object_labels:
                filtered_attributes = [a for a in g.nodes[n]['attributes'] if a in attribute_labels]
                G.add_node(n, w=g.nodes[n]['w'], h=g.nodes[n]['h'], x=g.nodes[n]['x'], y=g.nodes[n]['y'], attributes=filtered_attributes, name=g.labels[n])
                G.labels[n] = g.labels[n]
        for e in g.edges:
            if g.edges[e]['predicate'] in relationship_labels and e[0] in G.nodes and e[1] in G.nodes:
                G.add_edge(e[0], e[1], synsets=g.edges[e]['synsets'].copy() ,relationship_id=g.edges[e]['relationship_id'], predicate=g.edges[e]['predicate'])
        return G

# filter the graphs relationships and objects and attributes to only keep the 100/200 most frequent ones. Remove graphs which have no objects or relationships left after filtering
filtered_graphs = []
for g in tqdm(graphs):
    g_filtered = build_filtered_graph(g, object_labels, relationship_labels, attribute_labels)
    if len(g_filtered.nodes) > 0 and len(g_filtered.edges) > 0:
        filtered_graphs.append(g_filtered)

100%|██████████| 108077/108077 [00:18<00:00, 5842.15it/s]


In [13]:
# print the graphs stats: number of graphs, number of objects, number of relationships, number of attributes
print(f'number of graphs: {len(filtered_graphs)}')
print(f'number of objects: {sum([len(g.nodes) for g in filtered_graphs])}')
print(f'number of relationships: {sum([len(g.edges) for g in filtered_graphs])}')
print(f'number of attributes: {sum([len(g.nodes[n]["attributes"]) for g in filtered_graphs for n in g.nodes])}')

number of graphs: 97216
number of objects: 2204832
number of relationships: 776509
number of attributes: 1001979


In [None]:
import torch
device = torch.device('cuda:1')
from open_clip import create_model_and_transforms
model,preprocess, _ = create_model_and_transforms('ViT-bigG-14', pretrained='laion2b_s39b_b160k', device=device) # the biggest model available
from open_clip import get_tokenizer
tokenizer = get_tokenizer(model_name='ViT-bigG-14')

In [10]:
batch_size = 2
labels_to_embed = ["a "+ l for l in object_labels]
chunked_object_labels = [labels_to_embed[i:i + batch_size] for i in range(0, len(labels_to_embed), batch_size)]
obj_label_embeddings = []
print(chunked_object_labels[0])
for i, chunk in enumerate(tqdm(chunked_object_labels)):
    chunk = tokenizer(chunk).to(device)
    emd_chunk = model.encode_text(chunk).detach().cpu().numpy()
    for s in emd_chunk:
        obj_label_embeddings.append(s)

['a man', 'a person']


100%|██████████| 100/100 [00:02<00:00, 34.32it/s]


In [13]:
# encode a radom astronaut image
from skimage import data
from PIL import Image
img = data.astronaut()
img_preprocessed = preprocess(Image.fromarray(img)).unsqueeze(0).to(device)
with torch.cuda.amp.autocast():
    img_embedding = model.encode_image(img_preprocessed).detach().cpu().numpy()
print(img_embedding.shape)

(1, 1280)


In [15]:
choice1 = obj_label_embeddings[0]
choice2 = obj_label_embeddings[1]
choice3 = obj_label_embeddings[2]
choice4 = obj_label_embeddings[3]
print("similarities: ")
print(img_embedding @ choice1.T, f"(for {labels_to_embed[0]})")
print(img_embedding @ choice2.T, f"(for {labels_to_embed[1]})")
print(img_embedding @ choice3.T, f"(for {labels_to_embed[2]})")
print(img_embedding @ choice4.T, f"(for {labels_to_embed[3]})")

similarities: 
[334.43994] (for a man)
[346.87274] (for a person)
[283.4137] (for a window)
[260.72723] (for a tree)


In [16]:
# save the filtered graphs with torch.save
import torch
torch.save(filtered_graphs, LOCAL_DATA_PATH+'processed/filtered_graphs.pt')
torch.save(filtered_graphs[0:100], LOCAL_DATA_PATH+'processed/filtered_graphs_test_small.pt')
torch.save(object_labels, LOCAL_DATA_PATH+'processed/filtered_object_labels.pt')
torch.save(obj_label_embeddings, LOCAL_DATA_PATH+'processed/filtered_object_label_embeddings.pt')
torch.save(relationship_labels, LOCAL_DATA_PATH+'processed/filtered_relationship_labels.pt')
torch.save(attribute_labels, LOCAL_DATA_PATH+'processed/filtered_attribute_labels.pt')

In [None]:
def plot_graph(g):
    pos = nx.nx_agraph.graphviz_layout(g, prog="dot")
    max_y = max([y for x,y in pos.values()])
    n_nodes_top = len([n for n in g.nodes if pos[n][1] == max_y])
    longest_label = max([len(g.labels[n]) for n in g.nodes])
    plt.figure(figsize=(max(n_nodes_top*longest_label/10,15),5))
    nx.draw(g,pos=pos,labels=g.labels, with_labels=True, node_size=10, node_color="lightgray", font_size=8)
    nx.draw_networkx_edge_labels(g,pos=pos,edge_labels=nx.get_edge_attributes(g,'predicate'),font_size=8)
    plt.show()
print(f'there are now {len(filtered_graphs)} many graphs left')
print(f'example graph:')
import random
idx = random.randint(0,len(filtered_graphs))
plot_graph(filtered_graphs[idx])
print(idx)

## Small script to clean the adversarial dataset attributes (don't run again)

In [16]:
import torch
path = '/local/home/jthomm/GraphCLIP/datasets/visual_genome/processed/ra_selections_curated_adversarial_nofilteredattr.pt'
ra_selections = torch.load(path)
print(len(ra_selections))
print(ra_selections[0])
new_selections = []
n_removed_attributes = 0
for sel in ra_selections:
    original_graph = sel[0]
    new_graph = build_filtered_graph(sel[0], object_labels, relationship_labels, attribute_labels)
    new_selections.append((new_graph, sel[1], sel[2]))
    for n in original_graph.nodes:
        if n not in new_graph.nodes:
            assert False # this should not happen
        for a in original_graph.nodes[n]['attributes']:
            if a not in new_graph.nodes[n]['attributes']:
                print("removed attribute ", a)
                n_removed_attributes += 1
        for a in new_graph.nodes[n]['attributes']:
            if a not in original_graph.nodes[n]['attributes']:
                assert False # this should not happen
    for e in original_graph.edges:
        if e not in new_graph.edges:
            assert False
print(f"removed {n_removed_attributes} attributes")

103
(<networkx.classes.digraph.DiGraph object at 0x7f007886f8b0>, (1058686, 5989), 'lying on')
removed attribute  wide
removed attribute  pedestrian line
removed attribute  faded
removed attribute  sunny
removed attribute  wide
removed attribute  pedestrian line
removed attribute  faded
removed attribute  sunny
removed attribute  carpeted
removed attribute  clean
removed attribute  dry erase board
removed attribute  cardboard
removed attribute  cardboard
removed attribute  folded
removed attribute  greenish
removed attribute  cork
removed attribute  full
removed attribute  folded
removed attribute  greenish
removed attribute  cork
removed attribute  full
removed attribute  folded
removed attribute  greenish
removed attribute  cork
removed attribute  full
removed attribute  balding
removed attribute  bald
removed attribute  floral
removed attribute  printed
removed attribute  white haired
removed attribute  gray haired
removed attribute  steel gray
removed attribute  covered
removed att

In [17]:
torch.save(new_selections, '/local/home/jthomm/GraphCLIP/datasets/visual_genome/processed/ra_selections_curated_adversarial.pt')