In [1]:
%cd /sun/codebase/harmer_hdnet

/sunjinsheng/codebase/harmer_hdnet


In [2]:
import os
import time
from pathlib import Path
from tqdm.notebook import tqdm
import numpy as np
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
import torch
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity as ssim
from skimage import data, io

# from options.train_options import TrainOptions
# from data import CustomDataset
# from models import create_model

In [3]:
def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
def calculateMean(vars):
    return sum(vars) / len(vars)

def save_img(path, img):
    fold, name = os.path.split(path)
    os.makedirs(fold, exist_ok=True)
    io.imsave(path, img)

def resolveResults(results):
    interval_metrics = {}
    mask, mse, psnr, fmse, ssim = np.array(results['mask']), np.array(results['mse']), np.array(results['psnr']), np.array(results['fmse']), np.array(results['ssim'])
    interval_metrics['0.00-0.05'] = [np.mean(mse[np.logical_and(mask <= 0.05, mask > 0.0)]),
                                    np.mean(psnr[np.logical_and(mask <= 0.05, mask > 0.0)]),
                                    np.mean(fmse[np.logical_and(mask <= 0.05, mask > 0.0)]),
                                    np.mean(ssim[np.logical_and(mask <= 0.05, mask > 0.0)])]

    interval_metrics['0.05-0.15'] = [np.mean(mse[np.logical_and(mask <= 0.15, mask > 0.05)]),
                                    np.mean(psnr[np.logical_and(mask <= 0.15, mask > 0.05)]),
                                    np.mean(fmse[np.logical_and(mask <= 0.15, mask > 0.05)]),
                                    np.mean(ssim[np.logical_and(mask <= 0.15, mask > 0.05)])]

    interval_metrics['0.15-1.00'] = [np.mean(mse[mask > 0.15]),
                                    np.mean(psnr[mask > 0.15]),
                                    np.mean(fmse[mask > 0.15]),
                                    np.mean(ssim[mask > 0.15])]

    print(interval_metrics)
    return interval_metrics

def updateWriterInterval(writer, metrics, epoch):
    for k, v in metrics.items():
        writer.add_scalar('interval/{}-MSE'.format(k), v[0], epoch)
        writer.add_scalar('interval/{}-PSNR'.format(k), v[1], epoch)

In [4]:

from argparse import Namespace
opt = Namespace(dataset_root='datasets/HAdobe5k', name='', gpu_ids=[0], checkpoints_dir='/sun/home_logs/instance_harmer/results0320', is_train=False, model='hdnet', input_nc=3, output_nc=3, ngf=32, ndf=64, netD='basic', netG='hdnet', n_layers_D=3, normD='instance', normG='RAIN', init_type='normal', init_gain=0.02, no_dropout=False, dataset_mode='iharmony4', serial_batches=False, num_threads=32, batch_size=12, load_size=256, crop_size=256, max_dataset_size=np.inf, preprocess='none', display_winsize=256, epoch='latest', load_iter=0, verbose=False, suffix='', display_freq=500, display_id=1, display_server='http://localhost', display_env='main', display_port=8097, update_html_freq=500, print_freq=300, no_html=False, save_latest_freq=5000, save_epoch_freq=1, save_by_iter=False, continue_train=False, epoch_count=1, phase='train', niter=120, niter_decay=0, beta1=0.9, lr=0.001, g_lr_ratio=1.0, d_lr_ratio=1.0, gan_mode='vanilla', pool_size=0, lr_policy='target_decay', lr_decay_iters=100, lambda_L1=1.0, lambda_Fft=0.0, gp_ratio=1.0, lambda_a=1.0, lambda_v=1.0, isTrain=True)
# setup_seed(6)
# list_args = ["--dataset_root", "datasets/Hday2night", "--checkpoints_dir", "/sun/home_logs/hdnet/evaluate/results0721",
#              "--name", "", "--batch_size", "12", "--is_train", 0]
# opt = TrainOptions().parse(args=list_args)   # get training 
# test_dataset = CustomDataset(opt, is_for_train=False)
# test_dataset_size = len(test_dataset)
# print('The number of testing images = %d' % test_dataset_size)

# test_dataloader = test_dataset.load_data()
# print('The total batches of training images = %d' % len(test_dataset.dataloader))
opt_2 = Namespace()
opt_2.dataset_root = '/sun/home_datasets/iharmony4/HAdobe5k'
opt_2.preprocess = 'resize'
opt_2.load_size = 256
opt_2.batch_size = 16
opt_2.num_threads = 0


