-
Notifications
You must be signed in to change notification settings - Fork 0
/
ndnet.py
2229 lines (2000 loc) · 99.7 KB
/
ndnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 23 14:17:11 2018
@author: soenke
adapted from
https://github.com/zhengyang-wang/Unet_3D/tree/master/model
and inspired by
https://github.com/aicodes/tf-bestpractice/blob/master/README.md
"""
import numpy as np
import tensorflow as tf
import os
import json
import csv
from warnings import warn
import datetime
from math import ceil
from itertools import permutations
# own modules
import dataset_handlers as dh
import loss_functions as lf
import network_architectures as na
import training_utils
from training_utils import experimental_model_params
from decorators import deprecated
# for tensorflow 1.14 use this to avoid some warnings:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
#from tf_toolbox.tf_toolbox import _ftconvolve # -> needed for unsupervised
# IMPORTANT:
# if you get an error with mkl while running the test on cpu,
# switch to the eigen builds of tensorflow
#IDEAS:
## use explicit graph
# self.graph = tf.graph()
# with tf.Graph().as_default() as g (but graph is included in sess right?)
# --> need to learn more about tf internals
## build inference during init or using a build_network method
# --> would probably be cleaner
## configure session this way
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
# config = tf.ConfigProto(gpu_options=gpu_options)
# sess = tf.Session(config) # or similar
# /// or sth like
# session_config.gpu_options.visible_device_list= '1' #only see the gpu 1
# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
# session = tf.Session(config=config, ...)
# config=tf.ConfigProto(log_device_placement=True)
# --> currently needs all memory anyways...
## network architecture:
# allow circular/reflection padding!
# change conv size in activ3 in unet3d.build_final_blocks to (1,1,1)
# --> test other settings first to see if this is necessary
# use fully connected layer in the end (?)
## adaption to input:
# allow different network_depths in different directions
# -->
## pre- and post-processing:
# should I omit post-processing?
# should I scale output from postprocessing to 0..255 (?)
# --> experiment
## profile where work is done in net
# https://www.tensorflow.org/guide/graph_viz
# run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
# run_metadata = tf.RunMetadata()
# summary, _ = sess.run([merged, train_step],
# feed_dict=feed_dict(True),
# options=run_options,
# run_metadata=run_metadata)
# train_writer.add_run_metadata(run_metadata, 'step%d' % i)
# train_writer.add_summary(summary, i)
# print('Adding run metadata for', i)
# --> interesting, but not so important
#
# histogram of activations
# https://www.tensorflow.org/guide/tensorboard_histograms
# --> would be nice
#
# debugging
# https://www.tensorflow.org/guide/debugger
# --> if necessary
#
# possibly omit _parse_load_step in favor of using sth like this:
# https://stackoverflow.com/questions/45077445/ ...
# ... how-to-use-method-recover-last-checkpoints-of-tf-train-saver
# max number of parameters is about 10 million, but it depends
# at which sampling they occur
# posibly docment attributes here
class NDNet(object):
def __init__(
self, sess,
arch="unetv3", padding="same", net_channel_growth=1,
force_pos=False, normalize_input=False,
use_batch_renorm=False, use_batch_norm=False,
last_layer_batch_norm=None,
custom_preprocessing_fn=None, custom_postprocessing_fn=None,
dataset_means=None, dataset_stds=None,
data_format="channels_last", comment=None):
"""
NDNet provides a framework for training, running and testing Neural
Networks for image-to-image translation (any mapping from image scale
to image scale including deconvolution, denoising or semantic
segmentation).
Training and testing data consists of pairs of images (x,y),
where x and y must have the same spatial shape. The number of input
and output channels may vary according to net_channel_growth, which
can be a fraction.
It supports 2d or 3d network architectures, and likewise
such input images. It manages all attributes that are common to
training, running and testing including:
- a tf-session (note: might be removed from it in the future)
- network architecture, pre- and postprocessing
- details
This class is designed to standardize tasks such as data loading via
tf-data, run inference and calculate loss.
Args:
sess: a tf-session
.
# net features
arch (str): defines network-architecture. Currently only "unetv3"
is officially supported. There is also a number of
experimental models (such as "_unetv3_2d") in training_utils.
Note that the dimensionality of the network is encoded in
the arch-string.
Default: "unetv3" (a 3d model)
padding (str): padding of all convolutions. Can be "same" or
"valid". Default: "same"
net_channel_growth (float) : EXPERIMENTAL: change this, if you
want to have a different number of in- and output channels.
In initial tests, it was possible to set this even to float
values such as 2/3 (from 3 channel input to 2 channel output).
But I am actually surprised that worked, so use this with care.
force_pos (bool) : add positivity constraint by squaring
network output. Default: False
normalize_input (bool) : Do preprocessing using mean subtraction
with dataset_means and std normalization using dataset_stds.
Default: False
use_batch_renorm (bool) : Add batch normalization layers and
use batch renormalization scheme as suggested in
arXiv:1702.03275v2. Default: False
use_batch_norm (bool) : Add batch normalization layers and perform
standard batch normalization. Default: False
# details
last_layer_batch_norm (None or bool) : perform batch_nom also in
last layer. None means True if use_batch_(re)normalization
else False. Default: None
custom_preprocessing_fn (None or function) : a function of a
single image from the dataset. It must have one argument
(the input tensor) and return one output tensor
Default: None
custom_postprocessing_fn (None or function) : a function of a
single image from the dataset. It must have one argument
(the input tensor) and return one output tensor
Default: None
dataset_means (None or float) : used for preprocessing.
Default: None
dataset_stds (None or float) : used for preprocessing.
Default: None
data_format (str) : data format of arrays in tf-graph. Used by
tf.layers module. Can be "channels_last" or "channels_first"
(TODO: more explicit 2D/3D would be to specify ndhwc etc.)
comment (None or str) : will be added to model_id
Default: None
Note:
Preprocessing is applied in the order::
force_pos -> custom_preprocessing -> normalize_input
Postprocessing undoes this in the order::
undo_normalization -> custom_postprocessing -> force_pos
"""
# Note:
# The network operates on minibatches of data that have shape
# (N, (D,) H, W, C), i.e. consisting of N images, each with (depth D,)
# height H and width W and with C input channels.
self.sess = sess
self.arch = arch # store more info in this
self.force_pos = force_pos
# it could be useful to have self.data_format, because in principle
# pre- and postprocessing could have different data_format than
# the model. But for now this will not be allowed.
# self.data_format = data_format
self.normalize_input = normalize_input
if self.normalize_input:
if dataset_means is None and dataset_stds is None:
raise ValueError(
"normalize_input is True, but neither dataset_means " +
"nor dataset_stds are provided.")
elif dataset_means is None:
print("Only dataset_stds was given. Setting dataset_means " +
"to default value 0.")
dataset_means = 0
elif dataset_stds is None:
print("Only dataset_means was given. Setting dataset_stds " +
"to default value 1.")
dataset_stds = 1
if dataset_means == 0 and dataset_stds == 1:
print("skipping preprocessing, because dataset_means is " +
"already 0 and dataset_stds is already 1.")
self.normalize_input = False
self.dataset_means = dataset_means
self.dataset_stds = dataset_stds
# TODO (EXPERIMENTAL -> TEST)
# TODO: possibly do shape corrections before applying
# custom_postprocessing_fn
# TODO:
self.custom_preprocessing_fn = custom_preprocessing_fn
self.custom_postprocessing_fn = custom_postprocessing_fn
# define model
if arch == "unetv3":
network_depth=3
initial_channel_growth=32
channel_growth=2
conv_size=(3,3,3)
pool_size=(2,2,2)
input_output_skip=False
nonlinearity=tf.nn.relu # must be function of a tensor or None
elif arch == "unetv3_small":
network_depth=2
initial_channel_growth=2
channel_growth=2
conv_size=(3,3,3)
pool_size=(2,2,2)
input_output_skip=False
nonlinearity=tf.nn.relu # must be function of a tensor or None
elif arch in experimental_model_params.keys():
d = experimental_model_params[arch]
network_depth = d["network_depth"]
initial_channel_growth = d["initial_channel_growth"]
channel_growth = d["channel_growth"]
conv_size = d["conv_size"]
pool_size = d["pool_size"]
input_output_skip = d["input_output_skip"]
nonlinearity = d["nonlinearity"]
else:
raise ValueError(
"Unsupported arch '" + arch + "'. " +
"Currently the models 'unetv3' and 'unetv3_small' are " +
"officially supported. Experimental models are " +
str(list(experimental_model_params.keys())) + ".")
self.model = na.unet.Unet_v3(
padding=padding,
nonlinearity=nonlinearity,
network_depth=network_depth,
net_channel_growth=net_channel_growth, # experimental
initial_channel_growth=initial_channel_growth,
channel_growth=channel_growth,
conv_size=conv_size,
pool_size=pool_size,
use_batch_renorm=use_batch_renorm,
use_batch_norm=use_batch_norm,
last_layer_batch_norm=last_layer_batch_norm,
data_format=data_format,
input_output_skip=input_output_skip)
# saving checkpoints in
# self.modeldir/model_id/dataset_id/run_id/run_0/"ckpts"/run_0".ckpt")
# and logs in
# self.modeldir/model_id/dataset_id/run_id/run_0/"logs"/xxx".logs")
# ---> just delete the folder containing model to delete both
self.model_id = self._set_model_id(comment=comment)
# TODO: possible os.path.abspath("./models") to fix windows-problem
self.modeldir = "models"
## High level control
def train(self, training_dataset_handler, n_epochs, batch_size,
optimizer_fn=lambda lr : tf.train.AdamOptimizer(lr),
learning_rate_fn=lambda training_step: 1e-3,
loss_fn=tf.losses.mean_squared_error, cut_loss_to_valid=False,
weight_reg_str=None, weight_reg_fn=None, data_reg_str=None,
data_reg_fn=None,
ckpt=None, load_step=None, random_seed=None,
weight_init=None,
batch_renorm_fn=None, dropout_rate=0.0,
validate=False, validation_dataset_handler=None,
summary_interval=None, save_interval=None,
comment=None):
"""
Configure network for training,
load training data
and run optimizer loop on vars to minimize loss.
Args:
training_dataset_handler (dataset_handler) : dataset_handler from
dataset_handlers.tf_data_dataset_handlers. These provide a thin
layer around tf.data datasets. If you have an existing tf.data
dataset, using BaseDatasetHandler is sufficient. The module also
provides Handlers that can be initialized from numpy-arrays
or lists of files.
n_epochs (int) : number of training epochs. An epoch is completed,
when the network has seen all images once.
batch_size (int) : training batch size.
.
# optimization params
optimizer_fn (function) : optimizer_fn must be a function
taking one parameter (learning rate) and return an optimizer
operation. An optimizer can be transformed to an optimizer_fn
as simple as:
optimizer_fn = lambda lr: tf.train.AdamOptimizer(lr)
This is also the default.
learning_rate_fn (function) : learning_rate_fn must be a function
taking one parameter (global_step) and return a
float/tf.constant/tf.Variable. The function makes it easier to
define learning rate decay.
Default: A constant learning rate of 1e-3:
learning_rate_fn = lambda training_step: 1e-3
# loss function
loss_fn (function) : Loss_fn must be a function taking two parameters
(labels, predictions) and return a loss-tensor. Common losses can
be found in tf.losses and in this repo's loss_functions (lf)
Default: tf.losses.mean_squared_error
cut_loss_to_valid (bool) : This can be used for 'same' padding to
calculate the loss only on that part of the image that is not
impacted by padding. This has no impact, if padding='valid'
Default: False
weight_reg_str (float) : float to scale the strength of weight
regularization (similar to weight decay)
weight_reg_fn (function) : a function of a single weight, typically
the square or absolute value. All values of reg_fn(weight) are
added and multiplied by weight_reg_str and then added to the loss.
data_reg_str (float) : float to scale the strength of data
regularization.
data_reg_fn (function) : a function of a single image returning a loss
tensor. Its value is multiplied by data_reg_str and then added
to the loss. NOTE that this currently only works correctly for
a batch size of 1.
# model loading or init
ckpt (None or str) : path to checkpoint that will be loaded to
initialize training. load_step should not be included in ckpt,
but should be provided separately. If None, network will be
randomly initialized depending on weight_init and random_seed
load_step (None, "previous" or int) : step that will be loaded from ckpt.
"previous" is converted to the last ckpt that was written
to the folder containing the ckpt.
Cannot be None, if ckpt is not None.
random_seed (None or int) : The random seed is passed to all random
number generators, e.g. to shuffling of dataset and initialization
of weights. If left as "None", the shuffling or initialization
cannot be deterministically repeated.
NOTE: This should be changed when loading from ckpt. Otherwise
the exact same sequence will be returned again.
weight_init (None or initializer) : if no ckpt is provided, an
initializer (eg. from tf.initializers or from this repo's
training_utils must be provided. Recommendation for ReLU activation:
training_utils.he_init
# training specific features
batch_renorm_fn (None or function) : batch_renorm_fn must be a function
taking one parameter (global_step) and return rmin, rmax, dmax
values for clipping as described in the batch renormalization paper.
See default_batch_renorm_scheme as an example, which imitates the
suggested scheme from the paper. If self.use_batch_renorm is
False, this has no impact.
NOTE: I do not recommend using batch_renorm any more and this
may disappear in the future.
dropout_rate (float) : fraction of the activations that are dropped
out.
validate (bool) : decides whether or not to put mean loss on
validation set in summary. If True, validation_dataset_handler
must also be provided.
validation_dataset_handler (None or dataset_handler) : See
training_dataset_handler for details about dataset_handler. The
loss for every image in the validation_dataset will be calculated
sequentially and then the mean is calculated.
NOTE that this will slow down training if done frequently and on a
large set.
summary_interval (int or None) : number of steps after which a log
is written. Default (None): 2 logs per epoch
save_interval (int or None) : number of steps after which a ckpt is
written. Default (None): save ckpt after every 2 epochs
comment (None or str) : comment can be provided to modify run_id to
label runs
Returns:
None. But updates variables of model and saves ckpts.
Raises:
TODO
"""
(dataset, training, total_loss, saver, new_ckpt, writer, summary,
validation_dataset, val_loss_single, val_loss_ph) = self._train_init(
training_dataset_handler, n_epochs, batch_size,
# model loading or init
ckpt=ckpt, load_step=load_step, random_seed=random_seed,
weight_init=weight_init,
# training specific args
loss_fn=loss_fn, cut_loss_to_valid=cut_loss_to_valid,
weight_reg_str=weight_reg_str, weight_reg_fn=weight_reg_fn,
data_reg_str=data_reg_str, data_reg_fn=data_reg_fn,
batch_renorm_fn=batch_renorm_fn,
dropout_rate=dropout_rate, validate=validate,
validation_dataset_handler=validation_dataset_handler,
optimizer_fn=optimizer_fn, learning_rate_fn=learning_rate_fn,
# added to run_id
comment=comment)
self._train_loop(
n_epochs, batch_size, dataset, total_loss, training, saver,
new_ckpt, writer, summary, validation_dataset, val_loss_single,
val_loss_ph, summary_interval=summary_interval,
save_interval=save_interval)
# TODO: make trainer object
def _train_init(
self, training_dataset_handler, n_epochs, batch_size,
# model loading or init
ckpt=None, load_step=None, random_seed=None,
weight_init=None,
# training specific args
loss_fn=tf.losses.mean_squared_error, cut_loss_to_valid=False,
weight_reg_str=None, weight_reg_fn=None, data_reg_str=None,
data_reg_fn=None, batch_renorm_fn=None, dropout_rate=0.0,
validate=False, validation_dataset_handler=None,
optimizer_fn=lambda lr : tf.train.AdamOptimizer(lr),
learning_rate_fn=lambda training_step: 1e-3,
# appended to run_id
comment=None):
# TODO: kwarg: add summary (and which kinds)
is_training = True
tf.set_random_seed(random_seed) # for tf-random generators in the graph
# get infos from ckpt
load_step = _parse_load_step(ckpt, load_step)
if ckpt:
if weight_init:
warn("weight_init is " + str(weight_init) + ". It will be " +
"ignored. Weights are not re-initialized, when loading " +
"from ckpt.")
weight_init=None # already init'ed
# TODO training will restart with same sequence if initialized
# with same random seed. Possible solutions that are not yet
# implemented:
# -> save random seed in ckpt/model id or use global_step
# -> or make random_seed placeholder/Variable
# -> or better: save iterator state in ckpt!
# -> just issue a warning
else:
ckpt_dataset_id = ""
ckpt_run_id = ""
# define basic run parameters
self.model.extra_training_parameters(
weight_init=weight_init, # do not change in case of relu
batch_renorm_fn=batch_renorm_fn,
dropout_rate=dropout_rate) # often not used
with tf.name_scope("training_input_and_preprocessing"):
dataset = self.TrainingDatasetAndPreprocess(
training_dataset_handler=training_dataset_handler,
batch_size=batch_size, n_epochs=n_epochs,
random_seed=random_seed)
x_batch, y_batch = dataset.next_batch()
input_shape = dataset.x_shape
print("input_shape:", input_shape)
with tf.variable_scope("model"):
# x_batch is already preprocessed by dataset_handler
y_predicted_batch = self.model.inference(x_batch, is_training)
y_predicted_batch = self.postprocess(y_predicted_batch, input_shape)
print("output_shape:", y_predicted_batch.shape)
with tf.name_scope("losses"):
data_loss, weight_reg_loss, data_reg_loss= self.calculate_losses(
y_batch, y_predicted_batch, loss_fn=loss_fn,
cut_loss_to_valid=cut_loss_to_valid,
weight_reg_str=weight_reg_str, weight_reg_fn=weight_reg_fn,
data_reg_str=data_reg_str, data_reg_fn=data_reg_fn)
total_loss = data_loss + weight_reg_loss + data_reg_loss
if validate:
with tf.name_scope("validation_input_and_preprocessing"):
if validation_dataset_handler is None:
raise ValueError(
"validate is True, but " +
"validation_dataset_handler is None.")
validation_dataset = self.ValidationDatasetAndPreprocess(
validation_dataset_handler=validation_dataset_handler,
batch_size=batch_size)
if not validation_dataset.n_images:
raise ValueError(
"validate is True, but validation_dataset " +
"contains no images.")
# this is ideally always the same sequence
x_val_batch, y_val_batch = validation_dataset.next_batch()
print("validating with validation set (", validation_dataset.n_images,
"images sequentially) during every summary.")
with tf.name_scope("validation"):
with tf.variable_scope("model", reuse=True):
y_predicted_val_batch = self.model.inference(
x_val_batch, is_training=False, print_info=False)
y_predicted_val_batch = self.postprocess(
y_predicted_val_batch, input_shape)
# currently only validates data_loss for performance reasons
val_loss_single = self.data_loss(y_val_batch, y_predicted_val_batch,
loss_fn=loss_fn, cut_loss_to_valid=cut_loss_to_valid)
# trick with placeholder to be able to calculate val_losses
# over entire validation set, before taking the mean
val_loss_ph = tf.placeholder(tf.float32, name="val_loss")
else:
val_loss_single = None
validation_dataset = None
val_loss_ph = None
# get training op
global_step = tf.train.create_global_step()
tf.assign(global_step, load_step)
optimizer = self._define_optimizer(global_step,
optimizer_fn=optimizer_fn, learning_rate_fn=learning_rate_fn)
# extra_update_ops part is needed when using batch_norm (see tf-docs)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
training = optimizer.minimize(total_loss, global_step=global_step)
# define id of new model
dataset_id = dataset.dataset_id
run_id = _run_id(ckpt=ckpt,
load_step=load_step,
random_seed=random_seed,
batch_size=batch_size,
dropout_rate=dropout_rate,
optimizer_fn=optimizer_fn,
learning_rate_fn=learning_rate_fn,
loss_fn=loss_fn,
cut_loss_to_valid=cut_loss_to_valid,
weight_reg_str=weight_reg_str,
weight_reg_fn=weight_reg_fn,
data_reg_str=data_reg_str,
data_reg_fn=data_reg_fn,
comment=comment)
# run_name is determined depending on existing runs
# runs are numbered from 0 upwards.
# get ids of ckpt
(ckpt_model_id,
ckpt_dataset_id,
ckpt_run_id,
ckpt_run_name) = _parse_ckpt_info(ckpt)
# create folders to store ckpts and logs
# TODO: The way it is done now can lead to a very deep folder structure.
# -> Omit saving in subfolders in favor of putting
# model_id, run_id etc in a config file
# TODO: windows paths are restricted to 256 characters!
run_dir = os.path.join(self.modeldir, self.model_id, ckpt_dataset_id,
ckpt_run_id, dataset_id, run_id)
#if os.name == 'nt': # windows
if len(run_dir) > 200:
# I think windows restriction is 256 chars
warn("Path to run_dir is very long. " +
"Consider also that the length of windows paths is " +
"restricted and that additional chars are added " +
"during ckpt generation.")
if not os.path.exists(run_dir):
os.makedirs(run_dir)
run_name = _run_name(run_dir)
if _exists_in_rundir(run_name, run_dir): # I think this is redundant.
raise RuntimeError(
"run_name already exists in run_dir. This is a bug. " +
"Please fix. Old ckpt was not overridden.")
run_dir = os.path.join(run_dir, run_name)
print("saving new ckpt and logs in", run_dir)
logdir = os.path.join(run_dir, "logs")
ckptdir = os.path.join(run_dir, "ckpts")
if not os.path.exists(logdir):
os.makedirs(logdir)
if not os.path.exists(ckptdir):
os.makedirs(ckptdir)
new_ckpt = os.path.join(ckptdir, run_name)
# op for loading and saving ckpts
saver = tf.train.Saver(
max_to_keep=5, keep_checkpoint_every_n_hours=2)
# op for logs
loss_summaries = self._loss_summaries(
total_loss, data_loss, weight_reg_loss, data_reg_loss)
if validate:
loss_summaries.append(tf.summary.scalar("val_loss", val_loss_ph))
image_summaries = self._image_summaries(
x_batch, y_batch, y_predicted_batch, tf.nn.softmax)
prediction_summaries = self._prediction_summaries(y_predicted_batch)
summary = tf.summary.merge(
[loss_summaries, image_summaries, prediction_summaries])
writer = tf.summary.FileWriter(logdir, self.sess.graph)
writer.add_graph(tf.get_default_graph()) # explicitly adding graph
# load vars from ckpt into model or initialize new model
if ckpt:
if self._ckpt_compatible(ckpt_model_id):
_check_ids(dataset_id, run_id, ckpt_dataset_id, ckpt_run_id)
self._load_ckpt(saver, ckpt, load_step) # might throw error here
else:
raise ValueError("ckpt is not compatible with model.")
else:
self.sess.run(tf.global_variables_initializer())
return (dataset, training, total_loss, saver, new_ckpt, writer, summary,
validation_dataset, val_loss_single, val_loss_ph)
# TODO: make trainer object
def _train_loop(
self, n_epochs, batch_size, dataset, total_loss, training, saver,
new_ckpt, writer, summary,
validation_dataset, val_loss_single, val_loss_ph,
summary_interval=None, save_interval=None):
validate = validation_dataset is not None
# tf.data.Dataset makes last batch smaller, if needed.
n_batches_per_epoch = ceil(dataset.n_images / batch_size)
if validate:
n_val_batches = ceil(validation_dataset.n_images / batch_size)
n_iterations_per_epoch = n_batches_per_epoch
if summary_interval is None:
print("Saving 2 logs per epoch by default.")
summary_interval = int(0.5*n_iterations_per_epoch) or 1 # iterations
if save_interval is None:
print("Saving checkpoint every 2 epochs by default.")
save_interval = 2*n_iterations_per_epoch # iterations
# validation_interval = 2*summary_interval if validate else None
# training loop
# run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True) #DBG
# self.sess.run(..., options=run_options) # DBG
global_step = tf.train.get_global_step()
step = tf.train.global_step(self.sess, global_step)
print("starting training with start_step", step)
for epoch in range(n_epochs):
print("epoch", epoch+1, "/", n_epochs)
for iteration in range(n_iterations_per_epoch):
if step % save_interval == 0:
# saves the state before the iteration!
print('---->saving', step)
saver.save(self.sess, new_ckpt, global_step=step) # test
if step % summary_interval == 0:
# summarizes state before iteration!
print('---->summarizing', step)
if validate:
# if step % validation_interval == 0:
val_losses = list()
print("----> calculating losses on validation set")
for _ in range(n_val_batches):
val_losses.append(self.sess.run(val_loss_single))
val_lossc = np.mean(val_losses)
lossc, summaryc, _ = self.sess.run(
[total_loss, summary, training],
feed_dict={val_loss_ph:val_lossc})
print("---->done", val_losses)
print('---->validation loss', val_lossc)
# else:
# lossc, summaryc, _ = self.sess.run(
# [totalloss, summary, training],
# feed_dict={val_loss_ph:None})
else:
lossc, summaryc, _ = self.sess.run(
[total_loss, summary, training])
writer.add_summary(summaryc, step)
print('---->loss', lossc)
else: # do the same, but without summary
lossc, _ = self.sess.run([total_loss, training])
print("iteration", iteration+1, "/", n_iterations_per_epoch)
step = tf.train.global_step(self.sess, global_step)
try:
# cannot ensure that final summary is run because of iterator
if validate:
val_losses = list()
print("---->calculating losses on validation set")
for _ in range(n_val_batches):
val_losses.append(self.sess.run(val_loss_single))
val_lossc = np.mean(val_losses)
lossc, summaryc = self.sess.run(
[total_loss, summary], feed_dict={val_loss_ph:val_lossc})
print("---->done", val_losses)
print('---->validation loss', val_lossc)
lossc, summaryc = self.sess.run([total_loss, summary])
writer.add_summary(summaryc, step)
print('---->summarizing', step, '(final state)')
print('---->loss', lossc)
warn("I was expecting to be at the end of sequence here.")
except tf.errors.OutOfRangeError:
print("End of sequence")
print("It is not possible to run the final summary. " +
"\nThis behaviour is expected.")
print('---->saving', step, '(final state)')
writer.flush()
saver.save(self.sess, new_ckpt, global_step=step)
# TODO: pass x_format as format string (?)
def run_on_image(self, np_x, ckpt, load_step="previous", data_format=None):
"""
Load model from ckpt and use on an input image np_x.
Args:
np_x (np-array) : image as np-array.
Must be 2D+channel in the form "height-width-channel" or
3D+channel in the form "depth-height-width-channel."
ckpt (str) : path to checkpoint that will be loaded.
load_step should not be included in ckpt, but should be provided
separately.
load_step ("previous" or int) : step that will be loaded from ckpt.
"previous" is converted to the last ckpt that was written
to the folder containing the ckpt.
Returns:
np_y_pred : np-array of the same number of dimensions as np_x.
Output size depends on convolution mode of net (same or valid).
"""
is_training = False
self._check_data_format(data_format)
# setup inference
with tf.name_scope("input_and_preprocessing"):
x = tf.constant(np_x, dtype=tf.float32, shape=np_x.shape)
print("input shape:", x.shape)
input_shape = tf.expand_dims(x, self.model.im_axis).shape
x_batch = tf.expand_dims(self.preprocess(x), self.model.im_axis)
with tf.variable_scope("model"):
self.model.set_ready() # going to load from ckpt
y_pred = self.model.inference(x_batch, is_training)
#print("net output_shape:", y_pred.shape)
y_pred = self.postprocess(y_pred, input_shape)
y_pred = tf.squeeze(y_pred, axis=self.model.im_axis)
print("output_shape:", y_pred.shape)
# Load model from ckpt
saver = tf.train.Saver(max_to_keep=None)
load_step = _parse_load_step(ckpt, load_step)
ckpt_model_id, ckpt_dataset_id, ckpt_run_id, ckpt_run_name = \
_parse_ckpt_info(ckpt)
if self._ckpt_compatible(ckpt_model_id):
self._load_ckpt(saver, ckpt, load_step, False)
else:
raise ValueError("ckpt is not compatible with model.")
np_y_pred = self.sess.run(y_pred)
# write graph to tensorboard. Will show up as . in tensorboard
run_dir = os.path.join(
self.modeldir, self.model_id, ckpt_dataset_id, ckpt_run_id,
ckpt_run_name)
logdir = os.path.join(run_dir, "logs")
writer = tf.summary.FileWriter(logdir=logdir, graph=self.sess.graph)
writer.flush()
return np_y_pred
# make loss_fn required arg
# add weight_reg?
def test(self, testing_dataset_handler, ckpt, load_step="previous",
loss_fn=tf.losses.mean_squared_error, cut_loss_to_valid=False,
data_reg_str=None, data_reg_fn=None, batch_size=1):
"""
Configure network for testing,
Load model from ckpt
load testdata
calculate losses on test set
Args:
testing_dataset_handler (dataset_handler) : dataset_handler from
dataset_handlers.tf_data_dataset_handlers that provides method
to get x and y from test set. DatasetHandlers provide a thin
layer around tf.data datasets. If you have an existing tf.data
dataset, using BaseDatasetHandler is sufficient. The module also
provides Handlers that can be initialized from numpy-arrays
or lists of files.
ckpt (str) : path to checkpoint that will be loaded.
load_step should not be included in ckpt, but should be provided
separately.
load_step ("previous" or int) : step that will be loaded from ckpt.
"previous" is converted to the last ckpt that was written
to the folder containing the ckpt.
loss_fn (function) : Loss_fn must be a function taking two parameters
(labels, predictions) and return a loss-tensor. Common losses can
be found in tf.losses and in this repo's loss_functions (lf)
Default: tf.losses.mean_squared_error
cut_loss_to_valid (bool) : This can be used for 'same' padding to
calculate the loss only on that part of the image that is not
impacted by padding. This has no impact, if padding='valid'
Default: False
data_reg_str (float) : float to scale the strength of data
regularization.
data_reg_fn (function) : a function of a single image returning a loss
tensor. Its value is multiplied by data_reg_str and then added
to the loss. NOTE that this currently only works correctly for
a batch size of 1.
batch_size (int) : choose how many losses are calculated
simultaneously. Use a larger batch size to speed up testing.
Returns:
total_loss : (= data_loss + data_reg_loss)
.
"""
is_training = False
# load and preprocess dataset
with tf.name_scope("training_input_and_preprocessing"):
dataset = self.TestDatasetAndPreprocess(
testing_dataset_handler=testing_dataset_handler,
batch_size=batch_size)
input_shape = dataset.x_shape
print("input_shape:", input_shape)
x_batch, y_batch = dataset.next_batch()
with tf.variable_scope("model"):
# x_batch is already preprocessed by dataset_handler
self.model.set_ready()
y_predicted_batch = self.model.inference(x_batch, is_training)
y_predicted_batch = self.postprocess(y_predicted_batch, input_shape)
print("output_shape:", y_predicted_batch.shape)
with tf.name_scope("losses"):
data_loss_single, _, data_reg_loss_single = self.calculate_losses(
y_batch, y_predicted_batch, loss_fn=loss_fn,
cut_loss_to_valid=cut_loss_to_valid,
data_reg_str=data_reg_str, data_reg_fn=data_reg_fn)
total_loss_single = data_loss_single + data_reg_loss_single
# Load model from ckpt
saver = tf.train.Saver(max_to_keep=None)
load_step = _parse_load_step(ckpt, load_step)
dataset_id = dataset.dataset_id
ckpt_model_id, ckpt_dataset_id, ckpt_run_id, ckpt_run_name = \
_parse_ckpt_info(ckpt)
if self._ckpt_compatible(ckpt_model_id):
_check_dataset_ids(dataset_id, ckpt_dataset_id)
self._load_ckpt(saver, ckpt, load_step, False)
else:
raise ValueError("ckpt is not compatible with model.")
print("---->calculating losses on test set")
n_batches = ceil(dataset.n_images / batch_size)
data_losses = list()
data_reg_losses = list()
total_losses = list()
for _ in range(n_batches):
(data_loss_singlec,
data_reg_loss_singlec,
total_loss_singlec) = self.sess.run(
[data_loss_single,
data_reg_loss_single,
total_loss_single])
data_losses.append(data_loss_singlec)
data_reg_losses.append(data_reg_loss_singlec)
total_losses.append(total_loss_singlec)
# data_lossc = np.mean(data_losses)
# data_reg_lossc = np.mean(data_reg_losses)
lossc = np.mean(total_losses)
print("---->data losses: ", data_losses)
print("---->data reg losses:", data_reg_losses)
print("---->total losses: ", total_losses)
print('---->mean total loss:', lossc)
return lossc
## lower level functions
# TODO: this differs from dataset_handler._set_data_format
def _check_data_format(self, data_format):
# TODO: change default to the format of model?
if data_format is None:
warn("data_format is None. I am unable to check, if input " +
"data_format matches model.data_format. Set input " +
"data_format to avoid this warning.")
elif data_format in ["channels_last", "channels_first"]:
if data_format != self.model.data_format:
raise RuntimeError(
"data_format " + data_format + " does not match " +
"model.data_format " + self.model.data_format + ".")
else:
raise ValueError("Unknown data_format: " + data_format)
# summaries
def _loss_summaries(self, total_loss, data_loss, weight_reg_loss,
data_reg_loss):
summaries = []
summaries.append(tf.summary.scalar("total_loss", total_loss))
summaries.append(tf.summary.scalar("data_loss", data_loss))
summaries.append(tf.summary.scalar("weight_reg_loss", weight_reg_loss))
summaries.append(tf.summary.scalar("data_reg_loss", data_reg_loss))
return summaries
# TODO: need to update this to allow different channels and 2d
# TODO: this is quite data dependent.
def _image_summaries(self, x_batch, y_batch, y_predicted_batch, activation=None):
summaries = []
if len(x_batch.shape) != len(y_batch.shape):
raise ValueError(
"not implemented for the case, where x and y have a " +
"different number of dimensions. Detected x_batch.shape: " +
str(x_batch.shape) + " and y_batch.shape: " +
str(y_batch.shape) + ".")
### TODO: _get_projection_fn(batch) ###
# need to define a mapping depending on the data domain
# TODO use _get_channel_axis(...)[0] instead
if self.model.data_format == "channels_last":
proj_axis = 1 # z_axis
elif self.model.data_format == "channels_first":
proj_axis = 2 # z_axis
else:
warn("Will not generate image_summaries, because it is unknown " +
"how to handle data_format " + self.model.data_format + ".")
if len(x_batch.shape) == 5: # 3d input
x_projection_fn = lambda batch: tf.reduce_max(batch, proj_axis)
elif len(x_batch.shape) == 4:
x_projection_fn = lambda batch: batch # identity
else:
raise RuntimeError("x_batch is not 4d and not 5d.")
if len(y_batch.shape) == 5: # 3d input
y_projection_fn = lambda batch: tf.reduce_max(batch, proj_axis)
elif len(x_batch.shape) == 4:
y_projection_fn = lambda batch: batch # identity
else:
raise RuntimeError("y_batch is not 4d and not 5d.")
def pad_color_channel(batch):
# add empty 3rd color channel
paddings = [[0,0]] * len(batch.shape)
paddings[self.model.channel_axis] = [0,1]
paddings = tf.constant(paddings)
return tf.pad(batch, paddings)
if x_batch.shape[self.model.channel_axis] == 2:
x_color_fn = pad_color_channel
elif x_batch.shape[self.model.channel_axis] in [1,3]:
x_color_fn = lambda batch: batch # identity
else:
warn("Will not generate image_summaries, because it is unknown " +
"how to handle " + str(x_batch.shape[self.model.channel_axis]) +
" color channels.")
return summaries
if y_batch.shape[self.model.channel_axis] == 2:
y_color_fn = pad_color_channel
elif y_batch.shape[self.model.channel_axis] in [1,3]:
y_color_fn = lambda batch: batch # identity
else:
warn("Will not generate image_summaries, because it is unknown " +
"how to handle " + str(y_batch.shape[self.model.channel_axis]) +
" color channels.")
return summaries
if activation is None:
activation = lambda batch: batch # identity
### ###
# TODO: handle the case where output can be negative, e.g. with hinge_loss
with tf.name_scope("image_summaries"):
# print(x_batch.shape)
# print(y_batch.shape)
# print(y_predicted_batch.shape)
summaries.append(tf.summary.image(
"x", x_projection_fn(x_color_fn(x_batch))))
summaries.append(tf.summary.image(
"y", y_projection_fn(y_color_fn(y_batch))))
summaries.append(tf.summary.image(
"yp", y_projection_fn(y_color_fn(activation(y_predicted_batch)))))
return summaries