# Imports and Hyper-params

In [None]:
path = 'D:\DLProject\RRNProject'
output_path = 'D:\DLProject\RRNProject\output'

%run tools/imports.py
%run -i tools/functions.py
%run -i tools/models.py
%run -i tools/phrasecut.py

if not os.path.exists(output_path):
    os.makedirs(output_path)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

opt={}
opt = dotdict(opt)

## General
opt.dataset = 'phrasecut'    # 'phrasecut'
opt.split = 'test'           # 'train' 'test'
opt.test_iter = 1500
opt.test_log_every = 50      # num of iterations to log test info
opt.save_im_every = 500      # num of iterations to save mask output
opt.checkpoint = 'D:\DLProject\RRNProject\output\checkpoint_100000.pth' # path to .pth to continue training

## Hyperparams
opt.phrasecut_categories = ['c_coco'] # filter categories
opt.new_img_proc = True               # replace old image proc with new
opt.dcrf = True                       # DCRF post-processing
opt.im_h = 320
opt.im_w = 320
opt.vf_h = 40
opt.vf_w = 40

# Testing

In [None]:
# Init
opt.test_acc = {'iter': 'val'}
opt.test_acc_pos = {'iter': 'val'}
opt.test_acc_neg = {'iter': 'val'}
opt.test_miou = {'iter': 'val'}
opt.test_overall_iou = {'iter': 'val'}
opt.test_miou_dcrf = {'iter': 'val'}
opt.test_overall_iou_dcrf = {'iter': 'val'}
opt.vocab_size = 8407

# Load pre-trained cfg
checkpoint = torch.load(opt.checkpoint)
train_opt = checkpoint['opt']
train_opt = dotdict(train_opt)

# Initialize Model
model = Model(train_opt)
model.load_state_dict(checkpoint['model'])
    
# activate GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == 'cuda':
    model.to(device)
    
    
# Load Data
refvg_loader = RefVGLoader(split=opt.split)
img_ref_data = refvg_loader.get_img_ref_data()
task_i = -2
print('Loaded phrasecut: %s images, %s tasks' % (len(refvg_loader.img_ids), refvg_loader.task_num))

# Init Vocab
with open(str(dataset_dir) + '/name_att_rel_count.json', 'r') as file:
    data = json.load(file)
corpus = Corpus()
corpus.split_and_add_words_to_vocab_from_data(data)

image_batch = np.zeros((1, opt.im_h, opt.im_w, 3), dtype=np.float32)

# Acc and loss initialize
running_acc, running_acc_pos, running_acc_neg, running_miou, running_overall_iou, I, U = 0, 0, 0, 0, 0, 0, 0
running_miou_dcrf, running_overall_iou_dcrf, I_dcrf, U_dcrf, down_I, down_U = 0, 0, 0, 0, 0, 0
bw, n_batch= 0, 0
iou_metrics = [.5, .6, .7, .8, .9]
iou_precision = np.zeros(len(iou_metrics), dtype=np.int32)
dcrf_iou_precision = np.zeros(len(iou_metrics), dtype=np.int32)


#im preprocess
if opt.new_img_proc:
    preprocess = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
else:
    mu = np.array((104.00698793, 116.66876762, 122.67891434))

# init test
model.eval()

