|
1 |
| -from Utils.utils import setupGPU, load_config, setGPUMemoryLimit |
| 1 | +from Utils.utils import setupGPU, load_config, setGPUMemoryLimit, upgrade_configs_structure |
2 | 2 | setupGPU() # call it on startup to prevent OOM errors on my machine
|
3 | 3 |
|
4 | 4 | import argparse, os, shutil, json
|
5 | 5 | import tensorflow as tf
|
6 | 6 | from NN import model_from_config, model_to_architecture
|
7 | 7 | from Utils import dataset_from_config
|
8 | 8 |
|
| 9 | +def validateLayersNames(model): |
| 10 | + not_unique_layers = [] |
| 11 | + layers_names = set() |
| 12 | + for layer in model.trainable_variables: |
| 13 | + if layer.name in layers_names: |
| 14 | + not_unique_layers.append(layer.name) |
| 15 | + layers_names.add(layer.name) |
| 16 | + continue |
| 17 | + for layer in not_unique_layers: |
| 18 | + print(f"Layer name '{layer}' is not unique") |
| 19 | + assert not not_unique_layers, "Model contains not unique layers names" |
| 20 | + return |
| 21 | + |
9 | 22 | def main(args):
|
10 | 23 | folder = os.path.dirname(__file__)
|
11 | 24 | config = load_config(args.config, folder=folder)
|
| 25 | + |
12 | 26 | assert "experiment" in config, "Config must contain 'experiment' key"
|
13 | 27 | # store args as part of config
|
14 | 28 | config['experiment']['command line arguments'] = vars(args)
|
@@ -37,17 +51,7 @@ def main(args):
|
37 | 51 | # Create model
|
38 | 52 | model = model_from_config(config["model"], compile=True)
|
39 | 53 | model.summary(expand_nested=True)
|
40 |
| - # check if model is contain only unique layers names |
41 |
| - not_unique_layers = [] |
42 |
| - layers_names = set() |
43 |
| - for layer in model.trainable_variables: |
44 |
| - if layer.name in layers_names: |
45 |
| - not_unique_layers.append(layer.name) |
46 |
| - layers_names.add(layer.name) |
47 |
| - continue |
48 |
| - for layer in not_unique_layers: |
49 |
| - print(f"Layer name '{layer}' is not unique") |
50 |
| - assert not not_unique_layers, "Model contains not unique layers names" |
| 54 | + validateLayersNames(model) |
51 | 55 | # save to config model architecture and number of parameters
|
52 | 56 | config['architecture'] = model_to_architecture(model)
|
53 | 57 |
|
@@ -82,10 +86,13 @@ def main(args):
|
82 | 86 | ),
|
83 | 87 | tf.keras.callbacks.TerminateOnNaN(),
|
84 | 88 | ]
|
85 |
| - |
86 |
| - if args.wandb: # init wandb |
| 89 | + |
| 90 | + if args.wandb: |
87 | 91 | import wandb
|
| 92 | + |
88 | 93 | wandb.init(project=args.wandb, entity=args.wandb_entity, config=config)
|
| 94 | + # assign run name if specified |
| 95 | + if args.wandb_name: wandb.run.name = args.wandb_name |
89 | 96 | # track model metrics only
|
90 | 97 | callbacks.append(wandb.keras.WandbCallback(
|
91 | 98 | save_model=False, # save model to wandb manually
|
@@ -130,6 +137,7 @@ def main(args):
|
130 | 137 |
|
131 | 138 | parser.add_argument('--wandb', type=str, help='Wandb project name (optional)')
|
132 | 139 | parser.add_argument('--wandb-entity', type=str, help='Wandb entity name (optional)')
|
| 140 | + parser.add_argument('--wandb-name', type=str, help='Wandb run name (optional)') |
133 | 141 |
|
134 | 142 | args = parser.parse_args()
|
135 | 143 | if args.gpu_memory_mb: setGPUMemoryLimit(args.gpu_memory_mb)
|
|
0 commit comments