In [1]:
import argparse
import random
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data import Dataset, DataLoader
import wandb
import sys
sys.path.append('../')
from dataset_clevr_ryan import RelationalDataset, BoundingBox
from utils import *
from tqdm.auto import tqdm

from einops import rearrange, reduce
from einops.layers.torch import Rearrange
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('../bbox_classifier')
from classifier import BboxClassifier
from eval_pipeline import *
metric_model = BboxClassifier()
metric_model.load_state_dict(torch.load('../bbox_classifier/4-layer-DNN-48_multi_rels-100.pth'))
def single_image_eval(bboxes, relations, relations_ids, eval_info = EvalInfo()):
    bboxes.double()
    relations.double()
    relations_ids.double()
    # print("entered single image eval")
    correct_relations = 0
    for (i, rel) in enumerate(relations):
        (a, b) = relations_ids[i]
        a = a.item()
        b = b.item()
        # print("?", bboxes[a],bboxes[b], rel)
        rel_id = rel[-1]
        input = torch.concat([bboxes[a].cuda(), bboxes[b].cuda(), torch.tensor([rel_id]).cuda()])
        input = input.cuda().double()
        metric_model.cuda().double()
        pred = metric_model(input)[0].item()
        correct_relations += pred > 0.5
        eval_info.update(rel_id, pred > 0.5)

        # print(f"{i}th relation: {pred} vs {rel[0]}")
    return correct_relations / len(relations), eval_info

In [16]:
def evaluate(data, obj_num, wandb_drawer = None, no_above_below = True):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    metric_model.to(device)
    metric_model.eval()
    all_gen_bboxes, all_relations, all_relation_ids = data

    size = len(all_gen_bboxes)
    # assert size == 100

    eval_info = EvalInfo()
    scores = []

    for i in range(size):
        bboxes = all_gen_bboxes[i]
        relations = all_relations[i]
        relations_ids = all_relation_ids[i]
        # if i == 0:
        #     print("sample bboxes", bboxes)
        score, eval_info = single_image_eval(bboxes, relations, relations_ids, eval_info)
        scores.append(score)
    
    bboxes = [[BoundingBox(e.tolist()) for e in bboxes] for bboxes in all_gen_bboxes]

    images = [None] * 8
    for i in range(8):
        image = Image.new('RGB', (256, 256), (255, 255, 255))
        colours = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] # red, green, blue, yellow, cyan, magenta
        for j, bbox in enumerate(bboxes[i]):
            image = bbox.draw(image, color=colours[j % len(colours)])
        images[i] = image
    if wandb_drawer is not None:
        wandb_drawer.log({"images": [wandb.Image(image) for image in images]}, step = obj_num)
    # save images to file
    for i in range(8):
        images[i].save(f"images/Random_{obj_num}_{i}.png")

    if no_above_below:
        avg_score = (eval_info.correct_relations[0] + eval_info.correct_relations[1] + eval_info.correct_relations[2] + eval_info.correct_relations[3]) / (eval_info.total_relations[0] + eval_info.total_relations[1] + eval_info.total_relations[2] + eval_info.total_relations[3])
    else:
        avg_score = sum(scores) / len(scores)
    

    # separately log each relation's acc
    eval_info_list = eval_info.to_list()
    eval_info.print()
    for i in range(6):
        print(f"acc_{eval_info.relation_names[i]}: {eval_info_list[i]}")
        if wandb_drawer is not None:
            wandb_drawer.log({f"acc_{eval_info.relation_names[i]}": eval_info_list[i]}, step = obj_num)
    print(f"avg_acc: {avg_score}")
    if wandb_drawer is not None:
        wandb_drawer.log({"acc": avg_score}, step = obj_num)
    return avg_score

    

In [17]:

wandb_drawer = None
wandb_drawer = wandb.init(
            project="diffusion_bbox_eval",
            name=f"random",
            save_code=True)
for obj_num in [2,3,4,5,6,7,8,9]:
    dataset = RelationalDatasetxO(obj_num, upperbound = 1000)

    bboxes = []
    relations = []
    relations_ids = []

    for i in range(len(dataset)):
        clean_image, objects, rels, bbes, generated_prompt, raw_image, raw_image_tensor, rels_ids = dataset[i]
        bboxes.append(bbes)
        relations.append(rels)
        relations_ids.append(rels_ids)

    data = (bboxes, relations, relations_ids)
    evaluate(data, obj_num, wandb_drawer = wandb_drawer)
    
if wandb_drawer is not None:
    wandb_drawer.finish()

Empty images


  return torch.tensor(self.pos)


left: 394 / 736
right: 180 / 736
front: 260 / 733
behind: 444 / 733
above: 0 / 0
below: 0 / 0
acc_left: 0.5353260869565217
acc_right: 0.24456521739130435
acc_front: 0.35470668485675305
acc_behind: 0.6057298772169167
acc_above: 0
acc_below: 0
avg_acc: 0.43498978897208984
Empty images
left: 1180 / 2152
right: 543 / 2152
front: 711 / 2164
behind: 1278 / 2164
above: 0 / 0
below: 0 / 0
acc_left: 0.5483271375464684
acc_right: 0.25232342007434944
acc_front: 0.3285582255083179
acc_behind: 0.5905730129390019
acc_above: 0
acc_below: 0
avg_acc: 0.43002780352177944
Empty images
left: 2391 / 4338
right: 1038 / 4338
front: 1495 / 4346
behind: 2612 / 4346
above: 0 / 0
below: 0 / 0
acc_left: 0.5511756569847857
acc_right: 0.2392807745504841
acc_front: 0.34399447768062585
acc_behind: 0.6010124252185918
acc_above: 0
acc_below: 0
avg_acc: 0.43390142791340397
Empty images
left: 3988 / 7236
right: 1844 / 7236
front: 2535 / 7175
behind: 4357 / 7175
above: 0 / 0
below: 0 / 0
acc_left: 0.5511332227750139
acc_r