-
Notifications
You must be signed in to change notification settings - Fork 0
/
generators.py
78 lines (64 loc) · 2.38 KB
/
generators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import random
import config as C
def mk_triplets(directory):
classes = os.listdir(directory)
images = [os.listdir(os.path.join(directory,x)) for x in classes]
while True:
# pick random positive class
pos_class = random.randint(0,len(classes)-1)
# print('Anchor: ',pos_class,classes[pos_class])
# pick random, different negative class
neg_class = random.randint(0,len(classes)-2)
if neg_class >= pos_class:
neg_class = neg_class + 1
# print('Negative: ',neg_class,classes[neg_class])
# pick two random images from class
anchor = os.path.join(directory, classes[pos_class], random.choice(images[pos_class]))
pos = os.path.join(directory, classes[pos_class], random.choice(images[pos_class]))
neg = os.path.join(directory, classes[neg_class], random.choice(images[neg_class]))
# print('Selection:',anchor,pos,neg)
yield (pos_class,neg_class,anchor,pos,neg)
from PIL import Image
import numpy as np
# Scale to image size, paste on white background
def paste(img):
i = np.ones((299,299,3))
# NB: Mono images lack the third dimension and will fail here:
# (x,y,z) = img.shape
(x,y) = img.shape
start_x = int((299-x)/2)
end_x = start_x + x
start_y = int((299-y)/2)
end_y = start_y + y
i[start_x:end_x,start_y:end_y,0] = img
return i
def triplet_generator(batch_size,cache_size,directory):
trips = mk_triplets(directory)
while True:
ys = []
ans = []
pss = []
ngs = []
for i in range(0,batch_size):
pc,nc,anc,pos,neg = next(trips)
ys.append((pc,nc))
a_img = np.array(Image.open(anc))/255
p_img = np.array(Image.open(pos))/255
n_img = np.array(Image.open(neg))/255
# Todo: paste it into the middle of a img_size'd canvas
ans.append(paste(a_img))
pss.append(paste(p_img))
ngs.append(paste(n_img))
# todo: augmentation
a = np.asarray(ans)
p = np.asarray(pss)
n = np.asarray(ngs)
y = np.asarray(ys)
yield [a,p,n], y
# Testing:
print("### Testing triplet_generator ###")
g = triplet_generator(4, None, C.train_dir)
for x in range(0,4):
[a,p,n], y = next(g)
print(x, "a:", a.shape, "p:", p.shape, "n:", n.shape, "y:", y.shape)