In [1]:
import sys
sys.path.append("../models")
sys.path.append("../utils")
import config
import time
import numpy as np
import tensorflow as tf
import data_loader
from tra_helper.image_aug import img_missing
from tra_helper.plot_dset_one import plot_dset_one
from base_models.mobilenet import mobilenet_v2
from seg_models.unet import unet
from seg_models.unet_backbone import unet_backbone
from seg_models.deeplabv3_plus import deeplabv3_plus
from seg_models.model_triple import model_triple
from seg_models.unet_triple import unet_triple


### Training configuration

In [2]:
root_dir = '/home/yons/Desktop/developer-luo/SWatNet'
strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])
print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2


In [3]:
## 1. dataset
tra_dset = data_loader.get_tra_dset()
test_dset = data_loader.get_eva_dset()
tra_dset_dist = strategy.experimental_distribute_dataset(tra_dset)
test_dset_dist = strategy.experimental_distribute_dataset(test_dset)

## 2. training configuration
with strategy.scope():
    ## 2.1 loss function
    loss_fun = tf.keras.losses.BinaryCrossentropy(
                from_logits=False,
                reduction=tf.keras.losses.Reduction.NONE)
    def compute_loss(labels, predictions):
        per_example_loss = loss_fun(labels, predictions)
        # per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=config.BATCH_SIZE)
    tra_loss_tracker = config.tra_loss_tracker
    test_loss_tracker = config.test_loss_tracker
    ## 2.2 accuracy metric
    tra_oa = config.tra_oa
    tra_miou = config.tra_miou


### Model loading

In [4]:
## model
with strategy.scope():
    #### single scale
    # model = unet(nclass=2)
    # model = unet_backbone(input_shape=(256,256,4), base_model=mobilenet_v2)
    # model = deeplabv3_plus(input_shape=[256,256,4], base_model=mobilenet_v2, nclasses=2)
    #### multiple scale
    model = model_triple(input_shape=(256,256,4), base_model=mobilenet_v2)
    # model = unet_triple(scale_high=2048, scale_mid=512, scale_low=256, nclass=2)
    optimizer = config.optimizer
    # model.summary()


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Redu

In [5]:
'''------train step------'''
# @tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x, training=True)
        loss = compute_loss(y, y_pre)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    tra_loss_tracker.update_state(loss)
    tra_oa.update_state(y, y_pre)
    tra_miou.update_state(y, y_pre)
    return tra_loss_tracker.result(), tra_oa.result(), tra_miou.result()