In [5]:

from data.real_dataset import RealDataset
is_for_train = False
dataset = RealDataset(opt_2)
dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=1,
            shuffle=is_for_train,
            num_workers=int(opt_2.num_threads),
            drop_last=False)

loading training file...


In [6]:
# model = create_model(opt)      # create a model given opt.model and other options
# model.setup(opt)               # regular setup: load and print networks; create schedulers
# model.netG.eval()
from models.networks import HDNet
from models.normalize import RAIN
input_nc = 3
output_nc = 3
ngf = 32
norm_layer = RAIN
use_dropout = False
use_attention = True
model = HDNet(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, use_attention=True)
model.load_state_dict(torch.load('/sun/codebase/harmer_hdnet/home_logs/old_version/results_0701/latest_net_G.pth', map_location='cpu'))
model = model.cuda()
model.eval()

# total_iters = 0                # the total number of training iterations
# writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name))

# evaluate for every epoch
epoch = 0
max_psnr=0
# epoch_mse, epoch_psnr, epoch_interval_metrics = evaluateModel(epoch, model, opt, test_dataloader, 'eval', max_psnr)
# def evaluateModel(epoch_number, model, opt, test_dataset, epoch, max_psnr, iters=None):
epoch_number = epoch
iters = None

In [8]:
flag_save_img = True
eval_results_path = os.path.join('/sun/home_logs/instance_harmer/','results0320')
csv_eval = os.path.join(eval_results_path, 'eval.csv')

total_eval_results = {'mask': [], 'mse': [], 'psnr': [], 'fmse':[], 'ssim':[]}
# if iters is not None:
#     eval_path = os.path.join(opt.checkpoints_dir, opt.name, 'Eval_%s_iter%d.csv' % (epoch, iters))  # define the website directory
# else:
#     eval_path = os.path.join(opt.checkpoints_dir, opt.name, 'Eval_%s.csv' % (epoch))  # define the website directory
# util.mkdir(eval_results_path)
flag_exists = os.path.exists(csv_eval)
# eval_results_fstr = open(csv_eval, 'a')
# if not flag_exists:
#     eval_results_fstr.writelines('img_path,mask_ratio,mse,psnr,fmse,ssim\n') 

# eval_results = {'mask': [], 'mse': [], 'psnr': [], 'fmse':[], 'ssim':[]}

root = Path('/sun/home_logs/instance_harmer/real_image_results/instance2')
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
it = len(dataloader)
times = torch.zeros(it)     # 存储每轮iteration的时间

In [9]:
%%time
for i, data in tqdm(enumerate(dataloader), total = len(dataloader)):
    # model.set_input(data)  # unpack data from data loader
    model.comp = data['comp'].cuda()
    model.mask = data['mask'].cuda()
    paths = data['img_path']
    model.inputs = model.comp
    # comp = model.comp
    # torch.save(comp, "comp_2.pt")

    # model.test()  # inference
    with torch.no_grad():
        # model.forward()
        starter.record()
        model.output = model(model.comp, model.mask)
        ender.record()
        torch.cuda.synchronize()

        curr_time = starter.elapsed_time(ender) # 计算时间
        times[i] = curr_time

        model.attentioned = model.output * model.mask + model.inputs[:,:3,:,:] * (1 - model.mask)
        model.fake_f = model.output * model.mask
        model.harmonized = model.attentioned
        output = model.attentioned

        for i_img in range(output.size(0)):
            img_path = paths[i_img]
            img_name = img_path.split('/')[-1]
            save_path = root/img_name
            pred = output[i_img:i_img+1]
            img_pred = tensor2im(pred)
            img = Image.fromarray(img_pred)
            img.save(str(save_path))

  0%|          | 0/100 [00:00<?, ?it/s]

CPU times: user 1min 13s, sys: 525 ms, total: 1min 14s
Wall time: 7.07 s


In [10]:
from thop import profile

In [11]:
flops, params = profile(model, inputs=(model.comp, model.mask))
print(flops, params)

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.instancenorm.InstanceNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
48041641920.0 10408495.0


In [12]:
print('flops: ', flops, 'params: ', params)
print('flops: %.2f M, params: %.2f M' % (flops / 1000000.0, params / 1000000.0))

flops:  48041641920.0 params:  10408495.0
flops: 48041.64 M, params: 10.41 M


In [13]:
mean_time = times.mean().item()
print("Inference time: {:.6f}, FPS: {} ".format(mean_time, 1000/mean_time))

Inference time: 19.308174, FPS: 51.79153622171355 
