Skip to content

Commit bfc2416

Browse files
A way to load and evaluate a model from a WandB run.
1 parent bf4ed1e commit bfc2416

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

train.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tensorflow as tf
66
from NN import model_from_config, model_to_architecture
77
from Utils import dataset_from_config
8+
from Utils.WandBUtils import CWBRun
89

910
def validateLayersNames(model):
1011
not_unique_layers = []
@@ -21,7 +22,13 @@ def validateLayersNames(model):
2122

2223
def main(args):
2324
folder = os.path.dirname(__file__)
24-
config = load_config(args.config, folder=folder)
25+
if args.wandb_id:
26+
run = CWBRun(args.wandb_id)
27+
config = run.config
28+
args.model = run.bestModel.pathTo()
29+
args.no_train = True
30+
else:
31+
config = load_config(args.config, folder=folder)
2532

2633
assert "experiment" in config, "Config must contain 'experiment' key"
2734
# store args as part of config
@@ -138,7 +145,8 @@ def main(args):
138145
parser.add_argument('--wandb', type=str, help='Wandb project name (optional)')
139146
parser.add_argument('--wandb-entity', type=str, help='Wandb entity name (optional)')
140147
parser.add_argument('--wandb-name', type=str, help='Wandb run name (optional)')
141-
148+
parser.add_argument('--wandb-id', type=str, help='Wandb run id, to load and test model (optional)')
149+
142150
args = parser.parse_args()
143151
if args.gpu_memory_mb: setGPUMemoryLimit(args.gpu_memory_mb)
144152
main(args)

0 commit comments

Comments
 (0)