In [2]:
from jt_training import get_dataloader, train_one_epoch, evaluate, get_free_gpu


def load_model(clip_model_type, clip_pretrained_dataset, n_rel_classes, n_obj_classes, n_attr_classes, shallow=True, input_mode="text_embeddings", with_object_heads=False):
    from open_clip.jt_ViT_RelClassifier_lightning import ViT_RelClassifier
    model = ViT_RelClassifier(n_rel_classes, n_obj_classes, n_attr_classes, clip_model_type, clip_pretrained_dataset, shallow=shallow, mode=input_mode, with_object_heads=with_object_heads)
    prepocess_function = model.preprocess
    device = get_free_gpu(min_mem=20000)
    print(f"Using device {device}")
    model.to(device)
    return model, prepocess_function, device

In [3]:
import torch
clip_model_type = 'ViT-L-14' # 'ViT-L-14' #'ViT-B/32'
clip_pretrained_dataset = 'laion2b_s32b_b82k' # 'laion2b_s32b_b82k' #'laion400m_e32'
image_dir = "/local/home/jthomm/GraphCLIP/datasets/visual_genome/raw/VG/"
metadata_path = "/local/home/jthomm/GraphCLIP/datasets/visual_genome/processed/"


model, prepocess_function, device = load_model(clip_model_type, clip_pretrained_dataset, 100, 200, 100, with_object_heads=True)

loaded = torch.load('/local/home/jthomm/GraphCLIP/experiments/2023-05-27/vision_transformer_39/model_epoch-v29.ckpt', map_location=device)
# loaded = torch.load('/local/home/jthomm/GraphCLIP/experiments/2023-06-24/vision_transformer_8/model_epoch-v9.ckpt', map_location= device)
print(loaded.keys())
model.load_state_dict(loaded['state_dict'])


Using text embeddings as input to the model.
Using device cuda:2
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters', 'datamodule_hparams_name', 'datamodule_hyper_parameters'])


<All keys matched successfully>

In [4]:
dataset_train, dataset_val, graphs_train, graphs_val = get_dataloader(prepocess_function,metadata_path,image_dir, testing_only=False, get_pure_graphs=True)

Loading filtered graphs...
100
Done loading filtered graphs.
Filtered relationships loaded from file
Filtered objects loaded from file
Filtered attributes loaded from file
Filtered objects loaded from file
Filtered relationships loaded from file
Filtered objects loaded from file
Filtered attributes loaded from file
Filtered objects loaded from file


In [10]:
from tqdm import tqdm, trange
# shuffle not in place
import random
graphs_val_list = list(graphs_val)
random.shuffle(graphs_val_list)


first_n = 10000000
len(graphs_val)
print(len(graphs_val))
mean_recall_at_50 = 0
mean_recall_at_100 = 0
mean_recall_at_50_k100 = 0
mean_recall_at_100_k100 = 0
t = trange(len(graphs_val), desc='Bar desc', leave=True)
n_processed = 0
for i in t:
    graph = graphs_val_list[i]
    if len(graph.edges) == 0:
        continue
    first_n -= 1
    if first_n < 0:
        break
    guesses_k1 = []
    true_rels = []
    guesses_k100 = []
    for edge in graph.edges:
        obj1_name = graph.nodes[edge[0]]['name']
        obj2_name = graph.nodes[edge[1]]['name']
        predicate = graph.edges[edge]['predicate']
        image, full_text_clip_embd, rel_label, obj1_label, obj2_label, attr1_label, attr2_label, rel_mask, attr1_mask, attr2_mask = dataset_val.getitem_from_id_edge(graph.image_id, edge, "text_embeddings")
        rel, obj1, obj2, attr1, attr2 = model(image.unsqueeze(0).to(device), full_text_clip_embd.unsqueeze(0).to(device))
        true_rel_confidence = rel[0][rel_label].item()
        true_rel = rel_label
        pred_rel_confidence = rel[0].max().item()
        pred_rel = rel[0].argmax().item()
        true_rels.append((obj1_name, true_rel, obj2_name))
        guesses_k1.append((true_rel_confidence, obj1_name, pred_rel, obj2_name))
        guesses_k100 += [(rel[0][i].item(), obj1_name, i, obj2_name) for i in range(100)]

    guesses_k1.sort(key=lambda x: x[0], reverse=True)
    guesses_k100.sort(key=lambda x: x[0], reverse=True)
    # calculate recall at 50 and 100
    recall_at_50 = 0
    recall_at_100 = 0
    recall_at_50_k100 = 0
    recall_at_100_k100 = 0
    for i in range(50):
        if i>=len(guesses_k1):
            break
        if guesses_k1[i][1:] in true_rels:
            recall_at_50 += 1
    for i in range(100):
        if i>=len(guesses_k1):
            break
        if guesses_k1[i][1:] in true_rels:
            recall_at_100 += 1
    for i in range(50):
        if i>=len(guesses_k100):
            break
        if guesses_k100[i][1:] in true_rels:
            recall_at_50_k100 += 1
    for i in range(100):
        if i>=len(guesses_k100):
            break
        if guesses_k100[i][1:] in true_rels:
            recall_at_100_k100 += 1
    mean_recall_at_50 += recall_at_50/len(true_rels)
    mean_recall_at_100 += recall_at_100/len(true_rels)
    mean_recall_at_50_k100 += recall_at_50_k100/len(true_rels)
    mean_recall_at_100_k100 += recall_at_100_k100/len(true_rels)
    n_processed += 1

    t.set_description(f"Mean recall at 50: {mean_recall_at_50/n_processed}, Mean recall at 100: {mean_recall_at_100/n_processed}, Mean recall at 50 k100: {mean_recall_at_50_k100/n_processed}, Mean recall at 100 k100: {mean_recall_at_100_k100/n_processed}")
    

