### !!!Do not forget to change the kernel to PHRILO 3.5!!!

In [1]:

%matplotlib inline
import os
import sys

sys.path.append('..')

import numpy as np
import pickle
import skimage.io as io
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from IPython.display import display, clear_output
import ipywidgets

import os, h5py, json, argparse
from os.path import isfile, join
from os import listdir

from collections import defaultdict


%reload_ext autoreload
%autoreload 2

In [2]:
class Phrase(object):
    def __init__(self,tokens = [], entity = 'N/A', etype = 'N/A'):
        self.tokens = tokens
        self.entity = entity
        self.etype  = etype
def get_sentence(sentence):
    return " ".join([" ".join(phrase.tokens) for phrase in sentence])
def get_entities(sentence):
    return [phrase.entity for phrase in sentence if phrase.entity != 'N/A']

In [3]:
def parse_sentence(tokens):
    tokens = tokens.split(' ')
    phrase = Phrase([])
    sentence = []
    open_bracket = False
    for token in tokens:
        if token[0] == '[':
            if open_bracket:
                print("ERROR! open bracket seen multiple times")
                raise NotImplementedError()
            if len(phrase.tokens) > 0:
                sentence.append(phrase)
                phrase = Phrase([])
            phrase.entity = int(token.split('/')[1].split('#')[1])
            phrase.etype  = token.split('/')[2]
            open_bracket = True
        elif token[-1] == ']':
            if not open_bracket:
                print("ERROR! closed bracket without opening one")
                raise NotImplementedError()
            token = token[:-1]
            phrase.tokens.append(token)
            sentence.append(phrase)
            open_bracket = False
            phrase = Phrase([])
        else:
            phrase.tokens.append(token)
    sentence.append(phrase)
    return sentence

In [4]:
def read_sentences(path = '../data/flickr30k-entities/Sentences'):
    files = [join(path, f) for f in listdir(path) if isfile(join(path, f))]
    idx2sentences = defaultdict(list)
    
    for f in files:
        for line in open(f):
            l = line.strip()
            s = parse_sentence(l)
            idx = int(f.split('/')[-1].replace(".txt",""))
            idx2sentences[idx].append(s)
    return idx2sentences

In [5]:
dataset = np.load('../data/flickr30k.imdb.npy')[()]

idx2sentences = read_sentences()

In [6]:
fname2idx = {}
idx2fname = {}
for idx,instance in enumerate(dataset):

    fname = int(instance['im_path'].split('/')[-1].replace('.jpg',''))
    fname2idx[fname] = idx
    idx2fname[idx]   = fname

In [7]:
def entity2boxid(regions):
    e2box = defaultdict(set)
    box2e = defaultdict(set)
    for i,region in enumerate(regions):
        bbox,ann_id,sentences = region
        for ann in ann_id:
            ann = int(ann)
            e2box[ann].add(i)
            box2e[i].add(ann)
    return e2box, box2e

In [8]:
def drawTree(sent,name):
    cf = CanvasFrame()
    t = Tree.fromstring(sent)
    tc = TreeWidget(cf.canvas(),t)
    tc['node_font'] = 'ubuntu 18 bold'
    tc['leaf_font'] = 'ubuntu 16 bold'
#     tc['node_color'] = '#007996'
#     tc['leaf_color'] = '#668D3C'
#     tc['line_color'] = '#816C5B'
    
    tc['node_color'] = '#636363'
    tc['leaf_color'] = '#000000'
    tc['line_color'] = '#000000'
    
    cf.add_widget(tc,0,0) # (10,10) offsets
    cf.print_to_file('{}.ps'.format(name))
    os.system('convert {}.ps {}'.format(name,name))
    cf.destroy()

def drawTrees(fname):
    l = [line.strip().replace("@","") for line in open(fname)]
    tree = T.from_sexpr(l[0])
    drawTree(tree.children[0].children[1].children[1].getRaw(),'tree.png')
    #drawTree(l[0],'tree.png')
    t = T.from_sexpr(l[0])
    t.preClean()                                                                                                                                  
    t.adopt()                                                                                                                                     
    t.adopt()                                                                                                                                     
    t.preTriplet(blacklist)                                                                                                                       
    to3 = t.children[0] if len(t.children) == 1 else t                                                                                            
    d,ct_raw = to3.findTriplet(blacklist)                                                                                          
    drawTree(ct_raw, 'ctree.png')
    
