-
Notifications
You must be signed in to change notification settings - Fork 10
/
load_utils.py
111 lines (97 loc) · 3.97 KB
/
load_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Load images and corresponding npy labels from file.
"""
import numpy as np
import os.path as path
import glob
import os
import random
import cv2
import folders as f
import csv
def load_imgs_labels(batch_size, label_folder, img_folder, rand, angle_folder=None):
"""
Internal generator for loading train or test data
:param batch_size:
:param label_folder:
:param img_folder:
:param angle_folder: The folder that contains "angles.csv" and "filenames.csv"
:return: imgs, labels
"""
label_list = glob.glob(path.join(label_folder, "*"))
total_size = len(label_list)
loop_range = total_size - (total_size % batch_size)
if angle_folder is not None: # Load filenames.csv angles.csv
filenames, angles = load_filename_angle(angle_folder) # List of [filename, angles]
else:
filenames, angles = None, None
while True:
if rand:
random.shuffle(label_list)
for i in range(0, loop_range, batch_size):
batch_label_path = label_list[i:i+batch_size]
batch_label = [np.load(j) for j in batch_label_path]
# label contains .npy, use splitext to delete it.
batch_img_name = [path.splitext(path.basename(j))[0] for j in batch_label_path]
batch_img_path = [path.join(img_folder, name) for name in batch_img_name]
batch_img = [cv2.imread(p, cv2.IMREAD_GRAYSCALE) for p in batch_img_path]
if angle_folder is not None:
# Name to be matched with angles
batch_basename = [path.basename(fi).replace(".npy", "") for fi in batch_label_path]
batch_basename = [n.replace("_flip.jpg", ".jpg") for n in batch_basename]
batch_angles = [angles[filenames.index(fi)] for fi in batch_basename]
batch_angles = [list(map(float, a)) for a in batch_angles]
yield batch_img, batch_label, batch_angles
else:
yield batch_img, batch_label
def train_loader(batch_size, load_angle=False, use_trainval=False):
"""
Training data generator
:param batch_size:
:return: batch_img, batch_label
"""
if use_trainval: # Use all samples at final training
img_folder = f.resize_trainval_img
label_folder = f.resize_trainval_label
else: # Use train set, val set remains for validation
img_folder = f.resize_train_img
label_folder = f.resize_train_label
if load_angle:
angle_folder = f.trainval_angle if use_trainval else f.train_angle
loader = load_imgs_labels(batch_size, label_folder, img_folder, rand=True,
angle_folder=angle_folder)
else:
loader = load_imgs_labels(batch_size, label_folder, img_folder, rand=True)
for img_la in loader:
yield img_la
def test_loader(batch_size, load_angle=False):
"""
Test data generator
:param batch_size:
:return: batch_img, batch_label
"""
img_folder = f.resize_test_img
label_folder = f.resize_test_label
if load_angle:
loader = load_imgs_labels(batch_size, label_folder, img_folder, rand=False,
angle_folder=f.val_angle)
else:
loader = load_imgs_labels(batch_size, label_folder, img_folder, rand=False)
for img_la in loader:
yield img_la
# CSV Loader
def load_filename_angle(folder):
"""
Load filename and corresponding angle
:return: list of [ [[filename]], [[a1][a2][a3]] ]
"""
angle_path = path.join(folder, "angles.csv")
filename_path = path.join(folder, "filenames.csv")
with open(angle_path, mode='r') as angle_csv, open(filename_path, mode='r') as filename_csv:
csv_reader = csv.reader(filename_csv)
filenames = list(csv_reader) # Each line is a list with 1 element
filenames = list(map(lambda x: x[0], filenames))
csv_reader = csv.reader(angle_csv)
angles = list(csv_reader)
assert len(filenames) == len(angles)
return filenames, angles