print(f"Mean recall at 50: {mean_recall_at_50/100}")
print(f"Mean recall at 100: {mean_recall_at_100/100}")
print(f"Mean recall at 50 k100: {mean_recall_at_50_k100/100}")
print(f"Mean recall at 100 k100: {mean_recall_at_100_k100/100}")


19419


Mean recall at 50: 0.6089737843343489, Mean recall at 100: 0.609061925747666, Mean recall at 50 k100: 0.954872754106252, Mean recall at 100 k100: 1.013989807486511:  62%|██████▏   | 11965/19419 [53:24<33:16,  3.73it/s]     


KeyboardInterrupt: 

In [6]:
from tqdm import tqdm, trange
# shuffle not in place
import random
graphs_val_list = list(graphs_val)
random.shuffle(graphs_val_list)


first_n = 10000000
len(graphs_val)
print(len(graphs_val))
mean_recall_at_50 = 0
mean_recall_at_100 = 0
mean_recall_at_50_k100 = 0
mean_recall_at_100_k100 = 0
mean_acc = 0
acc_steps = 0
t = trange(len(graphs_val), desc='Bar desc', leave=True)
n_processed = 0
for i in t:
    graph = graphs_val_list[i]
    if len(graph.edges) == 0:
        continue
    first_n -= 1
    if first_n < 0:
        break
    guesses_k1 = []
    true_rels = []
    guesses_k100 = []
    for edge in graph.edges:
        obj1_name = graph.nodes[edge[0]]['name']
        obj2_name = graph.nodes[edge[1]]['name']
        predicate = graph.edges[edge]['predicate']
        image, full_text_clip_embd, rel_label, obj1_label, obj2_label, attr1_label, attr2_label, rel_mask, attr1_mask, attr2_mask = dataset_val.getitem_from_id_edge(graph.image_id, edge, "text_embeddings")
        rel, obj1, obj2, attr1, attr2 = model(image.unsqueeze(0).to(device), full_text_clip_embd.unsqueeze(0).to(device))
        true_rel_confidence = rel[0][rel_label].item()
        true_rel = rel_label
        pred_rel_confidence = rel[0].max().item()
        pred_rel = rel[0].argmax().item()
        if true_rel == pred_rel:
            mean_acc += 1
        acc_steps += 1
        true_rels.append((obj1_name, true_rel, obj2_name))
        guesses_k1.append((true_rel_confidence, obj1_name, pred_rel, obj2_name))
        guesses_k100 += [(rel[0][i].item(), obj1_name, i, obj2_name) for i in range(100)]

    guesses_k1.sort(key=lambda x: x[0], reverse=True)
    guesses_k100.sort(key=lambda x: x[0], reverse=True)
    # calculate recall at 50 and 100
    recall_at_50 = 0
    recall_at_100 = 0
    recall_at_50_k100 = 0
    recall_at_100_k100 = 0
    for i in range(50):
        if i>=len(guesses_k1):
            break
        if guesses_k1[i][1:] in true_rels:
            recall_at_50 += 1
    for i in range(100):
        if i>=len(guesses_k1):
            break
        if guesses_k1[i][1:] in true_rels:
            recall_at_100 += 1
    for i in range(50):
        if i>=len(guesses_k100):
            break
        if guesses_k100[i][1:] in true_rels:
            recall_at_50_k100 += 1
    for i in range(100):
        if i>=len(guesses_k100):
            break
        if guesses_k100[i][1:] in true_rels:
            recall_at_100_k100 += 1
    mean_recall_at_50 += recall_at_50/len(true_rels)
    mean_recall_at_100 += recall_at_100/len(true_rels)
    mean_recall_at_50_k100 += recall_at_50_k100/len(true_rels)
    mean_recall_at_100_k100 += recall_at_100_k100/len(true_rels)
    n_processed += 1

    t.set_description(f"Mean recall at 50: {mean_recall_at_50/n_processed}, Mean recall at 100: {mean_recall_at_100/n_processed}, Mean recall at 50 k100: {mean_recall_at_50_k100/n_processed}, Mean recall at 100 k100: {mean_recall_at_100_k100/n_processed}, Mean acc: {mean_acc/acc_steps}")
    

print(f"Mean recall at 50: {mean_recall_at_50/100}")
print(f"Mean recall at 100: {mean_recall_at_100/100}")
print(f"Mean recall at 50 k100: {mean_recall_at_50_k100/100}")
print(f"Mean recall at 100 k100: {mean_recall_at_100_k100/100}")

19419


Mean recall at 50: 0.5662773817134613, Mean recall at 100: 0.5662773817134613, Mean recall at 50 k100: 0.9311437132423748, Mean recall at 100 k100: 0.9903194200546717, Mean acc: 0.596767333049766:   2%|▏         | 295/19419 [01:19<1:25:29,  3.73it/s]  


KeyboardInterrupt: 