-
Notifications
You must be signed in to change notification settings - Fork 18
/
Datagenerator.py
67 lines (59 loc) · 2.32 KB
/
Datagenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import cv2
class ImageDataGenerator:
def __init__(self, class_list, n_class, batch_size = 1, flip = True, shuffle = False, mean = np.array([104., 117., 124.]), scale_size = (227,227)):
#initial params
self.horizontal = flip
self.batch_size = batch_size
self.shuffle = shuffle
self.class_list = class_list
self.mean = mean
self.scale_size = scale_size
self.pointer = 0
self.n_class = n_class
self.read_class_list(class_list)
if shuffle:
self.shuffle_data()
def read_class_list(self,class_list):
with open(class_list) as f:
lines = f.readlines()
self.images = []
self.labels = []
for line in lines:
items = line.split()
self.images.append(items[0])
self.labels.append(items[1])
self.data_size = len(self.labels)
def shuffle_data(self):
images = self.images.copy()
labels = self.labels.copy()
self.images = []
self.labels = []
idx = np.random.permutation(self.data_size)
for id in idx:
self.images.append(images[id])
self.labels.append(labels[id])
def reset_pointer(self):
self.pointer = 0
if self.shuffle:
self.shuffle_data()
def getNext_batch(self):
paths = self.images[self.pointer:self.pointer+self.batch_size]
labels = self.labels[self.pointer:self.pointer+self.batch_size]
self.pointer += self.batch_size
images = np.ndarray([self.batch_size,self.scale_size[0],self.scale_size[1],3])
for i in range(len(paths)):
image = cv2.imread(paths[i])
#print ('file name is {}'.format(paths[i]))
#cv2.imshow(paths[i],image)
#cv2.waitKey(0)
if self.horizontal and np.random.random()<0.5:
image = cv2.flip(image,1)
image = cv2.resize(image,(self.scale_size[0],self.scale_size[1]))
image = image.astype(np.float32)
image -= self.mean
images[i] = image
one_hot_labels = np.zeros((self.batch_size,self.n_class))
for i in range(len(labels)):
one_hot_labels[i][int(labels[i])] = 1
return images,one_hot_labels