In [4]:
from __future__ import print_function, absolute_import


import os 
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
import torchvision
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageChops
import sys, math
sys.path.append('..')

import scripts.utils
from scripts.utils.logger import Logger, savefig
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.imutils import batch_with_heatmap,normalize_batch,im_to_numpy
from scripts.utils.transforms import fliplr, flip_back
import scripts.models as models
import scripts.datasets as datasets
import random
from matplotlib import cm
from scripts.utils.pytorch_modelsize import SizeEstimator
 

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_jet():
    colormap_int = np.zeros((256, 3), np.uint8)
 
    for i in range(0, 256, 1):
        colormap_int[i, 0] = np.int_(np.round(cm.jet(i)[0] * 255.0))
        colormap_int[i, 1] = np.int_(np.round(cm.jet(i)[1] * 255.0))
        colormap_int[i, 2] = np.int_(np.round(cm.jet(i)[2] * 255.0))

    return colormap_int

jet_map = get_jet()

def clamp(num, min_value, max_value):
    return max(min(num, max_value), min_value)

def gray2color(gray_array, color_map):
    
    rows, cols = gray_array.shape
    color_array = np.zeros((rows, cols, 3), np.uint8)
 
    for i in range(0, rows):
        for j in range(0, cols):
#             log(256,2) = 8 , log(1,2) = 0 * 8
            color_array[i, j] = color_map[clamp(int(abs(gray_array[i, j])*10),0,255)]
    
    return color_array


from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

class objectview(object):
    def __init__(self, *args, **kwargs):
        d = dict(*args, **kwargs)
        self.__dict__ = d

dataset_name= 'real_data'
dataroot = '/home/oishii/Datasets/backup/'
input_size = 512

data_config_normal  = objectview({'input_size':input_size,
                           'limited_dataset':0,
                           'resize_and_crop':'resize',
                           'data_augumentation':False,
                           'norm_type':'none',
                           'base_dir':dataroot,
                           'data':dataset_name})

data_config_gan  = objectview({'input_size':input_size,
                           'limited_dataset':0,
                           'resize_and_crop':'resize',
                           'norm_type':'gan',
                           'base_dir':dataroot,
                           'data':dataset_name})


methods = [
#            ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4/1e3_bs16_256_basic_2017_naiveuno',data_config_normal),
#            ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov2/1e3_bs16_256_gradient9_basic_2017_naiveuno',data_config_normal),    
#             ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4/1e3_bs16_256_multimaskedpixel_2017_naivemmucross',data_config_normal),
#             ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_mmaskedgan_2017_naivemmucross',data_config_gan),
#     ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_baseline_mmaskedgan_2017_naivemmucross',data_config_gan),
        ('/home/oishii/Documents/imageharmonization/rasc-v2/psnr/cocov4/1e3_bs4_512_basic_final_rascv2512',data_config_normal),
]

#    [905]
# 

# sample =random.sample(range(80), 1); # 201 712 668 504 [843] #
# sample = [27]

sample=[]

