-
Notifications
You must be signed in to change notification settings - Fork 3
/
search.py
35 lines (26 loc) · 816 Bytes
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import json
import os
import sys
from glob import glob
from tasks.utils import *
TASK = sys.argv[1]
MODEL = sys.argv[2]
PARADIGM = sys.argv[3]
ISPE = sys.argv[4]
if len(sys.argv) == 6:
METRIC = sys.argv[5]
elif TASK in GLUE_DATASETS + SUPERGLUE_DATASETS + OTHER_DATASETS:
METRIC = "accuracy"
elif TASK in NER_DATASETS + SRL_DATASETS + QA_DATASETS:
METRIC = "f1"
best_score = 0
files = glob(f"./checkpoints/{TASK}-{MODEL}-search/{PARADIGM}/{ISPE}/*/best_results.json")
for f in files:
metrics = json.load(open(f, 'r'))
if metrics["best_eval_"+METRIC] > best_score:
best_score = metrics["best_eval_"+METRIC]
best_metrics = metrics
best_file_name = f
print(f"best_{METRIC}: {best_score}")
print(f"best_metrics: {best_metrics}")
print(f"best_file: {best_file_name}")