@tf.function
def distributed_train_step(x,y):
  per_replica_losses = strategy.run(train_step, args=(x,y,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

'''------test step------'''
# @tf.function
def test_step(x, y):
    # with tf.GradientTape() as tape:
    y_pre = model(x, training=False)
    loss = loss_fun(y, y_pre)
    test_loss_tracker.update_state(loss)
    test_oa.update_state(y, y_pre)
    test_miou.update_state(y, y_pre)
    return test_loss_tracker.result(), test_oa.result(), test_miou.result()

@tf.function
def distributed_test_step(x,y):
  return strategy.run(test_step, args=(x,y,))

'''------train loops------'''
def train_loops(tra_dset, test_dset, epochs):
    max_miou_pre = 0.8
    for epoch in range(epochs):
        start = time.time()

        '''---train the model---'''
        for x_batch, y_batch in tra_dset:
            # x_batch = img_missing(x_batch)    
            # x_batch=x_batch[2]   ##!!note: x_batch[2] for single-scale model
            tra_loss_epoch,tra_oa_epoch,tra_miou_epoch = train_step(x_batch, y_batch)

        '''---test the model---'''
        for x_batch, y_batch in test_dset:
            # x_batch=x_batch[2]  ##!note: x_batch[2] for single-scale model
            test_loss_epoch, test_oa_epoch, test_miou_epoch = test_step(x_batch, y_batch)
        
        '''---reset the metrics---'''
        tra_loss_tracker.reset_states(), tra_oa.reset_states(), tra_miou.reset_states()
        test_loss_tracker.reset_states(), test_oa.reset_states(), test_miou.reset_states()

        # '''---write into tensorboard---'''
        # train_summary_writer = tf.summary.create_file_writer(config.train_log_dir)
        # test_summary_writer = tf.summary.create_file_writer(config.test_log_dir)
        # with train_summary_writer.as_default():
        #     tf.summary.scalar('learning rate', data=config.optimizer.learning_rate(epoch*16), step=epoch)
        #     tf.summary.scalar('loss', data=tra_loss_epoch, step=epoch)
        #     tf.summary.scalar('oa', data=tra_oa_epoch, step=epoch)
        #     tf.summary.scalar('miou', data=tra_miou_epoch, step=epoch)
        # with test_summary_writer.as_default():
        #     tf.summary.scalar('loss', data=test_loss_epoch, step=epoch)
        #     tf.summary.scalar('oa', data=test_oa_epoch, step=epoch)
        #     tf.summary.scalar('miou', data=test_miou_epoch, step=epoch)
        # print the metrics
        print('epoch {}: traLoss:{:.3f}, traOA:{:.2f}, traMIoU:{:.2f}; evaLoss:{:.3f}, evaOA:{:.2f}, evaMIoU:{:.2f}, time:{:.0f}s'.format(epoch + 1, tra_loss_epoch, tra_oa_epoch, tra_miou_epoch, test_loss_epoch, test_oa_epoch, test_miou_epoch, time.time() - start))
        # if test_miou_epoch>max_miou_pre:
        #     max_miou_pre = test_miou_epoch
        #     model.save_weights(config.path_savedmodel+'/unet_mobilenetv2/weights_epoch_%d'%(epoch+1))

        '''---visualize the results---'''
        if epoch%20 == 0:
            figure = plot_dset_one(model, test_dset.take(1), \
                        i_patch=np.random.randint(8), binary=False, \
                        multiscale=True, weight=False)


In [6]:
## training
train_loops(tra_dset_dist, test_dset_dist, epochs=200)


TypeError: Inputs to a layer should be tensors. Got: PerReplica:{
  0: tf.Tensor(
[[[[0.6200799  0.45047766 0.6109032  0.49702972]
   [0.63449526 0.35282668 0.6195055  0.355333  ]
   [0.4722984  0.4315957  0.59727067 0.30518308]
   ...
   [0.62156916 0.51233846 0.6774574  0.38955602]
   [0.61946267 0.3803712  0.6243561  0.48660418]
   [0.57406944 0.5501763  0.7096416  0.5246348 ]]

  [[0.5731365  0.49619794 0.7276622  0.5051314 ]
   [0.53605354 0.38739836 0.6072549  0.5007934 ]
   [0.4940034  0.20950058 0.6089326  0.33804658]
   ...
   [0.43885306 0.45036268 0.59669507 0.5086227 ]
   [0.67249584 0.43728137 0.67402637 0.46411133]
   [0.46359062 0.45490137 0.7083179  0.38257778]]

  [[0.64405715 0.47959828 0.6886976  0.5809803 ]
   [0.5436801  0.39817566 0.75190455 0.5559174 ]
   [0.54790485 0.41492885 0.74833465 0.51595914]
   ...
   [0.57650906 0.36684358 0.60683197 0.43213516]
   [0.6041667  0.5679051  0.66766006 0.5440774 ]
   [0.6530315  0.5631654  0.5873621  0.55319875]]

  ...

  [[0.57870245 0.4904776  0.569982   0.48717722]
   [0.675678   0.44852623 0.7076017  0.61935705]
   [0.63620746 0.3559937  0.74030733 0.48152104]
   ...
   [0.67686474 0.55168396 0.64442474 0.53311515]
   [0.6447901  0.5222975  0.6838312  0.4116634 ]
   [0.5637487  0.4349667  0.5124273  0.48753336]]

  [[0.60596377 0.45426124 0.6641635  0.5152875 ]
   [0.5972198  0.5087473  0.5722151  0.519834  ]
   [0.52623814 0.3715008  0.61717355 0.5523248 ]
   ...
   [0.6348872  0.5259268  0.530748   0.5144385 ]
   [0.6052295  0.52829105 0.64737904 0.53423494]
   [0.69295454 0.5241862  0.69305307 0.60232884]]

  [[0.47146899 0.45613816 0.61953855 0.59028465]
   [0.646555   0.49210677 0.7046961  0.5223357 ]
   [0.6861547  0.46346393 0.6362788  0.65621054]
   ...
   [0.6764069  0.3680003  0.6432085  0.4983265 ]
   [0.6178616  0.5580512  0.6536784  0.51407385]
   [0.5907875  0.51930463 0.6444537  0.52340317]]]


 [[[0.6779002  0.5929388  0.60490185 0.51474744]
   [0.6791135  0.5954022  0.6093698  0.5227072 ]
   [0.6786674  0.5914422  0.5950984  0.5109111 ]
   ...
   [0.73086303 0.62979895 0.6303798  0.51431245]
   [0.745668   0.65055096 0.6311302  0.5294845 ]
   [0.74521756 0.64010715 0.62397605 0.53848004]]

  [[0.6600023  0.5568421  0.5676186  0.46096736]
   [0.64739114 0.544474   0.5824879  0.45670682]
   [0.6592534  0.5364333  0.5580777  0.44500348]
   ...
   [0.74853927 0.65938884 0.62522066 0.5179338 ]
   [0.7308316  0.6355038  0.62164026 0.49674234]
   [0.7350072  0.6191506  0.64295906 0.5493782 ]]

  [[0.645548   0.5445172  0.5682392  0.4676565 ]
   [0.65245014 0.52097183 0.5828586  0.42032582]
   [0.6480167  0.52508396 0.5712775  0.42350033]
   ...
   [0.8108812  0.6866859  0.57250345 0.46871915]
   [0.72570306 0.63127476 0.6415249  0.5461432 ]
   [0.72693384 0.6320735  0.6521011  0.5423312 ]]

  ...

  [[0.55265987 0.47454885 0.49067217 0.3619585 ]
   [0.56848764 0.4879316  0.51282114 0.3647284 ]
   [0.5675279  0.4814898  0.5145748  0.36334434]
   ...
   [0.5531348  0.41095763 0.42784536 0.3713311 ]
   [0.5558372  0.41049546 0.3850472  0.3774124 ]
   [0.5620006  0.40213954 0.4094788  0.3821677 ]]

  [[0.5342193  0.47420168 0.51034284 0.3569032 ]
   [0.5650562  0.47919896 0.51353407 0.36098593]
   [0.5658158  0.46943966 0.5085595  0.3621707 ]
   ...
   [0.54897237 0.40192282 0.4149623  0.36838573]
   [0.5653127  0.4175742  0.39190713 0.3825006 ]
   [0.5546146  0.40844393 0.4033177  0.3808438 ]]

  [[0.57760334 0.4825362  0.5028734  0.35562307]
   [0.5726126  0.47554937 0.5160849  0.3605693 ]
   [0.55283606 0.4786302  0.5017649  0.3457927 ]
   ...
   [0.5580689  0.40012127 0.3965158  0.37497142]
   [0.5412557  0.40758985 0.38960683 0.37669647]
   [0.5595341  0.40184546 0.40456626 0.3747781 ]]]


 [[[0.69011474 0.51809067 0.8303911  0.556691  ]
   [0.5367384  0.47534376 0.7434681  0.66309106]
   [0.5428682  0.49242294 0.66704106 0.53563356]
   ...
   [0.6138183  0.55313677 0.7490944  0.6460047 ]
   [0.71853447 0.54652584 0.71949553 0.70268625]
   [0.7347653  0.57981056 0.70072746 0.6097221 ]]

  [[0.4776051  0.45990682 0.62496376 0.6025801 ]
   [0.58232826 0.4247902  0.6046299  0.45733637]
   [0.5782217  0.4253936  0.66953397 0.5493336 ]
   ...
   [0.7350459  0.6978751  0.7828038  0.5817892 ]
   [0.7894968  0.5882003  0.6857585  0.6233401 ]
   [0.67994934 0.7440958  0.7057019  0.6872395 ]]

  [[0.4184674  0.4310146  0.74349993 0.51617455]
   [0.57681715 0.57464904 0.55974895 0.5770126 ]
   [0.65883666 0.46718824 0.5377805  0.50637615]
   ...
   [0.49076992 0.5778077  0.7003197  0.6112028 ]
   [0.6102093  0.57369643 0.7820222  0.62829965]
   [0.8115684  0.6235488  0.54087466 0.5528984 ]]

  ...

  [[0.5854448  0.41908026 0.74518627 0.53965956]
   [0.4993828  0.44233453 0.7500272  0.6387221 ]
   [0.5498151  0.5056493  0.70106834 0.5486287 ]
   ...
   [0.56729186 0.32939345 0.6117595  0.45728973]
   [0.45704716 0.35254928 0.58766633 0.4544697 ]
   [0.48551154 0.4575078  0.59535134 0.46302265]]

  [[0.56165034 0.54938954 0.77683425 0.7540078 ]
   [0.61578447 0.4887574  0.7216398  0.64194345]
   [0.41565967 0.44595695 0.6647097  0.5978068 ]
   ...
   [0.51869553 0.39634025 0.61919177 0.40772176]
   [0.5745941  0.32065505 0.5753324  0.48839596]
   [0.59205025 0.40397194 0.6222968  0.4539236 ]]

  [[0.57873774 0.42123798 0.8317123  0.73783565]
   [0.56747055 0.3392237  0.74293756 0.7647513 ]
   [0.63820314 0.5332192  0.72202647 0.56892014]
   ...
   [0.574303   0.45185184 0.585426   0.5598804 ]
   [0.54824823 0.3732354  0.51991266 0.38531393]
   [0.5321991  0.3739239  0.60976183 0.4310439 ]]]


 [[[0.7061902  0.48764575 0.7257625  0.5554827 ]
   [0.7212381  0.51019216 0.6921668  0.51561654]
   [0.7513532  0.53171426 0.73157173 0.47411513]
   ...
   [0.71765    0.48499686 0.7281051  0.527851  ]
   [0.72232753 0.45574862 0.7263005  0.49917987]
   [0.7244492  0.510188   0.734292   0.5254228 ]]

  [[0.69661015 0.48352683 0.72633946 0.49765033]
   [0.71148115 0.48142028 0.70593196 0.53248733]
   [0.7219084  0.5018693  0.7009566  0.5089324 ]
   ...
   [0.7330965  0.47442764 0.7152611  0.5089783 ]
   [0.7305123  0.49367133 0.7072634  0.53501093]
   [0.7278129  0.5489129  0.7219517  0.49949282]]

  [[0.69594115 0.51982063 0.7352047  0.51955813]
   [0.7219883  0.5531279  0.7011726  0.50958437]
   [0.7004027  0.49089047 0.6906016  0.48873162]
   ...
   [0.71901333 0.5092857  0.71502346 0.51624936]
   [0.74258804 0.50049496 0.7393354  0.53046834]
   [0.7290355  0.5137393  0.73813665 0.47449973]]

  ...

  [[0.7841882  0.59425235 0.72927755 0.5131614 ]
   [0.7904011  0.60012484 0.7368751  0.47192597]
   [0.797974   0.6253978  0.73768616 0.5218378 ]
   ...
   [0.824933   0.6964982  0.8334528  0.7028405 ]
   [0.7611148  0.61072147 0.8121204  0.6653917 ]
   [0.7080559  0.54088616 0.8333328  0.65909725]]

  [[0.7661769  0.5801199  0.72793776 0.50730354]
   [0.77757215 0.58386177 0.72494406 0.5407432 ]
   [0.7703137  0.6221273  0.7244024  0.47488326]
   ...
   [0.72987056 0.5980804  0.8136256  0.657699  ]
   [0.7032179  0.54564404 0.82444644 0.6494334 ]
   [0.70237595 0.5765192  0.8314518  0.6490667 ]]

  [[0.74660885 0.57222223 0.74399745 0.53571343]
   [0.7607232  0.5777111  0.73058414 0.55139977]
   [0.7512212  0.5483578  0.7340484  0.48960444]
   ...
   [0.6919118  0.5459949  0.8053299  0.6383021 ]
   [0.7076301  0.52230144 0.78987193 0.630432  ]
   [0.7168624  0.5674957  0.7964668  0.6292666 ]]]], shape=(4, 256, 256, 4), dtype=float32),
  1: tf.Tensor(
[[[[0.6980177  0.50487685 0.67965627 0.44201088]
   [0.5453956  0.4179813  0.75454724 0.6060525 ]
   [0.5570014  0.5724523  0.57242805 0.5503787 ]
   ...
   [0.6579627  0.46012917 0.59680235 0.4904221 ]
   [0.57494426 0.57474554 0.77546144 0.46749806]
   [0.8152574  0.5811218  0.6373651  0.49589378]]

  [[0.6173389  0.46603483 0.5711473  0.48233983]
   [0.58904845 0.30846065 0.5937186  0.48642004]
   [0.54620504 0.49443543 0.6381861  0.50901943]
   ...
   [0.48299435 0.4486877  0.6541782  0.47558942]
   [0.5129143  0.44187245 0.5931112  0.5245163 ]
   [0.7530662  0.5838285  0.68591154 0.58147824]]

  [[0.6350501  0.54151416 0.5664855  0.47625837]
   [0.79522866 0.41552535 0.7187772  0.45274168]
   [0.61252123 0.49984223 0.56256545 0.4629143 ]
   ...
   [0.5163064  0.3358295  0.5816408  0.42449215]
   [0.6088235  0.37527218 0.60684794 0.40145504]
   [0.59716105 0.4135163  0.56410545 0.44984916]]

  ...

  [[0.5863483  0.43416467 0.70422566 0.53971064]
   [0.6601184  0.34448975 0.59199405 0.39355952]
   [0.6161669  0.40598807 0.65502405 0.43844143]
   ...
   [0.6390679  0.522605   0.5968133  0.41181946]
   [0.62860894 0.5931733  0.68937486 0.479916  ]
   [0.6525632  0.44832838 0.6413091  0.514915  ]]

  [[0.59833616 0.48594135 0.6648734  0.46829164]
   [0.5070472  0.36852464 0.69267786 0.49711913]
   [0.79988724 0.4774609  0.62820256 0.4345489 ]
   ...
   [0.6472962  0.37394297 0.7261696  0.55994856]
   [0.68397856 0.5110107  0.6830772  0.5728423 ]
   [0.7978518  0.47247508 0.59421045 0.54122144]]

  [[0.54861057 0.3068362  0.47204524 0.3771782 ]
   [0.6034157  0.3833111  0.6268264  0.43909422]
   [0.7161011  0.46206456 0.5427064  0.46400702]
   ...
   [0.5459927  0.47412464 0.631869   0.53150666]
   [0.6494275  0.3468663  0.64244264 0.437023  ]
   [0.700458   0.46944085 0.6749605  0.4996889 ]]]


 [[[0.5684179  0.44260204 0.59815246 0.46139407]
   [0.61114633 0.50493336 0.63851565 0.53486186]
   [0.58341366 0.45847726 0.67826223 0.5061015 ]
   ...
   [0.5503462  0.41694254 0.5435787  0.41235864]
   [0.5290821  0.40679938 0.5297975  0.42446733]
   [0.5259849  0.410484   0.50999665 0.4293456 ]]

  [[0.6045965  0.48101863 0.62191653 0.46969062]
   [0.5913812  0.4809881  0.6710499  0.5431011 ]
   [0.57447916 0.434522   0.6495024  0.48557693]
   ...
   [0.5426984  0.42132175 0.51045287 0.42325163]
   [0.5332482  0.41344637 0.5278368  0.4254073 ]
   [0.5360429  0.41952288 0.5361966  0.41660717]]

  [[0.6460563  0.5145819  0.6599333  0.5192232 ]
   [0.57905686 0.45853415 0.70456344 0.54519343]
   [0.5657539  0.44178653 0.65124995 0.47665983]
   ...
   [0.5394812  0.41803268 0.55729115 0.41893506]
   [0.5339772  0.4110178  0.5395593  0.4129644 ]
   [0.55301964 0.4246767  0.5251739  0.4178446 ]]

  ...

  [[0.52855635 0.41770628 0.513637   0.42535636]
   [0.53666973 0.41750064 0.5181126  0.4217773 ]
   [0.5340571  0.4180799  0.5304321  0.4324241 ]
   ...
   [0.58525765 0.4535622  0.6278628  0.4726885 ]
   [0.5824047  0.45181668 0.61736536 0.4694157 ]
   [0.574299   0.45048517 0.6364736  0.49133104]]

  [[0.5094199  0.41676408 0.53845227 0.43122166]
   [0.5377668  0.40933067 0.5300929  0.4230702 ]
   [0.52497584 0.4162562  0.5546571  0.42732233]
   ...
   [0.57612735 0.44493923 0.6149667  0.4812014 ]
   [0.58240765 0.44419268 0.6189145  0.48169845]
   [0.58876085 0.4523928  0.6299436  0.4938084 ]]

  [[0.52731127 0.419869   0.5328092  0.4428544 ]
   [0.52995276 0.413065   0.53036445 0.43293613]
   [0.53205407 0.41343707 0.55017054 0.43606442]
   ...
   [0.5813354  0.44279486 0.6233833  0.4760692 ]
   [0.57516706 0.43857947 0.60621655 0.46161002]
   [0.57025933 0.43756458 0.6173605  0.47082642]]]


 [[[0.5655534  0.36839136 0.58174586 0.34563088]
   [0.54608667 0.36937332 0.5791079  0.3672496 ]
   [0.56192535 0.3456948  0.5808942  0.39568752]
   ...
   [0.52043366 0.41062054 0.5631013  0.37144056]
   [0.56303555 0.38082594 0.6377531  0.43040663]
   [0.5544235  0.3741211  0.65610546 0.4220673 ]]

  [[0.55437684 0.36387864 0.5854902  0.37111676]
   [0.55697274 0.3602262  0.5888927  0.28559923]
   [0.5662822  0.4029544  0.57197917 0.37391502]
   ...
   [0.5708256  0.33178616 0.648239   0.45220834]
   [0.55805236 0.3655151  0.6579042  0.45469016]
   [0.55960363 0.3829378  0.6603882  0.4368236 ]]

  [[0.5485182  0.37965205 0.58044946 0.394744  ]
   [0.5692835  0.37539297 0.5946481  0.35360625]
   [0.56923157 0.34741938 0.5698515  0.29563943]
   ...
   [0.5557749  0.39075902 0.6583134  0.47868124]
   [0.5458944  0.3799564  0.66556805 0.43646297]
   [0.5898541  0.41478547 0.6798965  0.46572414]]

  ...

  [[0.69440675 0.5694208  0.6358253  0.5329408 ]
   [0.673816   0.5310493  0.64333576 0.535164  ]
   [0.6192345  0.5330959  0.7827805  0.6370175 ]
   ...
   [0.58882207 0.42793608 0.6229692  0.45498204]
   [0.5814902  0.3878239  0.61147946 0.4299092 ]
   [0.57533634 0.41898167 0.617739   0.40178338]]

  [[0.71718776 0.5973412  0.6281087  0.5104146 ]
   [0.7119311  0.5786357  0.6530032  0.5194165 ]
   [0.6632911  0.5571174  0.7404576  0.6107363 ]
   ...
   [0.6326399  0.45663238 0.63283396 0.4930504 ]
   [0.6271349  0.44466752 0.6344468  0.48071292]
   [0.6303797  0.4668091  0.6333036  0.49308228]]

  [[0.755142   0.6531967  0.6405743  0.52183145]
   [0.6964464  0.5800834  0.63714486 0.51977044]
   [0.67017454 0.5562973  0.69260514 0.55814964]
   ...
   [0.60702264 0.45291156 0.63401747 0.49119887]
   [0.6117743  0.4443963  0.635783   0.46990573]
   [0.6211374  0.44046128 0.63295054 0.46737388]]]


 [[[0.6847946  0.48908246 0.561081   0.62964594]
   [0.6268649  0.5349496  0.7396407  0.37238842]
   [0.53284127 0.41250592 0.709383   0.49886572]
   ...
   [0.6551061  0.4373563  0.7903039  0.66368705]
   [0.51584595 0.4892563  0.7668909  0.7081881 ]
   [0.72090733 0.6036976  0.74380094 0.65755653]]

  [[0.60804075 0.35836548 0.6597215  0.53454316]
   [0.69323033 0.393558   0.5943859  0.486979  ]
   [0.54798275 0.459925   0.55207026 0.57920116]
   ...
   [0.6672802  0.4785284  0.6409271  0.43508995]
   [0.8013577  0.48821634 0.7916648  0.606381  ]
   [0.7665148  0.6134348  0.5245516  0.50607574]]

  [[0.7240704  0.41462973 0.6329266  0.47905776]
   [0.659194   0.49441278 0.722987   0.44935083]
   [0.60205907 0.4087535  0.7270777  0.5523374 ]
   ...
   [0.64498556 0.49781516 0.5999912  0.48188797]
   [0.82006663 0.56945634 0.57228684 0.5321662 ]
   [0.8694314  0.7019377  0.574485   0.5583311 ]]

  ...

  [[0.6654574  0.5295334  0.6171953  0.49293798]
   [0.6110889  0.52219224 0.80382824 0.48180252]
   [0.6128119  0.5091644  0.6212684  0.537459  ]
   ...
   [0.57025206 0.3470519  0.6895666  0.6070633 ]
   [0.6172749  0.44273257 0.7490457  0.7270975 ]
   [0.61988556 0.5504454  0.6902837  0.58120066]]

  [[0.5622207  0.40841195 0.5216717  0.4717768 ]
   [0.5723436  0.40867108 0.57033545 0.52637064]
   [0.5335992  0.5073532  0.5528992  0.53210944]
   ...
   [0.5923158  0.41920957 0.78495276 0.59107697]
   [0.7191378  0.51242626 0.72348785 0.65524215]
   [0.6374901  0.5877345  0.75149935 0.54320997]]

  [[0.63734937 0.37488455 0.6383838  0.5286166 ]
   [0.46551698 0.49511963 0.5064932  0.50222844]
   [0.57494766 0.4176346  0.52923405 0.4211309 ]
   ...
   [0.69452554 0.48695442 0.7649058  0.48640925]
   [0.54667246 0.40535632 0.7729947  0.5487324 ]
   [0.6653685  0.56958497 0.8087021  0.49816516]]]], shape=(4, 256, 256, 4), dtype=float32)
}

In [5]:
# %load_ext tensorboard
# !kill 6006
# %tensorboard --logdir logs/tensorb/
# http://localhost:16006


In [6]:
# model saving and loading
# path_weight = root_dir + '/models/temporal/UNet_gru_triple/weights_epoch_40'
# path_save_model = root_dir + '/models/pretrained/triple_backbone'
# model.save(path_save_model)
# model.save_weights(path_model)
# model = tf.keras.models.load_model(path_model)  ## load model
# model.load_weights(path_weight)

