In [None]:
from __future__ import print_function, absolute_import

import os 
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim
import torchvision
import torchvision.datasets as datasets
import numpy as np

import sys
sys.path.append('..')

import scripts.utils
from scripts.utils.logger import Logger, savefig
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.imutils import batch_with_heatmap,normalize_batch,im_to_numpy
from scripts.utils.transforms import fliplr, flip_back
import scripts.models as models
import scripts.datasets as datasets

from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

class objectview(object):
    """Convert dict(or parameters of dict) to object view
    See also:
        - https://goodcode.io/articles/python-dict-object/
        - https://stackoverflow.com/questions/1305532/convert-python-dict-to-object
    >>> o = objectview({'a': 1, 'b': 2})
    >>> o.a, o.b
    (1, 2)
    >>> o = objectview(a=1, b=2)
    >>> o.a, o.b
    (1, 2)
    """
    def __init__(self, *args, **kwargs):
        d = dict(*args, **kwargs)
        self.__dict__ = d


dataset_name= 'val5k'
dataroot = '/home/oishii/Datasets/backup/fivek'
input_size = 256

data_config_normal  = objectview({'input_size':input_size,
                           'limited_dataset':0,
                           'resize_and_crop':'resize',
                           'data_augumentation':False,
                           'norm_type':'none',
                           'base_dir':dataroot,
                           'data':dataset_name})

data_config_gan  = objectview({'input_size':input_size,
                           'limited_dataset':0,
                           'resize_and_crop':'resize',
                           'norm_type':'gan',
                           'base_dir':dataroot,
                           'data':dataset_name})


methods = [
           ('/home/oishii/Documents/imageharmonization/rasc_temp/psnr/adobe5k/1e3_bs16_256_basic_5k_rascv1',data_config_normal),
        ('/home/oishii/Documents/imageharmonization/rasc_temp/psnr/adobe5k/1e3_bs16_256_basic_5k_rascv2',data_config_normal),
#             ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4/1e3_bs16_256_multimaskedpixel_2017_naivemmucross',data_config_normal),
#             ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_mmaskedgan_2017_naivemmucross',data_config_gan),
#     ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_baseline_mmaskedgan_2017_naivemmucross',data_config_gan),
#     ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_gagan_2017_globalattentionnetwork',data_config_gan),
#     ('/home/oishii/Documents/imageharmonization/rasc-v2/limited/cocov4gan/1e3_bs16_256_mmaskedgan_2017_naivemmucrossx',data_config_gan)
]

sample = []

for resume,data_config in methods:
  
    #os.makedirs(os.path.join(resume,dataset_name,'output'),exist_ok=True)
    
            
    val_loader = torch.utils.data.DataLoader(datasets.COCO('',args=data_config),
                                         batch_size=1, shuffle=False,
                                         num_workers=2, pin_memory=False)
    epoches = '/checkpoint.pth.tar'
    checkpoint_dict = torch.load(resume+epoches)
    checkpoint = checkpoint_dict['state_dict']
        

    name_of_model = resume.split('_')[-1]
    

    model = models.__dict__[name_of_model]().cuda()
    model.load_state_dict(checkpoint)
    model.eval()
    
    MSE = AverageMeter()
    PSNR = AverageMeter()
    SSIM = AverageMeter()
    
    with torch.no_grad():
        for i, (inputs, target) in enumerate(val_loader):            
            inputs = inputs.cuda()
            target = target.cuda()
            mask = inputs[0,3:4]
           
            if 'rasc' in resume:
                output = model(inputs)
                mask2s = inputs[0,3:4]
            elif 'gagan' in resume:
                output,mask2s = model(inputs)
            else:
                output,mask8s,mask4s,mask2s = model(inputs)
                mask2s = nn.functional.interpolate(mask2s,scale_factor=2)
            
             # if gan based method
            if 'gan' in data_config.norm_type:
                outputnp = (im_to_numpy(output[0]) + 1) / 2.0 * 255.0
                inputsnp = (im_to_numpy(inputs[0,0:3]) + 1) / 2.0 * 255.0
                targetnp = (im_to_numpy(target[0]) + 1) / 2.0 * 255.0
                outputnp = outputnp.clip(0,255).astype(np.uint8)
                inputsnp = inputsnp.clip(0,255).astype(np.uint8)
                targetnp = targetnp.clip(0,255).astype(np.uint8)
            else:
                inputsnp = (im_to_numpy(inputs[0,0:3])*255).clip(0,255).astype(np.uint8)
                outputnp = (im_to_numpy(output[0])*255).clip(0,255).astype(np.uint8)
                targetnp = (im_to_numpy(target[0])*255).clip(0,255).astype(np.uint8)
                
            masknp = im_to_numpy(mask.repeat(3,1,1)).astype(np.uint8)
            
            tmp_ssim = ssim(targetnp,outputnp,multichannel=True)
            tmp_psnr = psnr(targetnp,outputnp)
            tmp_mse = np.mean( (outputnp - targetnp) ** 2 )
            
            MSE.update(tmp_mse, inputs.size(0))
            PSNR.update(tmp_psnr, inputs.size(0))
            SSIM.update(tmp_ssim, inputs.size(0))
            
