Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
sg-nm committed Apr 7, 2019
1 parent d398fc9 commit afb6099
Show file tree
Hide file tree
Showing 12 changed files with 847 additions and 0 deletions.
Binary file added Example_results/detection_supp-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added Example_results/mydata-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions

This repository contains the code for the following paper:

Masanori Suganuma, Xing Liu, Takayuki Okatani, "Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions," CVPR, 2019. [[arXiv](https://arxiv.org/abs/1812.00733)]

If you find this work useful in your research, please cite:

@inproceedings{suganumaCVPR2019,
Author = {M. Suganuma and X. Liu and T. Okatani},
Title = {Attention-based Adaptive Selection of Operations for Image Restoration in the Presence of Unknown Combined Distortions},
Booktitle = {CVPR},
Year = {2019}
}


Sample results on image restoration:

![example](Example_results/mydata-1.png "Sample image restoration results")

Sample results on object detection:

![example](Example_results/detection_supp-1.png "Sample object detection results")


## Requirement

* Ubuntu 16.04 LTS
* CUDA version 10.0
* Python version 3.6.2
* PyTorch version 1.0


## Usage

### Train a model on the dataset proposed by [RL-Restore](https://arxiv.org/abs/1804.03312)

```shell
python main.py -m mix -g 1
```

When you use the multiple GPUs, please specify the number of gpus by `-g` option (default:1)

### Train a model on your own dataset

```shell
python main.py -m yourdata -g 1
```


### Test

Put the trained model (XXXX.pth) to `Trained_model/`, and run the following code:

```shell
python test.py -m mix -g 1
```


### Dataset

The dataset used in [RL-Restore](https://arxiv.org/abs/1804.03312) is available [here](https://github.com/yuke93/RL-Restore).
To generate the training dataset, please run `data/train/generate_train.m` in the above repository and put the generated file (train.h5) to `dataset/train/` in your computer.
Binary file added Trained_model/model_best.pth
Binary file not shown.
159 changes: 159 additions & 0 deletions cnn_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from skimage.measure import compare_psnr as ski_psnr
from skimage.measure import compare_ssim as ski_ssim
import os
import csv
import logging

from model import Network
import torch.nn.functional as F
from data_load_own import get_training_set, get_test_set
from data_load_mix import get_dataset_deform
import utils


class CNN_train():
def __init__(self, dataset_name, imgSize=63, batchsize=32):
self.imgSize = imgSize
self.batchsize = batchsize
self.dataset_name = dataset_name

# load dataset
if dataset_name == 'mix' or dataset_name == 'yourdata':
if dataset_name == 'mix':
self.num_work = 8
train_dir = '/dataset/train/'
val_dir = '/dataset/val/'
test_dir = '/dataset/test/'
train_set = get_dataset_deform(train_dir, val_dir, test_dir, 0)
# val_set = get_dataset_deform(train_dir, val_dir, test_dir, 1)
test_set = get_dataset_deform(train_dir, val_dir, test_dir, 2)
self.dataloader = DataLoader(dataset=train_set, num_workers=self.num_work, batch_size=self.batchsize, shuffle=True, pin_memory=True)
# self.val_loader = DataLoader(dataset=val_set, num_workers=self.num_work, batch_size=1, shuffle=False, pin_memory=False)
self.test_dataloader = DataLoader(dataset=test_set, num_workers=self.num_work, batch_size=1, shuffle=False, pin_memory=False)
elif dataset_name == 'yourdata':
self.num_work = 8
train_input_dir = '/dataset/yourdata_train/input/'
train_target_dir = '/dataset/yourdata_train/target/'
test_input_dir = '/dataset/yourdata_test/input/'
test_target_dir = '/dataset/yourdata_test/target/'
train_set = get_training_set(train_input_dir, train_target_dir, True)
test_set = get_training_set(test_input_dir, test_target_dir, False)
self.dataloader = DataLoader(dataset=train_set, num_workers=self.num_work, batch_size=self.batchsize, shuffle=True, drop_last=True)
self.test_dataloader = DataLoader(dataset=test_set, num_workers=self.num_work, batch_size=1, shuffle=False)
else:
print('\tInvalid input dataset name at CNN_train()')
exit(1)

def __call__(self, cgp, gpuID, epoch_num=150, gpu_num=1):
print('GPUID :', gpuID)
print('epoch_num:', epoch_num)

# model
torch.manual_seed(2018)
torch.cuda.manual_seed(2018)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
L1_loss = nn.L1Loss()
L1_loss = L1_loss.cuda(gpuID)
model = Network(16, 10, L1_loss, gpuID=gpuID)
if gpu_num > 1:
device_ids = [i for i in range(gpu_num)]
model = torch.nn.DataParallel(model, device_ids=device_ids)
model = model.cuda(gpuID)
logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
print('Param:', utils.count_parameters_in_MB(model))
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epoch_num)
test_interval = 5
# for results
if not os.path.exists('./results'):
os.makedirs('./results/Inputs')
os.makedirs('./results/Outputs')
os.makedirs('./results/Targets')

# Train loop
for epoch in range(1, epoch_num+1):
scheduler.step()
start_time = time.time()
print('epoch', epoch)
train_loss = 0
for module in model.children():
module.train(True)
for ite, (input, target) in enumerate(self.dataloader):
lr_patch = Variable(input, requires_grad=False).cuda(gpuID)
hr_patch = Variable(target, requires_grad=False).cuda(gpuID)
optimizer.zero_grad()
output = model(lr_patch)
l1_loss = L1_loss(output, hr_patch)
l1_loss.backward()
optimizer.step()
train_loss += l1_loss.item()
if ite % 500 == 0:
vutils.save_image(lr_patch.data, './input_sample%d.png' % gpuID, normalize=False)
vutils.save_image(hr_patch.data, './target_sample%d.png' % gpuID, normalize=False)
vutils.save_image(output.data, './output_sample%d.png' % gpuID, normalize=False)
print('Train set : Average loss: {:.4f}'.format(train_loss))
print('time ', time.time()-start_time)

# check test performance
if epoch % test_interval == 0:
with torch.no_grad():
print('------------------------')
for module in model.children():
module.train(False)
test_ite = 0
test_psnr = 0
test_ssim = 0
eps = 1e-10
for i, (input, target) in enumerate(self.test_dataloader):
lr_patch = Variable(input, requires_grad=False).cuda(gpuID)
hr_patch = Variable(target, requires_grad=False).cuda(gpuID)
output = model(lr_patch)
# save images
vutils.save_image(output.data, './results/Outputs/%05d.png' % (int(i)), padding=0, normalize=False)
vutils.save_image(lr_patch.data, './results/Inputs/%05d.png' % (int(i)), padding=0, normalize=False)
vutils.save_image(hr_patch.data, './results/Targets/%05d.png' % (int(i)), padding=0, normalize=False)
# SSIM and PSNR
output = output.data.cpu().numpy()[0]
output[output>1] = 1
output[output<0] = 0
output = output.transpose((1,2,0))
hr_patch = hr_patch.data.cpu().numpy()[0]
hr_patch[hr_patch>1] = 1
hr_patch[hr_patch<0] = 0
hr_patch = hr_patch.transpose((1,2,0))
# SSIM
test_ssim+= ski_ssim(output, hr_patch, data_range=1, multichannel=True)
# PSNR
imdf = (output - hr_patch) ** 2
mse = np.mean(imdf) + eps
test_psnr+= 10 * math.log10(1.0/mse)
test_ite += 1
test_psnr /= (test_ite)
test_ssim /= (test_ite)
print('Test PSNR: {:.4f}'.format(test_psnr))
print('Test SSIM: {:.4f}'.format(test_ssim))
f = open('PSNR.txt', 'a')
writer = csv.writer(f, lineterminator='\n')
writer.writerow([epoch, test_psnr, test_ssim])
f.close()
print('------------------------')
torch.save(model.state_dict(), './model_%d.pth' % int(epoch))

return train_loss
107 changes: 107 additions & 0 deletions data_load_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

# we borrowed a part of the following code:
# https://github.com/yuke93/RL-Restore

import numpy as np
import os
import h5py
import sys
import torch.utils.data as data
import torch
from torchvision import transforms
import cv2


def load_imgs(list_in, list_gt, size = 63):
assert len(list_in) == len(list_gt)
img_num = len(list_in)
imgs_in = np.zeros([img_num, size, size, 3])
imgs_gt = np.zeros([img_num, size, size, 3])
for k in range(img_num):
imgs_in[k, ...] = cv2.imread(list_in[k]) / 255.
imgs_gt[k, ...] = cv2.imread(list_gt[k]) / 255.
return imgs_in, imgs_gt

def data_reformat(data):
"""RGB <--> BGR, swap H and W"""
assert data.ndim == 4
out = data[:, :, :, ::-1] - np.zeros_like(data)
out = np.swapaxes(out, 1, 2)
out = out.astype(np.float32)
return out

def get_dataset_deform(train_root,val_root,test_root,is_train):
dataset = DeformedData(
train_root=train_root,
val_root=val_root,
test_root=test_root,
is_train=is_train,
transform=transforms.Compose([transforms.ToTensor()]),
target_transform=transforms.Compose([transforms.ToTensor()])
)
return dataset

class DeformedData(data.Dataset):
def __init__(self, train_root, val_root, test_root, is_train=0, transform=None, target_transform=None):
self.transform = transform
self.target_transform = target_transform
self.is_train = is_train
self.train_dir = train_root
self.val_dir = val_root
self.test_dir = test_root

if self.is_train == 0:
# training data
self.train_list = [self.train_dir + file for file in os.listdir(self.train_dir) if file.endswith('.h5')]
self.train_cur = 0
self.train_max = len(self.train_list)
f = h5py.File(self.train_list[self.train_cur], 'r')
self.data = f['data'].value
self.label = f['label'].value
f.close()
self.data_index = 0
self.data_len = len(self.data)
print('training images:', self.data_len)
elif self.is_train == 1:
# validation data
f = h5py.File(self.val_dir + os.listdir(self.val_dir)[0], 'r')
self.data = f['data'].value
self.label = f['label'].value
f.close()
self.data_index = 0
self.data_len = len(self.data)
elif self.is_train == 2:
# # test data
self.test_in = self.test_dir + 'moderate' + '_in/'
self.test_gt = self.test_dir + 'moderate' + '_gt/'
list_in = [self.test_in + name for name in os.listdir(self.test_in)]
list_in.sort()
list_gt = [self.test_gt + name for name in os.listdir(self.test_gt)]
list_gt.sort()
self.name_list = [os.path.splitext(os.path.basename(file))[0] for file in list_in]
self.data_all, self.label_all = load_imgs(list_in, list_gt)
self.test_total = len(list_in)
self.test_cur = 0
# data reformat, because the data for tools training are in a different format
self.data = data_reformat(self.data_all)
self.label = data_reformat(self.label_all)
self.data_index = 0
self.data_len = len(self.data)
else:
print("not implement yet")
sys.exit()


def __getitem__(self, index):
img = self.data[index]
img_gt = self.label[index]

# transforms (numpy -> Tensor)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
img_gt = self.target_transform(img_gt)
return img, img_gt

def __len__(self):
return self.data_len
Loading

0 comments on commit afb6099

Please sign in to comment.