# Compare EPViT with the baselines CLIP and histogram

In [None]:
import relational_image_generation_evaluation as rige

### Datasets

In [None]:
dataset_daa = rige.get_adversarial_attribute_dataset()

def graph_to_caption(graph):
    sid, oid = list(graph.edges)[0]
    rel = graph.edges[(sid, oid)]['predicate']
    os = graph.nodes[sid]['name']
    ats = graph.nodes[sid]['attributes']
    oo = graph.nodes[oid]['name']
    ato = graph.nodes[oid]['attributes']
    rel_str = f"{','.join(ats)} {os} {rel} {','.join(ato)} {oo}".strip().lower()
    rel_str += '.'
    return rel_str

import random
random.seed(123456)

datasets = [("daa", dataset_daa)]

### Evaluation

In [None]:
from PIL import Image
def evaluate(evaluator, dataset):
    orig_graphs = []
    adv_graphs = []
    images = []
    for sample in dataset:
        orig_graph = sample['original_graph']
        orig_graphs.append(orig_graph)
        adv_graph = sample['adv_graph']
        adv_graphs.append(adv_graph)
        image_id = orig_graph.image_id
        assert orig_graph.image_id == adv_graph.image_id
        # adapt to your local directory
        try:
            IMAGE_DIR = '../datasets/visual_genome/raw/VG_100K/'
            image = Image.open(IMAGE_DIR + str(image_id) + '.jpg')
        except:
            IMAGE_DIR = '../datasets/visual_genome/raw/VG_100K_2/'
            image = Image.open(IMAGE_DIR + str(image_id) + '.jpg')
        images.append(image)
    orig_scores = evaluator(images, orig_graphs)
    print(".............................")
    adv_scores = evaluator(images, adv_graphs)
    
    if 'rel_scores' in orig_scores:
    # EPViT...
        n_correct = 0
        n_total = len(dataset)
        n_equal = 0
        for i in range(n_total):
            if orig_scores['attr_scores'][i] == 'noattributes':
                orig_score = orig_scores['rel_scores'][i]
            else:
                orig_score = orig_scores['rel_scores'][i] + orig_scores['attr_scores'][i]
            if adv_scores['attr_scores'][i] == 'noattributes':
                adv_score = adv_scores['rel_scores'][i]
            else:
                adv_score = adv_scores['rel_scores'][i] + adv_scores['attr_scores'][i]
            if orig_score > adv_score:
                n_correct += 1
            elif orig_score == adv_score:
                n_correct += 0.5
                n_equal += 1
        print("n_equal", n_equal)
        acc = n_correct / n_total
    elif 'overall_scores' in orig_scores:
        n_correct = 0
        n_total = len(dataset)
        for i in range(n_total):
            orig_score = orig_scores['overall_scores'][i]
            adv_score = adv_scores['overall_scores'][i]
            if orig_score > adv_score:
                n_correct += 1
            elif orig_score == adv_score:
                n_correct += 0.5
        acc = n_correct / n_total
    res = {
        'acc': acc
    }
    return res

#### EPViT

In [None]:
evaluator_epvit = rige.Evaluator('ViT-L/14')
for name, dataset in datasets:
    acc = evaluate(evaluator_epvit, dataset)
    print('epvit', name, acc)
del evaluator_epvit

#### CLIP ViT-L/14

In [None]:
evaluator_clip = rige.Evaluator('CLIP_ViT-L/14')
for name, dataset in datasets:
    acc = evaluate(evaluator_clip, dataset)
    print('clip-l/14', name, acc)
del evaluator_clip

#### CLIP ViT-G/14

In [None]:
evaluator_clip = rige.Evaluator('CLIP_ViT-G/14')
for name, dataset in [("dar", dataset_dar)]:
    acc = evaluate(evaluator_clip, dataset)
    print('clip-G/14', name, acc)
del evaluator_clip

#### Histogram

In [None]:
evaluator_hist = rige.Evaluator('histogram')

for name, dataset in datasets:
    acc = evaluate(evaluator_hist, dataset)
    print('histogram', name, acc)
del evaluator_hist