-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathfind_best_ckpt.py
executable file
·54 lines (41 loc) · 1.12 KB
/
find_best_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import re
from pathlib import Path
def get_score(s: str) -> float:
"""Gets the criterion score from .ckpt formated with ModelCheckpoint
Parameters
----------
s : str
Assumption is the last number is the desired number carved in .ckpt
Returns
-------
float
The criterion float
"""
return float(re.findall(r"(\d+.\d+).ckpt", s)[0])
def main(path: str, op: str) -> str:
"""Finds the best ckpt path
Parameters
----------
path : str
ckpt path
op : str
"max" (for mAP for example) or "min" (for loss)
Returns
-------
str
A ckpt path
"""
ckpts = list(map(str, Path(path).glob("*.ckpt")))
if not len(ckpts):
return
ckpt_score_dict = {ckpt: get_score(ckpt) for ckpt in ckpts}
op = max if op == "max" else min
out = op(ckpt_score_dict, key=ckpt_score_dict.get)
print(out) # need to flush for bash
return out
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("provide checkpoint path and op either max or min")
sys.exit(1)
main(sys.argv[1], sys.argv[2])