-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
847 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions | ||
|
||
This repository contains the code for the following paper: | ||
|
||
Masanori Suganuma, Xing Liu, Takayuki Okatani, "Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions," CVPR, 2019. [[arXiv](https://arxiv.org/abs/1812.00733)] | ||
|
||
If you find this work useful in your research, please cite: | ||
|
||
@inproceedings{suganumaCVPR2019, | ||
Author = {M. Suganuma and X. Liu and T. Okatani}, | ||
Title = {Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions}, | ||
Booktitle = {CVPR}, | ||
Year = {2019} | ||
} | ||
|
||
|
||
Sample results on image restoration: | ||
|
||
![example](Example_results/mydata-1.png "Sample image restoration results") | ||
|
||
Sample results on object detection: | ||
|
||
![example](Example_results/detection_supp-1.png "Sample object detection results") | ||
|
||
|
||
## Requirement | ||
|
||
* Ubuntu 16.04 LTS | ||
* CUDA version 10.0 | ||
* Python version 3.6.2 | ||
* PyTorch version 1.0 | ||
|
||
|
||
## Usage | ||
|
||
### Train a model on the dataset proposed by [RL-Restore](https://arxiv.org/abs/1804.03312) | ||
|
||
```shell | ||
python main.py -m mix -g 1 | ||
``` | ||
|
||
When you use the multiple GPUs, please specify the number of gpus by `-g` option (default:1) | ||
|
||
### Train a model on your own dataset | ||
|
||
```shell | ||
python main.py -m yourdata -g 1 | ||
``` | ||
|
||
|
||
### Test | ||
|
||
Put the trained model (XXXX.pth) to `Trained_model/`, and run the following code: | ||
|
||
```shell | ||
python test.py -m mix -g 1 | ||
``` | ||
|
||
|
||
### Dataset | ||
|
||
The dataset used in [RL-Restore](https://arxiv.org/abs/1804.03312) is available [here](https://github.com/yuke93/RL-Restore). | ||
To generate the training dataset, please run `data/train/generate_train.m` in the above repository and put the generated file (train.h5) to `dataset/train/` in your computer. |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import time | ||
import math | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.backends.cudnn as cudnn | ||
import torch.optim as optim | ||
import torch.utils.data | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
import torchvision.utils as vutils | ||
from torch.autograd import Variable | ||
from skimage.measure import compare_psnr as ski_psnr | ||
from skimage.measure import compare_ssim as ski_ssim | ||
import os | ||
import csv | ||
import logging | ||
|
||
from model import Network | ||
import torch.nn.functional as F | ||
from data_load_own import get_training_set, get_test_set | ||
from data_load_mix import get_dataset_deform | ||
import utils | ||
|
||
|
||
class CNN_train(): | ||
def __init__(self, dataset_name, imgSize=63, batchsize=32): | ||
self.imgSize = imgSize | ||
self.batchsize = batchsize | ||
self.dataset_name = dataset_name | ||
|
||
# load dataset | ||
if dataset_name == 'mix' or dataset_name == 'yourdata': | ||
if dataset_name == 'mix': | ||
self.num_work = 8 | ||
train_dir = '/dataset/train/' | ||
val_dir = '/dataset/val/' | ||
test_dir = '/dataset/test/' | ||
train_set = get_dataset_deform(train_dir, val_dir, test_dir, 0) | ||
# val_set = get_dataset_deform(train_dir, val_dir, test_dir, 1) | ||
test_set = get_dataset_deform(train_dir, val_dir, test_dir, 2) | ||
self.dataloader = DataLoader(dataset=train_set, num_workers=self.num_work, batch_size=self.batchsize, shuffle=True, pin_memory=True) | ||
# self.val_loader = DataLoader(dataset=val_set, num_workers=self.num_work, batch_size=1, shuffle=False, pin_memory=False) | ||
self.test_dataloader = DataLoader(dataset=test_set, num_workers=self.num_work, batch_size=1, shuffle=False, pin_memory=False) | ||
elif dataset_name == 'yourdata': | ||
self.num_work = 8 | ||
train_input_dir = '/dataset/yourdata_train/input/' | ||
train_target_dir = '/dataset/yourdata_train/target/' | ||
test_input_dir = '/dataset/yourdata_test/input/' | ||
test_target_dir = '/dataset/yourdata_test/target/' | ||
train_set = get_training_set(train_input_dir, train_target_dir, True) | ||
test_set = get_training_set(test_input_dir, test_target_dir, False) | ||
self.dataloader = DataLoader(dataset=train_set, num_workers=self.num_work, batch_size=self.batchsize, shuffle=True, drop_last=True) | ||
self.test_dataloader = DataLoader(dataset=test_set, num_workers=self.num_work, batch_size=1, shuffle=False) | ||
else: | ||
print('\tInvalid input dataset name at CNN_train()') | ||
exit(1) | ||
|
||
def __call__(self, cgp, gpuID, epoch_num=150, gpu_num=1): | ||
print('GPUID :', gpuID) | ||
print('epoch_num:', epoch_num) | ||
|
||
# model | ||
torch.manual_seed(2018) | ||
torch.cuda.manual_seed(2018) | ||
torch.backends.cudnn.benchmark = True | ||
torch.backends.cudnn.enabled = True | ||
L1_loss = nn.L1Loss() | ||
L1_loss = L1_loss.cuda(gpuID) | ||
model = Network(16, 10, L1_loss, gpuID=gpuID) | ||
if gpu_num > 1: | ||
device_ids = [i for i in range(gpu_num)] | ||
model = torch.nn.DataParallel(model, device_ids=device_ids) | ||
model = model.cuda(gpuID) | ||
logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) | ||
print('Param:', utils.count_parameters_in_MB(model)) | ||
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) | ||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch_num) | ||
test_interval = 5 | ||
# for results | ||
if not os.path.exists('./results'): | ||
os.makedirs('./results/Inputs') | ||
os.makedirs('./results/Outputs') | ||
os.makedirs('./results/Targets') | ||
|
||
# Train loop | ||
for epoch in range(1, epoch_num+1): | ||
scheduler.step() | ||
start_time = time.time() | ||
print('epoch', epoch) | ||
train_loss = 0 | ||
for module in model.children(): | ||
module.train(True) | ||
for ite, (input, target) in enumerate(self.dataloader): | ||
lr_patch = Variable(input, requires_grad=False).cuda(gpuID) | ||
hr_patch = Variable(target, requires_grad=False).cuda(gpuID) | ||
optimizer.zero_grad() | ||
output = model(lr_patch) | ||
l1_loss = L1_loss(output, hr_patch) | ||
l1_loss.backward() | ||
optimizer.step() | ||
train_loss += l1_loss.item() | ||
if ite % 500 == 0: | ||
vutils.save_image(lr_patch.data, './input_sample%d.png' % gpuID, normalize=False) | ||
vutils.save_image(hr_patch.data, './target_sample%d.png' % gpuID, normalize=False) | ||
vutils.save_image(output.data, './output_sample%d.png' % gpuID, normalize=False) | ||
print('Train set : Average loss: {:.4f}'.format(train_loss)) | ||
print('time ', time.time()-start_time) | ||
|
||
# check test performance | ||
if epoch % test_interval == 0: | ||
with torch.no_grad(): | ||
print('------------------------') | ||
for module in model.children(): | ||
module.train(False) | ||
test_ite = 0 | ||
test_psnr = 0 | ||
test_ssim = 0 | ||
eps = 1e-10 | ||
for i, (input, target) in enumerate(self.test_dataloader): | ||
lr_patch = Variable(input, requires_grad=False).cuda(gpuID) | ||
hr_patch = Variable(target, requires_grad=False).cuda(gpuID) | ||
output = model(lr_patch) | ||
# save images | ||
vutils.save_image(output.data, './results/Outputs/%05d.png' % (int(i)), padding=0, normalize=False) | ||
vutils.save_image(lr_patch.data, './results/Inputs/%05d.png' % (int(i)), padding=0, normalize=False) | ||
vutils.save_image(hr_patch.data, './results/Targets/%05d.png' % (int(i)), padding=0, normalize=False) | ||
# SSIM and PSNR | ||
output = output.data.cpu().numpy()[0] | ||
output[output>1] = 1 | ||
output[output<0] = 0 | ||
output = output.transpose((1,2,0)) | ||
hr_patch = hr_patch.data.cpu().numpy()[0] | ||
hr_patch[hr_patch>1] = 1 | ||
hr_patch[hr_patch<0] = 0 | ||
hr_patch = hr_patch.transpose((1,2,0)) | ||
# SSIM | ||
test_ssim+= ski_ssim(output, hr_patch, data_range=1, multichannel=True) | ||
# PSNR | ||
imdf = (output - hr_patch) ** 2 | ||
mse = np.mean(imdf) + eps | ||
test_psnr+= 10 * math.log10(1.0/mse) | ||
test_ite += 1 | ||
test_psnr /= (test_ite) | ||
test_ssim /= (test_ite) | ||
print('Test PSNR: {:.4f}'.format(test_psnr)) | ||
print('Test SSIM: {:.4f}'.format(test_ssim)) | ||
f = open('PSNR.txt', 'a') | ||
writer = csv.writer(f, lineterminator='\n') | ||
writer.writerow([epoch, test_psnr, test_ssim]) | ||
f.close() | ||
print('------------------------') | ||
torch.save(model.state_dict(), './model_%d.pth' % int(epoch)) | ||
|
||
return train_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
|
||
# we borrowed a part of the following code: | ||
# https://github.com/yuke93/RL-Restore | ||
|
||
import numpy as np | ||
import os | ||
import h5py | ||
import sys | ||
import torch.utils.data as data | ||
import torch | ||
from torchvision import transforms | ||
import cv2 | ||
|
||
|
||
def load_imgs(list_in, list_gt, size = 63): | ||
assert len(list_in) == len(list_gt) | ||
img_num = len(list_in) | ||
imgs_in = np.zeros([img_num, size, size, 3]) | ||
imgs_gt = np.zeros([img_num, size, size, 3]) | ||
for k in range(img_num): | ||
imgs_in[k, ...] = cv2.imread(list_in[k]) / 255. | ||
imgs_gt[k, ...] = cv2.imread(list_gt[k]) / 255. | ||
return imgs_in, imgs_gt | ||
|
||
def data_reformat(data): | ||
"""RGB <--> BGR, swap H and W""" | ||
assert data.ndim == 4 | ||
out = data[:, :, :, ::-1] - np.zeros_like(data) | ||
out = np.swapaxes(out, 1, 2) | ||
out = out.astype(np.float32) | ||
return out | ||
|
||
def get_dataset_deform(train_root,val_root,test_root,is_train): | ||
dataset = DeformedData( | ||
train_root=train_root, | ||
val_root=val_root, | ||
test_root=test_root, | ||
is_train=is_train, | ||
transform=transforms.Compose([transforms.ToTensor()]), | ||
target_transform=transforms.Compose([transforms.ToTensor()]) | ||
) | ||
return dataset | ||
|
||
class DeformedData(data.Dataset): | ||
def __init__(self, train_root, val_root, test_root, is_train=0, transform=None, target_transform=None): | ||
self.transform = transform | ||
self.target_transform = target_transform | ||
self.is_train = is_train | ||
self.train_dir = train_root | ||
self.val_dir = val_root | ||
self.test_dir = test_root | ||
|
||
if self.is_train == 0: | ||
# training data | ||
self.train_list = [self.train_dir + file for file in os.listdir(self.train_dir) if file.endswith('.h5')] | ||
self.train_cur = 0 | ||
self.train_max = len(self.train_list) | ||
f = h5py.File(self.train_list[self.train_cur], 'r') | ||
self.data = f['data'].value | ||
self.label = f['label'].value | ||
f.close() | ||
self.data_index = 0 | ||
self.data_len = len(self.data) | ||
print('training images:', self.data_len) | ||
elif self.is_train == 1: | ||
# validation data | ||
f = h5py.File(self.val_dir + os.listdir(self.val_dir)[0], 'r') | ||
self.data = f['data'].value | ||
self.label = f['label'].value | ||
f.close() | ||
self.data_index = 0 | ||
self.data_len = len(self.data) | ||
elif self.is_train == 2: | ||
# # test data | ||
self.test_in = self.test_dir + 'moderate' + '_in/' | ||
self.test_gt = self.test_dir + 'moderate' + '_gt/' | ||
list_in = [self.test_in + name for name in os.listdir(self.test_in)] | ||
list_in.sort() | ||
list_gt = [self.test_gt + name for name in os.listdir(self.test_gt)] | ||
list_gt.sort() | ||
self.name_list = [os.path.splitext(os.path.basename(file))[0] for file in list_in] | ||
self.data_all, self.label_all = load_imgs(list_in, list_gt) | ||
self.test_total = len(list_in) | ||
self.test_cur = 0 | ||
# data reformat, because the data for tools training are in a different format | ||
self.data = data_reformat(self.data_all) | ||
self.label = data_reformat(self.label_all) | ||
self.data_index = 0 | ||
self.data_len = len(self.data) | ||
else: | ||
print("not implement yet") | ||
sys.exit() | ||
|
||
|
||
def __getitem__(self, index): | ||
img = self.data[index] | ||
img_gt = self.label[index] | ||
|
||
# transforms (numpy -> Tensor) | ||
if self.transform is not None: | ||
img = self.transform(img) | ||
if self.target_transform is not None: | ||
img_gt = self.target_transform(img_gt) | ||
return img, img_gt | ||
|
||
def __len__(self): | ||
return self.data_len |
Oops, something went wrong.