#             675,(val2014,271), (val2014,949)
# run with gan-based method            
with torch.no_grad():
    for resume,data_config in methods:
        os.makedirs(os.path.join(resume,dataset_name),exist_ok=True)
        dataloader = datasets.COCO
        val_loader = torch.utils.data.DataLoader(dataloader('',data_config,sample=sample),
                                             batch_size=1, shuffle=False,
                                             num_workers=2, pin_memory=False)
        epoches = '/model_best.pth.tar'
        checkpoint_dict = torch.load(resume+epoches)
        checkpoint = checkpoint_dict['state_dict']
            
        for i, (inputs, target, name_of_img) in enumerate(val_loader):            
            
            name_of_model = resume.split('_')[-1]
           
            model = models.__dict__[name_of_model]()
            model.load_state_dict(checkpoint)
            model.eval()
            
            inputs = inputs
            target = target[0]
            
            if 'rascv2' in resume:
                output = model(inputs)
                mask2s = inputs[0,3:4]
            elif 'gagan' in resume:
                output,mask2s = model(inputs)
            else:
                output,mask8s,mask4s,mask2s = model(inputs)
                mask2s = nn.functional.interpolate(mask2s,scale_factor=2)
               
            # if gan based method
            if 'gan' in data_config.norm_type:
                outputnp = (im_to_numpy(output[0]) + 1) / 2.0 * 255.0
                inputsnp = (im_to_numpy(inputs[0,0:3]) + 1) / 2.0 * 255.0
                targetnp = (im_to_numpy(target) + 1) / 2.0 * 255.0
                outputnp = outputnp.clip(0,255).astype(np.uint8)
                inputsnp = inputsnp.clip(0,255).astype(np.uint8)
                targetnp = targetnp.clip(0,255).astype(np.uint8)
            else:
                inputsnp = (im_to_numpy(inputs[0,0:3])*255).clip(0,255).astype(np.uint8)
                outputnp = (im_to_numpy(output[0])*255).clip(0,255).astype(np.uint8)
                targetnp = (im_to_numpy(target)*255).clip(0,255).astype(np.uint8)
                
                      
            im_output = Image.fromarray(outputnp)
#             im_input = Image.fromarray(inputsnp)
#             im_target = Image.fromarray(targetnp)
#             im_mask = Image.fromarray(masknp)
#             im_attention = Image.fromarray(attentionmap)
            
#             jetnp = gray2color(np.array(ImageChops.difference(im_target,im_output).convert('L')),jet_map)
#             im_jet = Image.fromarray(jetnp)
            
#             tmp_ssim = ssim(targetnp,outputnp,multichannel=True)
#             tmp_psnr = psnr(targetnp,outputnp)
#             tmp_mse = np.mean( (outputnp - targetnp) ** 2 )
  
            im_output.save(name_of_img[0])
#             im_jet.save(resume.split('/')[-1].replace('1e3_bs4_512','').replace('_final_','')+'jet_'+str(sample)+'.png')
#             im_attention.save(resume.split('/')[-1].replace('1e3_bs4_512','').replace('_final_','')+'att_'+str(sample)+'.png')
#             im_input.save('input_'+str(sample)+'.png')
#             im_target.save('target_'+str(sample)+'.png')
#             im_mask.save('mask_'+str(sample)+'.png')
            
#             print(resume.split('/')[-1],tmp_ssim,tmp_psnr,tmp_mse,sample)
#             plt.figure(figsize=(30,6))
#             plt.xlabel(resume.split('/')[-1])
#             plt.imshow(np.concatenate((inputsnp,outputnp,targetnp,jetnp,attentionmap),axis=1))
#             plt.draw()


            

total Dataset of real_data is :  99
('test01_Su.png',)
('test02_Su.png',)
('test03_Su.png',)
('test04_Su.png',)
('test05_Su.png',)
('test06_Su.png',)
('test07_Su.png',)
('test08_Su.png',)
('test09_Su.png',)
('test10_Su.png',)
('test11_Su.png',)
('test12_Su.png',)
('test13_Su.png',)
('test14_Su.png',)
('test15_Su.png',)
('test16_Su.png',)
('test17_Su.png',)
('test18_Su.png',)
('test19_Su.png',)
('test20_Su.png',)
('test21_Su.png',)
('test22_Su.png',)
('test23_Su.png',)
('test24_Su.png',)
('test25_Su.png',)
('test26_Su.png',)
('test27_Su.png',)
('test28_Su.png',)
('test29_Su.png',)
('test30_Su.png',)
('test31_Su.png',)
('test32_Su.png',)
('test33_Su.png',)
('test34_Su.png',)
('test35_Su.png',)
('test36_Su.png',)
('test37_Su.png',)
('test38_Su.png',)
('test39_Su.png',)
('test40_Su.png',)
('test41_Su.png',)
('test42_Su.png',)
('test43_Su.png',)
('test44_Su.png',)
('test45_Su.png',)
('test46_Su.png',)
('test47_Su.png',)
('test48_Su.png',)
('test_01.png',)
('test_02.png',)
('test_03.png',)
(