In [7]:
import warnings
warnings.filterwarnings('ignore','.*conversion.*')

import os
import h5py
from config import opt
import main
import ipywidgets as widgets
from IPython.display import display
from IPython.display import clear_output
from ipywidgets import (interact, interact_manual, interactive, fixed)
from evaluation import load_result
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from pylab import *

if opt.hdf5:
    from datasets import Train_Dataset_HDF5 as Train_Dataset
    from datasets import Test_Dataset_HDF5 as Test_Dataset
else:
    from datasets import Train_Dataset_IMAGE as Train_D
    from datasets import Test_Dataset_IMAGE as Test_Dataset
    
dataset_name = widgets.Dropdown(
                        options=['Market1501', 'DukeMTMC'],
                        value='Market1501',
                        description='Dataset:',
                        disabled=False,
                    )
model = widgets.Dropdown(
                    options=['ResNet50', 'DenseNet121'],
                    value='ResNet50',
                    description='Model:',
                    disabled=False,
                )

re_ranking = widgets.Checkbox(
                value=True,
                description='Re_Ranking:',
                disabled=False
            )

load_features = widgets.Checkbox(
                value=True,
                description='Load Extracted Features:',
                disabled=False
            )


train_button = widgets.Button(description="Train")
test_button = widgets.Button(description='Test')
items = [train_button,test_button]
button = widgets.Box(items)

out = widgets.Output()
def click_train(b):
    with out:
        clear_output()
        opt.dataset_name = dataset_name.value
        opt.model = model.value
        main.train()
def click_test(b):
    with out:
        clear_output()
        opt.dataset_name = dataset_name.value
        opt.model = model.value
        opt.re_ranking = re_ranking.value
        opt.load_features = load_features.value
        main.test()
train_button.on_click(click_train)
test_button.on_click(click_test)

widgets.VBox([dataset_name,model,re_ranking, load_features, button, out])

In [79]:
result,CMC,mAP = load_result()
query_name = result['query_name']
query_ids = []
for name in query_name:
    id = name.split('_')[0]
    if not id in query_ids:
        query_ids.append(id)

person_id = widgets.Dropdown(options= query_ids,
                             value=query_ids[0],
                             description='Person ID')

person_name = widgets.Dropdown(options= [name for name in query_name if person_id.value in name[:4]],
                               value=[name for name in query_name if person_id.value in name[:4]][0],
                               description='Person Name')

index = widgets.IntSlider(value=0,
                          min=0,
                          max=len(result['query_name']),
                          description='Index'
                         )

R=widgets.IntSlider(value=5,min=0,max=20)
size = widgets.IntSlider(value=20,min=10,max=40)
re_ranking = widgets.Checkbox(value=True,description='Show Re_Ranking:')

def update_name(*args):
    person_name.options = [name for name in query_name if person_id.value in name[:4]]
def update_index_from_id(*args):
    index.value = query_name.index([name for name in query_name if person_id.value in name[:4]][0])
def update_index_from_name(*args):
    index.value = query_name.index(person_name.value)
def updata_id_from_index(*args):
    person_id.value = query_name[index.value][:4]
def updata_name_from_index(*args):
    person_name.value = query_name[index.value]

person_id.observe(update_name, 'value')
person_id.observe(update_index_from_id, 'value')
person_name.observe(update_index_from_name, 'value')
index.observe(updata_id_from_index, 'value')
index.observe(updata_name_from_index, 'value')