#             torchvision.utils.save_image(output[0],os.path.join(resume,dataset_name,'output','%s.png'%(i)),padding=0)
            
    print("%s: MSE:%s, SSIM:%s, PSNR:%s"%(resume.split('/')[-1], MSE.avg, SSIM.avg, PSNR.avg))


total Dataset of val5k is :  2420


In [10]:
# ablation study

from __future__ import print_function, absolute_import

import os 
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torchvision
import torchvision.datasets as datasets
import numpy as np

import sys
sys.path.append('..')

import scripts.utils
from scripts.utils.logger import Logger, savefig
from scripts.utils.evaluation import accuracy, AverageMeter, final_preds
from scripts.utils.misc import save_checkpoint, save_pred, adjust_learning_rate
from scripts.utils.osutils import mkdir_p, isfile, isdir, join
from scripts.utils.imutils import batch_with_heatmap,normalize_batch,im_to_numpy
from scripts.utils.transforms import fliplr, flip_back
import scripts.models as models
import scripts.datasets as datasets

from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_psnr as psnr

class objectview(object):
    """Convert dict(or parameters of dict) to object view
    See also:
        - https://goodcode.io/articles/python-dict-object/
        - https://stackoverflow.com/questions/1305532/convert-python-dict-to-object
    >>> o = objectview({'a': 1, 'b': 2})
    >>> o.a, o.b
    (1, 2)
    >>> o = objectview(a=1, b=2)
    >>> o.a, o.b
    (1, 2)
    """
    def __init__(self, *args, **kwargs):
        d = dict(*args, **kwargs)
        self.__dict__ = d
            
# COCO-val2017
methods = (
           '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_unet',
           '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_unetconv',
           '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv3xgaussian',
            '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv1',
    '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv2',
    '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv3',
    '/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv3x6'
         )

dataset_name= 'val2017'
dataroot = '/home/oishii/Documents/coco_data_maker/synthesis_coco_v4/'

sample = []

