-
Notifications
You must be signed in to change notification settings - Fork 3
/
openset_by_fold_rand.py
124 lines (110 loc) · 4.03 KB
/
openset_by_fold_rand.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import dirhandler as dh
import traintestsplit as ttsplit
import json
import os
import random
COMMON_THRESH = .6
FOLDS = 5
def write_dicts(splits, fold):
foldstr = str(fold) + "/"
for key in splits:
filedir = 'openset_splits/' + foldstr
filename = filedir + key + '.json'
if not os.path.exists(filedir):
os.mkdir(filedir)
with open(filename, 'w') as f:
json.dump(splits[key], f)
f.close()
print(filename + " written")
return
def gen_dict_from_list(lst):
ddict = {}
for photodir in lst:
ddict[photodir] = dh.get_photos_in_dir(photodir)
return ddict
def gen_validation(src, folds):
valdict = {}
for key in src:
n_photos=len(src[key])
n_remove = n_photos//folds
for _ in range(n_remove):
idx = random.randint(0, len(src[key])-1)
valdict[key] = [src[key].pop(idx)]
return valdict
def get_random_dirs(photodirs, n_training_dirs):
rand_dirs = []
while len(rand_dirs) < n_training_dirs:
rand = random.randint(0, len(photodirs)-1)
rand_dirs.append(photodirs.pop(rand))
return rand_dirs
def gen_train_list(photodirs, divisor):
trainlist=[]
n_training_dirs = len(photodirs)//divisor
trainlist = get_random_dirs(photodirs, n_training_dirs)
return trainlist
def gen_train_test():
photodirs = dh.get_photo_dirs(path='final_dataset/processed', exclude=1)
photodirs_eligable = dh.get_photo_dirs('final_dataset/processed', exclude=5)
testlist = []
for photodir in photodirs: #place all directories with not enough seals in open set
if photodir not in photodirs_eligable:
testlist.append(photodir)
trainlist = gen_train_list(photodirs_eligable, 2)
testlist = [photodir for photodir in photodirs if photodir not in trainlist]
testlist.extend(photodirs_eligable)
return gen_dict_from_list(trainlist), gen_dict_from_list(testlist)
# testprobe.extend(gen_validation(testgal))
def validate_split(traindict, testdict, valdict):
for key in valdict: #make sure no photos in val set still in training
if valdict[key][0] in traindict[key]:
print("Error: Photo in val set also in training")
exit(1)
def get_intersections(folds_list):
common = []
for i in range(len(folds_list)):
common.append([])
for j in range(len(folds_list)):
if j == i:
continue
common[i].append(len(list(set(folds_list[i]['train']).intersection(folds_list[j]['train']))))
return common
def validate_folds(folds_list):
common = get_intersections(folds_list)
maximum=0
for list in common:
idx = common.index(list)
foldmax = max(list)/len(folds_list[idx]['train'])
if maximum < foldmax:
maximum = foldmax
print("Fold-" + str(idx+1) + " max common train dirs with all folds: " + str(max(list)) + "/" + str(len(folds_list[idx]['train'])) + " = " + str(foldmax))
return maximum
def gen_openset(nfold, method="kfold"):
fold_list=[]
for _ in range(nfold):
traindict, testdict = gen_train_test()
valdict = gen_validation(traindict, nfold)
splits = {
'train': traindict,
'validation': valdict,
'test': testdict
}
fold_list.append(splits.copy())
return fold_list
def gen_splits():
fold_list = gen_openset(FOLDS)
maximum = validate_folds(fold_list)
while(maximum > COMMON_THRESH):
print("Max: " + str(maximum) + " over alloted % " + str(COMMON_THRESH) + ", rerunning...")
fold_list = gen_openset(FOLDS)
maximum = validate_folds(fold_list)
fold_counter = 1
for splits in fold_list:
write_dicts(splits, fold_counter)
fold_counter = fold_counter+1
def count_photos():
photodirs_eligable = dh.get_photo_dirs('final_dataset/processed', exclude=1)
sum = 0
for photodir in photodirs_eligable:
sum += len(dh.get_photos_in_dir(photodir))
print(sum)
gen_splits()