for iter in tq.tqdm(range(opt.test_iter)):
    
    ############## Load Data ##############
    
    # Read next task
    match = 0
    while(match == 0):
        task_i += 1
        if (task_i >= len(img_ref_data['task_ids'])) or (task_i == -1):
            img_ref_data = refvg_loader.get_img_ref_data()          # load img
            img_id = img_ref_data['image_id']
            img_p = str(img_fpath) + '/' + str(img_id) + '.jpg'     # (original shape) image
            img = Image.open(img_p)
            img = img.resize((opt.im_h, opt.im_w))                  # (320, 320) image
            image = np.array(img).astype(np.float32)                # (320, 320, 3) np float
            ubyte_im = skimage.img_as_ubyte(image.astype(np.uint8)) # (320, 320, 3) np uint8
            task_i = 0

        # get task and categories
        task_ids = img_ref_data['task_ids']
        task = task_ids[task_i]
        subsets_of_img = refvg_loader.get_task_subset(img_id, task)
        
        # Filter data by category
        if len(set(opt.phrasecut_categories).intersection(subsets_of_img)) == len(opt.phrasecut_categories):
            match = 1

    # extract phrase     
    sentence = img_ref_data['phrases'][task_i]     #string 'large picture'
    text_pass = corpus.tokenize_sentence(sentence).type(torch.LongTensor) #need to be (1,1,20) torch

    # Ground truth mask
    original_h = img_ref_data['height']            # 600
    original_w = img_ref_data['width']             # 800
    mask = np.zeros((original_h, original_w))                        # (600,800) np
    gt_Polygons = img_ref_data['gt_Polygons'][task_i]                # [plg0, plg1,..] for seperate objects
    for plg in gt_Polygons:
        mask += polygons_to_mask(plg, w=original_w, h=original_h)    #(600,800) np 1/0
    mask_up = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0)    #(1,1,600,800) torch 1/0
    #(1,1,320,320) torch 0/1
    mask_pass = nn.functional.interpolate(mask_up, size=(320, 320), mode='bilinear', align_corners=False).to(device)

    # Skip for black/white
    if len(image.shape) == 2:
        bw += 1
        continue
    
    # processing image before pass
    image_flip = image[:,:,0:3]         #rgba
    image_flip = image_flip[:,:,::-1]   #(320, 320, 3)
    if opt.new_img_proc != True:
        image_flip -= mu

    # add batch_size dimension
    image_batch[n_batch, ...] = image_flip #(1, 320, 320, 3)

    # turn into tensor
    image_pass = torch.from_numpy(image_batch).permute(0,3,1,2).to(device) #(1,3,320,320) torch

    # normalize [0,1] => normalize mean, std
    if opt.new_img_proc:
        image_pass = preprocess(image_pass.view(3,opt.im_h,opt.im_w)/255).view(1,3,opt.im_h,opt.im_w) #(1,3,320,320)
        
    ############## Test Step ##############
    
    # forward pass
    output_down, output_up = model(image_pass, text_pass)  # output: (1,1,40,40), output_up: (1,1,320,320) pre-activation (<0 for false, >0 for true)
    
    ############## Log Test Info ##############
    
    # transform output to prediction in original image shape
    output_predict = (output_up>=0)                             #(1,1,320,320) tensor True/False
    output_up_predict_detached = output_predict.detach().cpu()  #(1,1,320,320) tensor True/False
    output_up_predict_detached_numpy = np.squeeze(output_up_predict_detached.numpy().astype(np.float32)) #(1,1,320,320) np float 1/0
    predicts = resize_and_crop(output_up_predict_detached_numpy, mask.shape[0], mask.shape[1]) #(425, 640) np float 32 0/1
    
    # Accuracy and IoU
    acc, acc_pos, acc_neg = compute_accuracy_test(torch.from_numpy(predicts), torch.from_numpy(mask))
    #iou, intersect, union = compute_iou_np(predicts, mask)
    iou_down, intersect_down, union_down = compute_iou(output_up.detach().cpu(), mask_pass.detach().cpu())
    
    #IoU @ Precision
    for i in range(len(iou_metrics)): # IoU @ Precision
        iou_precision[i] += (iou_down >= iou_metrics[i])
            
    running_acc += acc
    running_acc_pos += acc_pos
    running_acc_neg += acc_neg
    #running_miou += iou
    #I += intersect
    #U += union
    running_miou += iou_down
    down_I += intersect_down
    down_U += union_down
    
    # Dense CRF post-processing
    if opt.dcrf:
        if opt.dataset == 'phrasecut':
            predicts_dcrf = dcrf_calc(output_up.detach().cpu(), ubyte_im, opt.im_h, opt.im_w, opt.im_h, opt.im_w)
        else:
            predicts_dcrf = dcrf_calc(output_up.detach().cpu(), ubyte_im, opt.im_h, opt.im_w, mask.shape[0], mask.shape[1])
        iou_dcrf, intersect_dcrf, union_dcrf = compute_iou_np(predicts_dcrf, mask_pass.detach().cpu().numpy()) #iou expects torch
        running_miou_dcrf += iou_dcrf
        I_dcrf += intersect_dcrf
        U_dcrf += union_dcrf
        for i in range(len(iou_metrics)): # IoU @ Precision
            dcrf_iou_precision[i] += (iou_dcrf >= iou_metrics[i])
    
    # log test info
    if iter % opt.test_log_every == 0 and iter != 0:
        #calc
        avg_miou = running_miou/(iter+1)
        avg_overall_iou = (down_I/down_U)
        avg_acc = running_acc/(iter+1)
        avg_acc_pos = running_acc_pos/(iter+1)
        avg_acc_neg = running_acc_neg/(iter+1)
        
        #log
        opt.test_miou[iter] = avg_miou
        opt.test_overall_iou[iter] = avg_overall_iou
        opt.test_acc[iter] = avg_acc
        opt.test_acc_pos[iter] = avg_acc_pos
        opt.test_acc_neg[iter] = avg_acc_neg
        opt.iou_precision = iou_precision
        opt.dcrf_iou_precision = dcrf_iou_precision
            
        print('\niter[%s]: mean_IoU=%.2f, Overall_IoU=%.4f' % (iter, avg_miou, avg_overall_iou))
        print('acc_pos=%.2f, acc_neg=%.2f' % (avg_acc_pos, avg_acc_neg))
        
        if opt.dcrf:
            avg_miou_dcrf = running_miou_dcrf/(iter+1)
            avg_overall_iou_dcrf = I_dcrf/U_dcrf
            opt.test_miou_dcrf[iter] = avg_miou_dcrf
            opt.test_overall_iou_dcrf[iter] = avg_overall_iou_dcrf
            print('DCRF: mean_IoU=%.4f, Overall_IoU=%.4f' % (avg_miou_dcrf, avg_overall_iou_dcrf))
            
        log_file = os.path.join(output_path + '/test.log')   
        with open(log_file, 'w') as file:
            for k, v in opt.items():
                file.write(str(k) + '='+ str(v) + '\n\n')

            
    # Save image
    if iter % opt.save_im_every == 0 and iter != 0:
        image_display = torch.from_numpy(image) # (425, 640, 3) torch uint8
        output_file_real = output_path + '/' + sentence + ' - ground truth.png'
        output_file_pred = output_path + '/' + sentence + ' - predict truth.png'
        visualize(torch.from_numpy(image), mask_pass.detach().cpu(), sentence + ' - ground truth', show=False, save=output_file_real)
        visualize(torch.from_numpy(image), (output_up>=0).detach().cpu(), sentence + ' - prediction', show=False, save=output_file_pred)
        
        if opt.dcrf:
            output_file_dcrf = output_path + '/' + sentence + ' - DCRF.png'
            visualize(torch.from_numpy(image), torch.from_numpy(predicts_dcrf), sentence + ' - DCRF', show=False, save=output_file_dcrf)
            
        
log_file = os.path.join(output_path + '/test.log')   
with open(log_file, 'w') as file:
    for k, v in opt.items():
        file.write(str(k) + '='+ str(v) + '\n\n')
        
print('Done testing: Saved results at %s' % (log_file))


# Test the output (for debug)


In [None]:
visualize(torch.from_numpy(image), mask_pass.detach().cpu(), sentence + ' - ground truth', show=True, save=None)
visualize(torch.from_numpy(image), (output_up>=0).detach().cpu(), sentence + ' - prediction', show=True, save=None)
visualize(torch.from_numpy(image), torch.from_numpy(predicts_dcrf), sentence + ' - DCRF', show=True, save=None)