# IoU Filter Setting

In [None]:
import ujson
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from visual_genome import api as vg
from PIL import Image as PIL_Image
import requests
from io import BytesIO

In [None]:
# load data
file = open("VQAmb_Dataset.json")
data = ujson.load(file)
file.close()

In [None]:
THRESHOLD = .8 # IoU threshold
ELIM_THRESHOLD = .2

In [None]:
def iouFilter(o1, o2):
    w = 0
    h = 0
    found = False
    if o1['w'] * o1['h'] > o2['w'] * o2['h']: # look at smaller box
        o3 = o1
        o1 = o2
        o2 = o3
    y = o1['y']
    while y < o1['y'] + o1['h']:
        if found:
            if y >= o2['y'] and y < o2['y'] + o2['h']:
                h += 1
            else:
                break
        else: # if still searching for overlapping box
            x = o1['x']
            while x < o1['x'] + o1['w']:
                # if point is in both boxes
                if x >= o2['x'] and x < o2['x'] + o2['w'] and y >= o2['y'] and y < o2['y'] + o2['h']:
                    if found == False:
                        found = True
                        w = 1
                        h = 1
                    else:
                        w += 1
                else:
                    if found: # first point outside of overlapping box on first line
                        break
                x += 1
        y += 1
    areaIntersect = w * h
    areaUnion = o1['w'] * o1['h'] + o2['w'] * o2['h'] - areaIntersect
    iou = areaIntersect / areaUnion
    return iou

In [None]:
data['1']

In [None]:
# coalesce appropriate all_objs, all_ans, points and return True is question is ambiguous
def coalesce(quest, imgID):
    #image = vg.get_image_data(id=imgID)
    objs = quest['all_objs']
    ans = quest['all_ans']
    i = 0
    # iterate through each pair of objects
    while i < len(objs)-1:
        j = i + 1
        while j < len(objs):
            iou = iouFilter(objs[i], objs[j])
            if iou >= THRESHOLD: # IoU threshold: # bounding boxes are the "same" (enough to coalesce attributes)
                """
                # illustrate bounding boxes and image
                regions = [{'x': objs[i]['x'], 'y': objs[i]['y'], 'w': objs[i]['w'], 'h': objs[i]['h']}, {'x': objs[j]['x'], 'y': objs[j]['y'], 'w': objs[j]['w'], 'h': objs[j]['h']}]
                visualize_regions(image, regions)
                """
                
                # coalesce attributes
                objs[i]['attributes'] += objs[j]['attributes']
                objs[i]['attributes'] = list(set(objs[i]['attributes'])) # remove duplicates
                
                # check multiple correct answers are in coalesced attributes
                if set(ans).issubset(objs[i]['attributes']):
                    objs.pop(i)
                    i -= 1
                    break
                
                # coalesce to bigger box coordinates
                if objs[i]['w'] * objs[i]['h'] < objs[j]['w'] * objs[j]['h']:
                    objs[i]['w'] = objs[j]['w']
                    objs[i]['h'] = objs[j]['h']
                    objs[i]['x'] = objs[j]['x']
                    objs[i]['y'] = objs[j]['y']

                # remove item
                objs.pop(j)
                j -= 1
            elif iou >= ELIM_THRESHOLD: # bounding boxes are close but not similar enough to coalesce attributes
                """
                # illustrate bounding boxes and image
                regions = [{'x': objs[i]['x'], 'y': objs[i]['y'], 'w': objs[i]['w'], 'h': objs[i]['h']}, {'x': objs[j]['x'], 'y': objs[j]['y'], 'w': objs[j]['w'], 'h': objs[j]['h']}]
                visualize_regions(image, regions)
                """
                
                # keep bigger bounding box
                if objs[i]['w'] * objs[i]['h'] < objs[j]['w'] * objs[j]['h']:
                    objs[i] = objs[j]
                # remove item
                objs.pop(j)
                j -= 1
        
            j += 1
        i += 1
    
    newAns = []
    if len(quest['all_objs']) > 1:
        # coalesce answers
        ans = quest['all_ans']
        quest['points'] = []
        o = 0
        while o < len(objs):
            obj = objs[o]
            notFound = 1 # attribute not found in this object
            for attrb in obj['attributes']: # attributes of this object
                if attrb in ans:
                    notFound = 0
                    newAns.append(attrb)
                    quest['points'].append({'x': int(obj['x'] + obj['w'] / 2), 'y': int(obj['y'] + obj['h'] / 2), 'ans': attrb})
                    break
            if notFound == 1:
                objs.pop(o) # remove objects that don't have the answer attribute
                o -= 1
            o += 1            
                
        quest['all_ans'] = list(set(newAns))
    
    return len(newAns) > 1


