-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
178 lines (141 loc) · 6.15 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from torch.utils.data import Dataset
import os
import torch
import numpy as np
from skimage import io, transform
import consts
from augment import label_func
def create_csv_file(dir, filename):
"""
:param dir: The directory in which the csv file corresponding to the dataset. Each entry in the CSV corresponds to
the path of a file from the dataset.
:param filename: The name of the CSV file we are creating.
"""
# Open the file in the write mode
with open(filename, 'w') as fileHandle:
for d in os.listdir(dir):
d_full_path = os.path.join(dir, d)
if os.path.isdir(d_full_path) and d in consts.IMAGENETTE_LABEL_DICT:
for path in os.listdir(d_full_path):
full_path = os.path.join(d_full_path, path)
if os.path.isfile(full_path):
fileHandle.write(f'{full_path}\n')
print(f'{full_path}\n')
class Rescale(object):
"""Rescale the image in a sample to a given size"""
def __init__(self, output_size):
"""
:param output_size: (tuple or int) Desired output size. If tuple, output is matched to output_size. If int,
smaller of image edges is matched to output_size keeping aspect ratio the same.
"""
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
"""Perform a Rescale operation on an image, converting it to the desired height and width.
:param sample: A sample image tensor of shape [height, width, ...] on which to perform Rescaling.
:return: A rescaled image tensor with the desired height and width (specified by `self.output_size`).
"""
image = sample
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
return img
class RandomCrop(object):
"""Crop randomly the image in a sample."""
def __init__(self, output_size):
"""
:param output_size: (tuple or int) Desired output size. If int, square crop is made.
"""
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
"""
:param sample: A sample image tensor of shape [height, width, ...] on which to perform Rescaling.
:return: A crop image tensor that is taken from the original image tensor. The height and width of the produced
crop are specified by `self.output_size`.
"""
image = sample
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
return image
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
"""
:param sample: An image that is either grayscale or RGB.
:return: A torch tensor version of the original image such that its dimensions are [C, H, W]. C contains the RGB
channel values, while H and W represent height and width respectively.
"""
image = sample
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
if len(image.shape) == 2:
image = np.expand_dims(image, axis=2)
image = np.repeat(image, 3, axis=2)
image = image.transpose((2, 0, 1))
return torch.from_numpy(image)
class ImagenetteDataset(Dataset):
"""Imagenette dataset."""
def __init__(self, root_dir, csv_file, transform=None, labels=False, debug=False):
"""
:param root_dir: The string directory with all the images.
:param csv_file: String path to the csv file with paths to all the images.
:param transform: (callable, optional) transform to be applied on a sample.
:param labels: True if the dataset should contain integer labels to represent the type of content of the image.
False if the dataset should consist strictly of the images, without their labels.
:param debug: True if we are training in an easy-debug mode where training and evaluation must run quickly.
False otherwise.
"""
self.root_dir = root_dir
self.csv_file = csv_file
self.transform = transform
self.labels = labels
self.debug = debug
self.translate_labels = {'tench': 0,
'English springer': 1,
'cassette player': 2,
'chain saw': 3,
'church': 4,
'French horn': 5,
'garbage truck': 6,
'gas pump': 7,
'golf ball': 8,
'parachute': 9}
with open(csv_file, newline='') as f:
self.paths_to_images = f.read().splitlines()
def __len__(self):
if self.debug:
return 16
else:
return len(self.paths_to_images)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = io.imread(self.paths_to_images[idx])
if self.transform:
image = self.transform(image)
if self.labels:
tag = self.translate_labels[label_func(self.paths_to_images[idx])]
return image, tag
else:
return image
if __name__ == '__main__':
print(consts.image_dir)
print(consts.csv_filename)
create_csv_file(dir=consts.image_dir, filename=consts.csv_filename)