def plotInstance(instance, gid = [], pid = 0,fname = 'prediction.png', add_text = True,
                idx_set = set()):
    
    im_path = instance['im_path']
    regions = instance['regions'] # FIX the path to image with .replace()
        
    fig = plt.figure()

    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
              
    I = io.imread(im_path)
    imshape = I.shape
    ax.imshow(I)

    gold = False
    idx = 0
    drawn = set()
    lwidth = 2
    for i,region in enumerate(regions):
        if idx_set != set() and i not in idx_set:
            #print "SKIPPING UNMENTIONED BOX:",i
            continue
        bbox,sentences,ann_id = region

        pcolor = '#FF9642'
        if i == pid:
            pcolor = '#C0362C'
        if i in set(gid):
            pcolor = '#007996'
        if i == pid and i in set(gid):
            pcolor = '#668D3C'
            
        box_plot = Rectangle((bbox[0], bbox[1]), bbox[2]-bbox[0], bbox[3]-bbox[1], fill=False, edgecolor=pcolor, linewidth=lwidth)
        ax.add_patch(box_plot)
        if add_text:
            ax.text(bbox[0]-25, bbox[1]-15, '{}'.format(i),
                    bbox=dict(facecolor='#ad4141',alpha = 0.5),fontsize = 18, color = 'white')
    fig.tight_layout()
    fig.savefig(fname)
    plt.close()

    return imshape

In [10]:
input_nl   = ipywidgets.Label(value="")

ibutton_next = ipywidgets.Button(description="Next")
ibutton_prev = ipywidgets.Button(description="Previous")
islider = ipywidgets.IntSlider(description='Instance#')

sbutton_next = ipywidgets.Button(description="Next Sent")
sbutton_prev = ipywidgets.Button(description="Previous Sent")
sslider = ipywidgets.IntSlider(description='Sentence#')

cb  = ipywidgets.Checkbox(value=False, description='Only Annotated Boxes',disabled=False)

islider.max = len(dataset)-1
islider.value = 0
sslider.max = 4
sslider.value = 0

NL   = ipywidgets.HBox([input_nl])

ihbox,shbox,tbox,inputs = None,None,None,None

def on_nexti(b):
    sslider.value = 0    
    islider.value += 1
    on_button_clicked(None)
    
def on_previousi(b):
    sslider.value = 0
    islider.value -= 1
    on_button_clicked(None)

def on_nexts(b):
    sslider.value += 1
    on_button_clicked(None)
    
def on_previouss(b):
    sslider.value -= 1
    on_button_clicked(None)
    
def handle_islider_change(change):
    sslider.value = 0
    on_button_clicked(None)
    
def handle_sslider_change(change):
    on_button_clicked(None)    

def on_button_clicked(change):
    instance_idx = islider.value
    sentence_idx = sslider.value
    instance = dataset[instance_idx]
    
    e2box, box2e = entity2boxid(instance['regions'])
    
    
    
    all_idx = set()
    s = idx2sentences[idx2fname[instance_idx]][sentence_idx]
    token = []
    for phrase in s:
        if phrase.entity != 'N/A' and phrase.entity in e2box:
            idx = " ".join([str(val) for val in list(e2box[phrase.entity])])
            tok = '[' +" ".join(phrase.tokens) + '](' + idx +')'
            all_idx = all_idx | e2box[phrase.entity]
        else:
            tok = " ".join(phrase.tokens)
        token += [tok]
    if cb.value:
        idx_set = all_idx
    else:
        idx_set = set()
    shape = plotInstance(instance, gid = [], pid = [], idx_set = idx_set)
    _     = plotInstance(instance, gid = [], pid = [], idx_set = idx_set,
                         add_text = False,fname = 'prediction.noboxid.png')
   
    tf = open("prediction.png", "rb")
    anns = tf.read()

    c = 0.5
    w = 300
    if not on_button_clicked.image_on:
        ANNS  = ipywidgets.Image(value=anns,format='png')

        on_button_clicked.image_on = True
        on_button_clicked.ANNS  = ANNS

    else:
        on_button_clicked.ANNS.value  = anns
    clear_output(wait=True)
    if on_button_clicked.image_on:
        display(ihbox,shbox, tbox,inputs)
    print(" ".join(token))


cb.observe(on_button_clicked)

on_button_clicked.image_on = False


islider.observe(handle_islider_change, names='value')
sslider.observe(handle_sslider_change, names='value')


ibutton_next.on_click(on_nexti)
ibutton_prev.on_click(on_previousi)
sbutton_next.on_click(on_nexts)
sbutton_prev.on_click(on_previouss)

on_button_clicked(None)

inputs = ipywidgets.VBox([NL])
ihbox  = ipywidgets.HBox([ibutton_prev, ibutton_next, islider,cb])
shbox  = ipywidgets.HBox([sbutton_prev, sbutton_next, sslider])
tbox   = ipywidgets.HBox([on_button_clicked.ANNS]) 


display(ihbox,shbox, tbox,inputs)


HBox(children=(Button(description='Previous', style=ButtonStyle()), Button(description='Next', style=ButtonSty…

HBox(children=(Button(description='Previous Sent', style=ButtonStyle()), Button(description='Next Sent', style…

HBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xb0\x00\x00\x01 \x08\x06\x00\x00\…

VBox(children=(HBox(children=(Label(value=''),)),))

[Two pudgy](0 1) , [middle-aged women](0 1) stand in [a parking lot](3) next to [an empty shopping cart](2) .