In [None]:
iou = {}
questionCountOrig = 0
questionCountNew = 0
#count = 0

for img in data:
    #count += 1
    #print(count/len(data))
    iou[img] = []
    lastIndex = 0
    for quest in data[img]:
        questionCountOrig += 1
        iou[img] += [quest]

        if coalesce(iou[img][lastIndex], img) == False:
            iou[img].pop(lastIndex) # remove question
            lastIndex -= 1
        else:
            questionCountNew += 1 # question confirmed
        lastIndex += 1
    if len(iou[img]) == 0:
        iou.pop(img) # remove image if not needed
        

#print(iou)

print("Original # of Images: " + str(len(data)))
print("New # of Images: " + str(len(iou)))
print("Original Question Count: " + str(questionCountOrig))
print("New Question Count: " + str(questionCountNew))

In [None]:
def visualize_points(imgID, objs, points):
    image = vg.get_image_data(id=imgID)
    response = requests.get(image.url)
    img = PIL_Image.open(BytesIO(response.content))
    plt.imshow(img)
    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)
    ax = plt.gca()
    for o in range(len(points)):
        ax.plot(points[o]['x'],points[o]['y'], 'ro')
        ax.text(points[o]['x'], points[o]['y'], points[o]['ans'], style='italic', bbox={'facecolor':'red', 'alpha':.5, 'pad':10})
    for o in objs:
        ax.add_patch(Rectangle((o['x'], o['y']),
                                o['w'],
                                o['h'],
                                fill=False,
                                edgecolor='red',
                                linewidth=3))
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()

In [None]:
import random

In [None]:
sample = random.sample(list(iou), int(.1*len(iou)))

In [None]:
for img in sample:
    imgID = img # 497995
    for quest in iou[img]:
        pts = quest['points']
        question = quest['question']
        o = quest['all_objs']
        print(question)
        print(len(o))
        visualize_points(imgID, o, pts)

In [None]:
"""
def visualize_regions(image, regions):
    response = requests.get(image.url)
    img = PIL_Image.open(BytesIO(response.content))
    plt.imshow(img)
    ax = plt.gca()
    for region in regions:
        ax.add_patch(Rectangle((region.x, region.y),
                               region.width,
                               region.height,
                               fill=False,
                               edgecolor='red',
                               linewidth=3))
        ax.text(region.x, region.y, region.phrase, style='italic', bbox={'facecolor':'white', 'alpha':0.7, 'pad':10})
    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()
"""
def visualize_regions(image, regions): # second parameter is list of dictionaries
    response = requests.get(image.url)
    img = PIL_Image.open(BytesIO(response.content))
    plt.imshow(img)
    ax = plt.gca()
    for region in regions:
        ax.add_patch(Rectangle((region['x'], region['y']),
                               region['w'],
                               region['h'],
                               fill=False,
                               edgecolor='red',
                               linewidth=3))
        # ax.text(region.x, region.y, region.phrase, style='italic', bbox={'facecolor':'white', 'alpha':0.7, 'pad':10})
    fig = plt.gcf()
    fig.set_size_inches(18.5, 10.5)
    plt.tick_params(labelbottom='off', labelleft='off')
    plt.show()

In [None]:
iou = {} # revised dictionary
# iterate questions

"""
count = 0
for img in data:
    iou[img] = []
    if count < 10:
        for quest in data[img]:
            print(quest)
            iou[img] += [quest]
            lastIndex = len(iou[img]) - 1
            if coalesce(iou[img][lastIndex]) == False:
                iou[img].pop(lastIndex)
        if len(iou[img]) == 0:
            iou.pop(img)
    else:
        break
    count += 1
print(iou)
"""

In [None]:
# write to new json file
with open("VQAmb_IoU_5.json", "w") as outfile: 
    ujson.dump(iou, outfile)

"""
Section End
"""

In [None]:
for img in data:
    print(img)

References
- https://www.w3schools.com/python/python_casting.asp
- https://docs.python.org/2/library/sets.html
- https://matplotlib.org/3.1.0/tutorials/introductory/pyplot.html