def result_demo(person_id, person_name,index,R,size, re_ranking):
    if not re_ranking:
        result,CMC,mAP = load_result()
        ranking = result['ranking']
        query_name = result['query_name']
        gallery_name = result['gallery_name']
        if opt.hdf5:
            hdf5_path = os.path.join(opt.data_dir,opt.dataset_name,opt.dataset_name+'.hdf5')
            f = h5py.File(hdf5_path,'r')
        dataset_dir = os.path.join(opt.data_dir,opt.dataset_name)
        query_label = query_name[index].split('_')[0]
        fig, img = plt.subplots(1,R+1,figsize=(size,size))
        for i in range(0,R+1):
            if i == 0:
                if opt.hdf5:
                    img[i].imshow(f['query'][query_name[index]]['img'])
                else:
                    img[i].imshow(mpimg.imread(os.path.join(dataset_dir,'query', query_name[index]+'.jpg')))
                img[i].set_title('Query Image \n ID:'+query_label)
            else:
                if opt.hdf5:
                    img[i].imshow(f['gallery'][gallery_name[int(ranking[index][i-1])]]['img'])
                else:
                    img[i].imshow(mpimg.imread(os.path.join(dataset_dir,'bounding_box_test', gallery_name[int(ranking[index][i-1])]+'.jpg')))
                gallery_label = gallery_name[int(ranking[index][i-1])].split('_')[0]
                if i == 1:
                    img[i].set_title('Gallery Images \n ID:'+gallery_label)
                else:
                    img[i].set_title('ID:'+gallery_label)

                if(gallery_label== query_label):
                    autoAxis = img[i].axis()
                    rec = Rectangle((autoAxis[0]-0.7,autoAxis[2]-0.2),(autoAxis[1]-autoAxis[0])+1,(autoAxis[3]-autoAxis[2])+0.4,fill=False,lw=2,color='red')
                    rec = img[i].add_patch(rec)
                    rec.set_clip_on(False)
            img[i].axis('off')
        plt.show()
    else:
        opt.re_ranking = False
        original_result,_,_ = load_result()
        original_ranking = original_result['ranking']
        opt.re_ranking = True
        rerank_result,_,_ = load_result()
        rerank_ranking = rerank_result['ranking']
        
        query_name = original_result['query_name']
        gallery_name = original_result['gallery_name']
        
        if opt.hdf5:
            hdf5_path = os.path.join(opt.data_dir,opt.dataset_name,opt.dataset_name+'.hdf5')
            f = h5py.File(hdf5_path,'r')
        dataset_dir = os.path.join(opt.data_dir,opt.dataset_name)
        query_label = query_name[index].split('_')[0]
        fig, img = plt.subplots(2,R+1,figsize=(size,size/1.5))
        for i in range(0,R+1):
            if i == 0:
                if opt.hdf5:
                    img[0,i].imshow(f['query'][query_name[index]]['img'])
                    h,w,c = f['query'][query_name[0]]['img'].shape
                    img[1,i].imshow(np.zeros((h,w))+255,cmap=matplotlib.cm.gray, vmin=0, vmax=255)
                else:
                    img[0,i].imshow(mpimg.imread(os.path.join(dataset_dir,'query', query_name[index]+'.jpg')))
                    img[1,i].imshow(mpimg.imread(os.path.join(dataset_dir,'query', query_name[index]+'.jpg')))
                img[0,i].set_title('Query Image \n ID:'+query_label)
            else:
                if opt.hdf5:
                    img[0,i].imshow(f['gallery'][gallery_name[int(original_ranking[index][i-1])]]['img'])
                    img[1,i].imshow(f['gallery'][gallery_name[int(rerank_ranking[index][i-1])]]['img'])
                else:
                    img[0,i].imshow(mpimg.imread(os.path.join(dataset_dir,'bounding_box_test', gallery_name[int(original_ranking[index][i-1])]+'.jpg')))
                    img[1,i].imshow(mpimg.imread(os.path.join(dataset_dir,'bounding_box_test', gallery_name[int(rerank_ranking[index][i-1])]+'.jpg')))
                    
                gallery_label = gallery_name[int(original_ranking[index][i-1])].split('_')[0]
                rerank_gallery_label = gallery_name[int(rerank_ranking[index][i-1])].split('_')[0]
                if i == 1:
                    img[0,i].set_title('Original Images \n ID:'+gallery_label)
                    img[1,i].set_title('Re-Ranked Images \n ID:'+rerank_gallery_label)
                else:
                    img[0,i].set_title('ID:'+gallery_label)
                    img[1,i].set_title('ID:'+rerank_gallery_label)
                    

                if(gallery_label== query_label):
                    autoAxis = img[0,i].axis()
                    rec = Rectangle((autoAxis[0]-0.7,autoAxis[2]-0.2),(autoAxis[1]-autoAxis[0])+1,(autoAxis[3]-autoAxis[2])+0.4,fill=False,lw=2,color='red')
                    rec = img[0,i].add_patch(rec)
                    rec.set_clip_on(False)
                
                if(rerank_gallery_label== query_label):
                    autoAxis = img[1,i].axis()
                    rec = Rectangle((autoAxis[0]-0.7,autoAxis[2]-0.2),(autoAxis[1]-autoAxis[0])+1,(autoAxis[3]-autoAxis[2])+0.4,fill=False,lw=2,color='red')
                    rec = img[1,i].add_patch(rec)
                    rec.set_clip_on(False)
            img[0,i].axis('off')
            img[1,i].axis('off')
        plt.subplots_adjust(wspace=0.1, hspace=0.1)
        plt.show()
        

demo = interact_manual(result_demo,
         person_id = person_id,
         person_name = person_name,
         index = index,
         R = R, 
         size = size,
                       
         re_ranking = re_ranking)