-
Notifications
You must be signed in to change notification settings - Fork 54
/
val_data.py
48 lines (39 loc) · 1.47 KB
/
val_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
file: val_data.py
about: build the validation/test dataset
author: Xiaohong Liu
date: 01/08/19
"""
# --- Imports --- #
import torch.utils.data as data
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Normalize
# --- Validation/test dataset --- #
class ValData(data.Dataset):
def __init__(self, val_data_dir):
super().__init__()
val_list = val_data_dir + 'val_list.txt'
with open(val_list) as f:
contents = f.readlines()
haze_names = [i.strip() for i in contents]
gt_names = [i.split('_')[0] + '.png' for i in haze_names]
self.haze_names = haze_names
self.gt_names = gt_names
self.val_data_dir = val_data_dir
def get_images(self, index):
haze_name = self.haze_names[index]
gt_name = self.gt_names[index]
haze_img = Image.open(self.val_data_dir + 'hazy/' + haze_name)
gt_img = Image.open(self.val_data_dir + 'clear/' + gt_name)
# --- Transform to tensor --- #
transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_gt = Compose([ToTensor()])
haze = transform_haze(haze_img)
gt = transform_gt(gt_img)
return haze, gt, haze_name
def __getitem__(self, index):
res = self.get_images(index)
return res
def __len__(self):
return len(self.haze_names)