1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-.
3
+ '''
4
+ This script performs the following steps:
5
+ - Load the best model from the Data folder
6
+ - Load the test datasets from the Data/test-main folder
7
+ - Evaluate the model on the test datasets
8
+ - Add each test dataset to blacklists if the model mean loss is greater than the threshold
9
+ '''
10
+ # TODO: add the W&B integration
11
+ import argparse , os , sys
12
+ # add the root folder of the project to the path
13
+ ROOT_FOLDER = os .path .abspath (os .path .dirname (__file__ ) + '/../' )
14
+ sys .path .append (ROOT_FOLDER )
15
+
16
+ import numpy as np
17
+ from Core .CDatasetLoader import CDatasetLoader
18
+ from Core .CTestLoader import CTestLoader
19
+ from collections import defaultdict
20
+ import time
21
+ from Core .CModelTrainer import CModelTrainer
22
+ import tqdm
23
+ import json
24
+ import glob
25
+
26
+ def _eval (dataset , model ):
27
+ T = time .time ()
28
+ # evaluate the model on the val dataset
29
+ losses = []
30
+ predDist = []
31
+ for batchId in range (len (dataset )):
32
+ batch = dataset [batchId ]
33
+ loss , _ , dist = model .eval (batch )
34
+ predDist .append (dist )
35
+ losses .append (loss )
36
+ continue
37
+
38
+ loss = np .mean (losses )
39
+ dist = np .mean (predDist )
40
+ T = time .time () - T
41
+ return loss , dist , T
42
+
43
+ def evaluate (dataset , model ):
44
+ loss , dist , T = _eval (dataset , model )
45
+ print ('Test | %.2f sec | Loss: %.5f. Distance: %.5f' % (
46
+ T , loss , dist ,
47
+ ))
48
+ return loss , dist
49
+
50
+ def main (args ):
51
+ timesteps = args .steps
52
+ folder = args .folder
53
+ stats = None
54
+ with open (os .path .join (folder , 'remote' , 'stats.json' ), 'r' ) as f :
55
+ stats = json .load (f )
56
+
57
+ model = dict (timesteps = timesteps , stats = stats , use_encoders = False )
58
+ assert args .model is not None , 'The model should be specified'
59
+ if args .model is not None :
60
+ model ['weights' ] = dict (folder = folder , postfix = args .model , embeddings = True )
61
+
62
+ model = CModelTrainer (** model )
63
+ badDatasets = [] # list of tuples (userId, placeId, screenId) for the blacklisted datasets
64
+ # find folders with the name "/test-*/"
65
+ for nm in glob .glob (os .path .join (folder , 'test-main' , 'test-*/' )):
66
+ evalDataset = CTestLoader (nm )
67
+ loss , dist = evaluate (evalDataset , model )
68
+ if args .threshold < dist :
69
+ badDatasets .append (evalDataset .parametersIDs ())
70
+ continue
71
+ # convert indices to the strings
72
+ res = []
73
+ for userId , placeId , screenId in badDatasets :
74
+ userId = stats ['userId' ][userId ]
75
+ placeId = stats ['placeId' ][placeId ]
76
+ screenId = stats ['screenId' ][screenId ]
77
+ res .append ((userId , placeId , screenId ))
78
+ continue
79
+ print ('Blacklisted datasets:' )
80
+ print (json .dumps (res , indent = 2 ))
81
+ # save the blacklisted datasets
82
+ with open (os .path .join (folder , 'blacklist.json' ), 'w' ) as f :
83
+ json .dump (res , f , indent = 2 )
84
+ return
85
+
86
+ if __name__ == '__main__' :
87
+ parser = argparse .ArgumentParser ()
88
+ parser .add_argument ('--steps' , type = int , default = 5 )
89
+ parser .add_argument ('--model' , type = str , default = 'best' )
90
+ parser .add_argument ('--folder' , type = str , default = os .path .join (ROOT_FOLDER , 'Data' ))
91
+ parser .add_argument (
92
+ '--threshold' , type = float , required = True ,
93
+ )
94
+
95
+ main (parser .parse_args ())
96
+ pass
0 commit comments