for resume in methods:

    if 'pix2pix' in resume:
        data_config  = objectview({'input_size':256,'normalized_input':False,'data_augumentation':False,'withseg':False})
        val_loader = torch.utils.data.DataLoader(datasets.COCO(dataroot,dataset_name,config=data_config,sample=sample,gan_norm=True),
                                             batch_size=1, shuffle=False,
                                             num_workers=2, pin_memory=False)
        epoches = '/latest_net_G.pth'
        checkpoint = torch.load(resume+epoches)
    else:
        if 'seg' in resume:
            data_config  = objectview({'input_size':256,'normalized_input':False,'data_augumentation':False,'withseg':True})
        else:
            data_config  = objectview({'input_size':256,'normalized_input':False,'data_augumentation':False,'withseg':False})
            
        val_loader = torch.utils.data.DataLoader(datasets.COCO(dataroot,dataset_name,config=data_config,sample=sample,gan_norm=False),
                                             batch_size=1, shuffle=False,
                                             num_workers=2, pin_memory=False)
        epoches = '/model_best.pth.tar'
        checkpoint_dict = torch.load(resume+epoches)
        checkpoint = checkpoint_dict['state_dict']
    
    if 'pix2pix' in resume:
        name_of_model = 'ounet'
    else:
        name_of_model = resume.split('_')[-1]
    
    model = models.__dict__[name_of_model]().cuda()
    model.load_state_dict(checkpoint)
    model.eval()
    
    MSE = AverageMeter()
    PSNR = AverageMeter()
    SSIM = AverageMeter()
    
    with torch.no_grad():
        for i, (inputs, target) in enumerate(val_loader):            
            inputs = inputs.cuda()
            target = target[0].cuda()
            
            if 'unet' in resume:
                output = model(inputs)
            elif 'seg' in resume:
                output,_ = model(inputs)
            else:
                output = model(inputs)
                
             # if gan based method
            if 'pix2pix' in resume:
                outputnp = (im_to_numpy(output[0]) + 1) / 2.0 * 255.0
                inputsnp = (im_to_numpy(inputs[0,0:3]) + 1) / 2.0 * 255.0
                outputnp = outputnp.clip(0,255).astype(np.uint8)
                inputsnp = inputsnp.clip(0,255).astype(np.uint8)
                targetnp = (im_to_numpy(target[0])*255).astype(np.uint8)
            else:
                inputsnp = (im_to_numpy(inputs[0,0:3])*255).clip(0,255).astype(np.uint8)
                outputnp = (im_to_numpy(output[0])*255).clip(0,255).astype(np.uint8)
                targetnp = (im_to_numpy(target[0])*255).astype(np.uint8)
            
            tmp_ssim = ssim(targetnp,outputnp,multichannel=True)
            tmp_psnr = psnr(targetnp,outputnp)
            tmp_mse = np.mean( (outputnp - targetnp) ** 2 )
            
            MSE.update(tmp_mse, inputs.size(0))
            PSNR.update(tmp_psnr, inputs.size(0))
            SSIM.update(tmp_ssim, inputs.size(0))
            
#             torchvision.utils.save_image(inputs[0,3:4,:,:],os.path.join(resume,dataset_name,'mask','%s.png'%(i)),padding=0)
#             torchvision.utils.save_image(target[0],os.path.join(resume,dataset_name,'target','%s.png'%(i)),padding=0)
#             torchvision.utils.save_image(inputs[0,0:3,:,:],os.path.join(resume,dataset_name,'input','%s.png'%(i)),padding=0)
#             torchvision.utils.save_image(output[0],os.path.join(resume,dataset_name,'output','%s.png'%(i)),padding=0)
            
    print("%s: MSE:%s, SSIM:%s, PSNR:%s"%(resume, MSE.avg, SSIM.avg, PSNR.avg))


total Dataset of val2017 is :  1716
/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_unet: MSE:17.28237579215285, SSIM:0.9757684633437844, PSNR:33.55681231416759
total Dataset of val2017 is :  1716
/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_unetconv: MSE:17.85627645805258, SSIM:0.9757584460044089, PSNR:33.349426540741185
total Dataset of val2017 is :  1716
/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv3xgaussian: MSE:15.599634133481945, SSIM:0.979036533803898, PSNR:34.74366374129477
total Dataset of val2017 is :  1716
/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv1: MSE:17.935396112947377, SSIM:0.9758088011985688, PSNR:33.324841910413326
total Dataset of val2017 is :  1716
/home/oishii/Documents/deep-harimonization-improved/psnr/cocov4/1e3_bs8_256_2017_radhnv2: MSE:15.80662498385201, SSIM:0.9779920592634598, PSNR:34.494394724813844
tota