Skip to content

Commit 8a80ebb

Browse files
script to find a bad datasets
1 parent 439a92d commit 8a80ebb

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

Core/CTestLoader.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
import tensorflow as tf
22
import numpy as np
33
import os, glob
4+
from functools import lru_cache
45

56
class CTestLoader(tf.keras.utils.Sequence):
67
def __init__(self, testFolder):
8+
self._folder = testFolder
79
self._batchesNpz = [
810
f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
911
]
1012
self.on_epoch_end()
1113
return
1214

15+
@property
16+
def folder(self):
17+
return self._folder
18+
19+
@lru_cache(maxsize=1)
20+
def parametersIDs(self):
21+
batch, _ = self[0]
22+
userId = batch['userId'][0, 0, 0]
23+
placeId = batch['placeId'][0, 0, 0]
24+
screenId = batch['screenId'][0, 0, 0]
25+
return userId, placeId, screenId
26+
1327
def on_epoch_end(self):
1428
return
1529

scripts/make-blacklist.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-.
3+
'''
4+
This script performs the following steps:
5+
- Load the best model from the Data folder
6+
- Load the test datasets from the Data/test-main folder
7+
- Evaluate the model on the test datasets
8+
- Add each test dataset to blacklists if the model mean loss is greater than the threshold
9+
'''
10+
# TODO: add the W&B integration
11+
import argparse, os, sys
12+
# add the root folder of the project to the path
13+
ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
14+
sys.path.append(ROOT_FOLDER)
15+
16+
import numpy as np
17+
from Core.CDatasetLoader import CDatasetLoader
18+
from Core.CTestLoader import CTestLoader
19+
from collections import defaultdict
20+
import time
21+
from Core.CModelTrainer import CModelTrainer
22+
import tqdm
23+
import json
24+
import glob
25+
26+
def _eval(dataset, model):
27+
T = time.time()
28+
# evaluate the model on the val dataset
29+
losses = []
30+
predDist = []
31+
for batchId in range(len(dataset)):
32+
batch = dataset[batchId]
33+
loss, _, dist = model.eval(batch)
34+
predDist.append(dist)
35+
losses.append(loss)
36+
continue
37+
38+
loss = np.mean(losses)
39+
dist = np.mean(predDist)
40+
T = time.time() - T
41+
return loss, dist, T
42+
43+
def evaluate(dataset, model):
44+
loss, dist, T = _eval(dataset, model)
45+
print('Test | %.2f sec | Loss: %.5f. Distance: %.5f' % (
46+
T, loss, dist,
47+
))
48+
return loss, dist
49+
50+
def main(args):
51+
timesteps = args.steps
52+
folder = args.folder
53+
stats = None
54+
with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
55+
stats = json.load(f)
56+
57+
model = dict(timesteps=timesteps, stats=stats, use_encoders=False)
58+
assert args.model is not None, 'The model should be specified'
59+
if args.model is not None:
60+
model['weights'] = dict(folder=folder, postfix=args.model, embeddings=True)
61+
62+
model = CModelTrainer(**model)
63+
badDatasets = [] # list of tuples (userId, placeId, screenId) for the blacklisted datasets
64+
# find folders with the name "/test-*/"
65+
for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/')):
66+
evalDataset = CTestLoader(nm)
67+
loss, dist = evaluate(evalDataset, model)
68+
if args.threshold < dist:
69+
badDatasets.append(evalDataset.parametersIDs())
70+
continue
71+
# convert indices to the strings
72+
res = []
73+
for userId, placeId, screenId in badDatasets:
74+
userId = stats['userId'][userId]
75+
placeId = stats['placeId'][placeId]
76+
screenId = stats['screenId'][screenId]
77+
res.append((userId, placeId, screenId))
78+
continue
79+
print('Blacklisted datasets:')
80+
print(json.dumps(res, indent=2))
81+
# save the blacklisted datasets
82+
with open(os.path.join(folder, 'blacklist.json'), 'w') as f:
83+
json.dump(res, f, indent=2)
84+
return
85+
86+
if __name__ == '__main__':
87+
parser = argparse.ArgumentParser()
88+
parser.add_argument('--steps', type=int, default=5)
89+
parser.add_argument('--model', type=str, default='best')
90+
parser.add_argument('--folder', type=str, default=os.path.join(ROOT_FOLDER, 'Data'))
91+
parser.add_argument(
92+
'--threshold', type=float, required=True,
93+
)
94+
95+
main(parser.parse_args())
96+
pass

0 commit comments

Comments
 (0)