In [None]:
!ls ${CUDA_DIR}/nvvm/libdevice

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
# import sys, os
# os.environ["CUDA_VISIBLE_DEVICES"]="0"
!set TF_GPU_THREAD_MODE=gpu_private

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import math
from json import dumps

gpus = tf.config.list_physical_devices("GPU")
print(f"gpus={gpus}")

from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers
from tensorflow.keras.layers import Dense, Input, Flatten,\
                                    Reshape, LeakyReLU as LR,\
                                    Activation, Dropout
from tensorflow.keras.models import Model, Sequential
from matplotlib import pyplot as plt
from IPython import display # If using IPython, Colab or Jupyter
import numpy as np
import tensorflow_addons as tfa
import datetime
import random
from sklearn.model_selection import train_test_split
import time

import os
import re
import pathlib

In [None]:
# !wget https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

In [None]:
print(tf.__version__)

In [None]:
def add_vertical_lines_noise(x, shifts=2, line_height=2):
    x = np.copy(x)
    axis_for_roll = 1
    y = np.roll(x, -shifts, axis=axis_for_roll)
    j = 0
    for i in range(x.shape[axis_for_roll]):
        if j <= line_height:
            x[:,i,:] = y[:,i,:]
            
        j+=1
        if j == line_height * 2:
            j = 0
#     print(x.shape)
    return x

def add_horizontal_lines_noise(x, shifts=2, line_height=2):
    x = np.copy(x)
    axis_for_roll = 0
    y = np.roll(x, -shifts, axis=1)
    j = 0
    for i in range(x.shape[axis_for_roll]):
        if j <= line_height:
            x[i,:,:] = y[i,:,:]
            
        
        if j == line_height * 2:
            j = 0
        else:
            j+=1
#     print(x.shape)
    return x

In [None]:
import tensorflow_datasets as tfds
from functools import reduce
splits = tfds.even_splits('train', n=200, drop_remainder=True)

In [None]:
IMG_H = 480
IMG_W = 720
IMG_CHANNELS = 3

In [None]:
# train_len = 7164

In [None]:
# BATCH_SIZE = 3
# EPOCHS=100
# steps_per_epoch=int(train_len/BATCH_SIZE)

## Dataset

In [None]:
import tensorflow_datasets as tfds

def unsharp(x):
    image = Image.fromarray(x)
    return image.filter(ImageFilter.UnsharpMask(radius=2, percent=150))

def pixelation_noise(x, ranges=[1/3]):
    downsize_image_ratio = random.choice(ranges)
    sh = tf.cast(tf.shape(x), tf.float32)
    
    resized_size_h = sh[0]
    resized_size_w = sh[1]
    down = tf.image.resize(
        x,
        [tf.cast(resized_size_h * downsize_image_ratio, tf.int32), tf.cast(resized_size_w * downsize_image_ratio, tf.int32)],
        preserve_aspect_ratio=True,
        antialias=False,
        name=None)
    x= tf.image.resize(
        down,
        [resized_size_h, resized_size_w],
        preserve_aspect_ratio=True,
        antialias=False,
        name=None)
    return x

def random_invert_img(x, p=0.5):
    if  tf.random.uniform([]) < p:
        x = (255-x)
    else:
        x
    return x

def random_apply_saturation(x, p=0.5):
    if  tf.random.uniform([]) < p:
        return tf.image.random_saturation(x, 5, 10)
    return x

def augment_img(x):
#     x = tf.image.stateless_random_brightness(x, 0.2, seed)
#     x = tf.image.random_contrast(x, 0.2, 0.5)
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = random_invert_img(x, p=0.4)
    x = random_apply_saturation(x, p=0.4)
    return x

# def split_fn(x, y):
#     print(x, y)
#     return tf.image.resize(x, size=[480, 720]), y

@tf.function
def random_noise_and_resize(y):
    def rn(x): 
        ts = [
            lambda x: add_horizontal_lines_noise(x, random.randint(-10, -8), random.randint(2, 5)),
            pixelation_noise,
#             lambda x: unsharp(np.array(x)), 
            lambda x: add_horizontal_lines_noise(pixelation_noise(x,ranges=[1/3]), random.randint(-10, -8), random.randint(2, 5))
             ]
        fn = random.choice(ts)

        return fn(x)
    
    x = tf.image.resize(y, size=[480, 720])
    return tf.numpy_function(func=rn, inp=[x], Tout=tf.float32), y


@tf.function
def random_noise(y):
    def rn(x): 
        ts = [
            lambda x: add_horizontal_lines_noise(x, random.randint(-10, -8), random.randint(2, 5)),
            pixelation_noise,
#             lambda x: unsharp(np.array(x)), 
            lambda x: add_horizontal_lines_noise(pixelation_noise(x,ranges=[1/3]), random.randint(-10, -8), random.randint(2, 5))
             ]
        fn = random.choice(ts)

        return fn(x)

    return tf.numpy_function(func=rn, inp=[y], Tout=tf.float32), y


def set_shapes(image, label):
    image.set_shape((480, 720, 3))
    label.set_shape((720, 1080, 3))
    return image, label

def ccast(x, y):
    return tf.cast(x, tf.float16), tf.cast(y, tf.float16)

def div2k_ds(split):
    return tfds.load('div2k', split=split, shuffle_files=True)\
        .map(lambda x: x["hr"])\
        .map(lambda y:tf.image.resize_with_crop_or_pad(y, 720, 1080))\
        .map(lambda z: tf.cast(z, tf.float16))\
        .map(augment_img)\
        .map(random_noise_and_resize)\
        .map(set_shapes)\
        .map(ccast)\

def div2k_ds_same_size(split):
    return tfds.load('div2k', split=split, shuffle_files=True)\
        .map(lambda x: x["hr"])\
        .map(lambda y:tf.image.resize_with_crop_or_pad(y, 480, 720))\
        .map(lambda z: tf.cast(z, tf.float16))\
        .map(augment_img)\
        .map(random_noise_and_resize)\
        .map(set_shapes)

div2k_ds_train = div2k_ds('train').batch(2)
div2k_ds_test = div2k_ds('validation').batch(2)

# # i = iter(div2k_ds_train)
# # for j in range(3):
# #     x, y = next(i)
# #     plt.figure()
# #     plt.imshow(tf.cast(x, tf.uint8))
# #     plt.figure()
# #     plt.imshow(tf.cast(y, tf.uint8))

In [None]:
# from utils import NoiseUtil, ImgUtils, DataLoader, DataManager

           
# def add_noise(x,y):
#     downsize_image_ratio = random.choice([1/5])
#     sh = tf.cast(tf.shape(x), tf.float32)
    
#     print(x)
    
#     resized_size_h = sh[0]
#     resized_size_w = sh[1]
#     down = tf.image.resize(
#         x,
#         [tf.cast(resized_size_h * downsize_image_ratio, tf.int32), tf.cast(resized_size_w * downsize_image_ratio, tf.int32)],
#         preserve_aspect_ratio=True,
#         antialias=False,
#         name=None)
    
    

#     x= tf.image.resize(
#         down,
#         [resized_size_h, resized_size_w],
#         preserve_aspect_ratio=True,
#         antialias=False,
#         name=None)
        
#     return tf.reshape(x, (resized_size_h, resized_size_w, 3)), y
    
    
# #     print(x.shape)

    
# #     n = NoiseUtil.pixel_noise(x, random.choice([50,60]), 15, downsize_image_ratios=[1/4, 1/6])

# #     n = x + 0.2 * tf.random.normal(
# #         x.shape[1:],
# #         mean=0.0,
# #         stddev=1.0,
# #         dtype=tf.dtypes.float32,
# #     )

# #     return n,y

# random_bright = tf.keras.layers.RandomBrightness(factor=0.2)
# random_contrast = tf.keras.layers.RandomContrast(factor=0.2)
# random_flip = tf.keras.layers.RandomFlip()


# def augment(x):
# #     seed = (random.randint(0, 100),random.randint(0, 100))
#     x = random_bright(x, training=False)
#     x = random_contrast(x, training=False)
#     x = random_flip(x, training=False)
#     return x

# def saturation(x):
#     return tf.image.random_saturation(x, 5, 10)

# def get_train_data():
#     return tf.data.Dataset.from_generator(train_gen, output_signature=tf.TensorSpec(shape=(480, 720, 3)))
    
# def get_test_data():
#     return tf.data.Dataset.from_generator(test_gen, output_signature=tf.TensorSpec(shape=(480, 720, 3)))


# def get_dist_ds(ds, ds_len):
#     c = ds.map(normm).map(expp).map(augment)
#     a = c.map(lambda y: (y,1))
#     b = c.map(lambda y: (y,0)).map(add_noise)
#     return a.concatenate(b).batch(BATCH_SIZE)

# def get_gen_ds(ds):
#     return ds.map(augment).map(lambda x: (x,x)).map(add_noise).batch(BATCH_SIZE)

# # train_ds = get_gen_ds(get_train_data())
# # test_ds = get_gen_ds(get_test_data())

## Training config

In [None]:
class TrainingConfig:
    def __init__(self,
                 dropout_rate=0.1,
                 ff_dropout_rate=0.4,
                 kernel_regularizer = tf.keras.regularizers.L2(0.001),
                 learning_rate=[5e-5, 1e-5, 9e-6, 7e-6],
                optimizer="adamf"):
        self.dropout_rate = dropout_rate
        self.kernel_regularizer = kernel_regularizer
        self.ff_dropout_rate = ff_dropout_rate
        self.learning_rate =learning_rate
        self.optimizer = optimizer
        
    def to_map(self):
        return {"dropout_rate":self.dropout_rate, "kernel_regularizer":self.kernel_regularizer.get_config(), "ff_dropout_rate":self.ff_dropout_rate, "learning_rate":self.learning_rate, "optimizer": self.optimizer}

    def to_hp_map(self):
        assert self.kernel_regularizer.get_config()["l2"] is not None
        return {"dropout_rate":self.dropout_rate, "kernel_regularizer":self.kernel_regularizer.get_config()["l2"], "ff_dropout_rate":self.ff_dropout_rate}
    
    def toString(self):
        return dumps(self.to_map())
    
    def __str__(self):
        return self.toString()
        

In [None]:
import datetime

def now():
    now = datetime.datetime.now()
    return now.strftime('%Y_%m_%d_T%H_%M_%S') + ('_%02d' % (now.microsecond / 10000))

def minute():
    now = datetime.datetime.now()
    return now.strftime('%H_%M') 

minute()

In [None]:
# class MetricsManager:
#     def __init__(self):

In [None]:
from tensorboard.plugins.hparams import api as hp

# HP_NUM_UNITS = hp.HParam('dropout_rate', hp.Discrete([0.1, 0.2]))
# HP_DROPOUT = hp.HParam('ff_dropout_rate', hp.Discrete([0.1,0.4]))
# HP_OPTIMIZER = hp.HParam('kernel_regularizer', hp.Discrete([
#     tf.keras.regularizers.L2(1e-2),
#     tf.keras.regularizers.L2(1e-5),
# ]))
# class HpSearch():
#     def __init__(self, log_dir):
METRIC_ACCURACY = 'accuracy'
METRIC_LOSS = 'loss'
METRIC_MSE = 'mse'

class TensorboardHPSearch():
    def __init__(self, log_dir, name, prefix=f"t{minute()}", v=1):
        self.log_dir_base = log_dir
        self.log_dir = f"{log_dir}/{name}"
        self.log_dir_name = name
        self.prefix = prefix
        self.hp_writer = tf.summary.create_file_writer(self.log_dir)

        
    def init(self, config):
#         with self.hp_writer.as_default():
#             hp.hparams(config.to_hp_map())
        print(f"HPs: {config.to_hp_map()}, logdir={self.log_dir}")
        return hp.KerasCallback(self.log_dir, config.to_hp_map())
            
    def log_results(self, name, loss, accuracy, mse):
        with tf.summary.create_file_writer(f"{self.log_dir_base}/{name}").as_default():
            tf.summary.scalar(METRIC_ACCURACY, accuracy, step=1)
            tf.summary.scalar(METRIC_LOSS, loss, step=1)
            tf.summary.scalar(METRIC_MSE, mse, step=1)

        
    def hp_search(self,
                  dropout_rate=[0.1, 0.2],
                  ff_dropout_rate=[0.1,0.4],
                  regularizers=[1e-2,1e-5],
                  learning_rate=[(5e-5, 7e-6)],
                  optimizer=["adamf", "adam"],
                  run=lambda config, callback, name:0):
        
        DROPOUT_RATE = hp.HParam('dropout_rate', hp.Discrete(dropout_rate))
        FF_DROPOUT_RATE = hp.HParam('ff_dropout_rate', hp.Discrete(ff_dropout_rate))
        KERNEL_REGULARIZER = hp.HParam('kernel_regularizer', hp.Discrete(regularizers))
        LR = hp.HParam('learning_rate', hp.Discrete(map(lambda x: str(x), learning_rate)))
        OPTIMIZER = hp.HParam('optimizer', hp.Discrete(optimizer))

        
        with self.hp_writer.as_default():
    #                         group = f"t{t}/validation"
            hp.hparams_config(
                hparams=[DROPOUT_RATE, FF_DROPOUT_RATE, KERNEL_REGULARIZER, LR, OPTIMIZER],
                metrics=[
                    hp.Metric(METRIC_ACCURACY, display_name='Accuracy'),
                    hp.Metric(METRIC_LOSS, display_name='Perceptual Loss'),
                    hp.Metric(METRIC_MSE, display_name='MSE')
                ],
            )
                        
        configs = []
        t = 1
        for dr in DROPOUT_RATE.domain.values:
            for ffdr in FF_DROPOUT_RATE.domain.values:
                for reg in KERNEL_REGULARIZER.domain.values:
                    for lr in learning_rate: 
                        for opt in OPTIMIZER.domain.values:
                            config = TrainingConfig(dropout_rate=dr, ff_dropout_rate=ffdr,
                                                   kernel_regularizer=tf.keras.regularizers.L2(reg),
                                                   learning_rate=lr,
                                                   optimizer=opt)

                            hpConfig = {
                                DROPOUT_RATE:dr,
                                FF_DROPOUT_RATE: ffdr,
                                KERNEL_REGULARIZER: reg,
                                LR: str(lr),
                                OPTIMIZER:opt
                            }
                            run(config, hp.KerasCallback(
                                f"{self.log_dir}/{self.prefix}{t}", hpConfig, trial_id = f"{self.prefix}{t}"),
                                f"{self.log_dir_name}/{self.prefix}{t}")
                            configs.append(config)
                            t += 1
        return configs

class TensorboardUtil():
    def __init__(self, log_dir):
        self.log_dir = log_dir
        self.file_writer = tf.summary.create_file_writer(log_dir)
        self.hp_writer = tf.summary.create_file_writer(f"t_{log_dir}/{now()}")
        
    def log_text(self, data, name, step=0):
        with self.file_writer.as_default():
            tf.summary.text(name, data, step=step, description=None)
            
    def log_scalar(self, value, name, step=0):
        with self.file_writer.as_default():
            tf.summary.scalar(name, value, step=step)            
    
#         hp.KerasCallback(logdir, hparams)

    def save_image(self, image, label, step=0):
        with self.file_writer.as_default():
            print(f"Saved {label} to Tensorboard")
            tf.summary.image(label, image, step=step)
        
    def get_callback(self, profile_batch=0):
        return tf.keras.callbacks.TensorBoard(log_dir = self.log_dir,
                      write_graph=True,
                      histogram_freq = 1,
                      profile_batch=profile_batch)

# tb_util = TensorboardUtil("test")
# [s.toString() for s in tb_util.hp_search()]

In [None]:
# from tensorflow.keras.applications.vgg16 import VGG16
# from tensorflow.keras.applications.vgg16 import preprocess_input
# import tensorflow.keras.backend as K

# def build_perceptual_loss(input_shape=(480, 720, 3)):
#     print(input_shape)
#     vgg = VGG16(weights="imagenet", include_top=False, input_shape=input_shape)
#     vgg.trainable = False ## Not trainable weights


#     selected_layers = ['block1_conv1', 'block2_conv2',"block3_conv3" ,'block4_conv3','block5_conv3']
#     selected_layer_weights = [1.0, 4.0 , 4.0 , 8.0 , 16.0]


#     outputs = [vgg.get_layer(l).output for l in selected_layers]
#     prediction_model = Model(vgg.input, outputs, name="perceptual_loss_model")
#     prediction_model.trainable = False

#     @tf.function
#     def perceptual_loss(input_image , reconstruct_image):
#         input_image = tf.keras.applications.vgg16.preprocess_input(input_image)
#         reconstruct_image = tf.keras.applications.vgg16.preprocess_input(reconstruct_image)

#         h1_list = prediction_model(input_image, training=False)
#         h2_list = prediction_model(reconstruct_image, training=False)

#         rc_loss = 0.0

#         for h1, h2, weight in zip(h1_list, h2_list, selected_layer_weights):
#             h1 = K.batch_flatten(h1)
#             h2 = K.batch_flatten(h2)
#             rc_loss = rc_loss + weight * K.sum(K.square(h1 - h2), axis=-1)

#         return rc_loss
    
#     return perceptual_loss

# build_perceptual_loss(input_shape=(480, 720, 3))

## Perceptual loss

In [None]:
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
import tensorflow.keras.backend as K
from tensorflow.keras import mixed_precision

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)


class LossNetwork(tf.keras.models.Model):
    def __init__(self):
        super(LossNetwork, self).__init__()
        vgg = VGG16(include_top=False, weights='imagenet')
        vgg.trainable = False
        
        self.selected_layers = ['block1_conv1', 'block2_conv2',"block3_conv3" ,'block4_conv3','block5_conv3']
        self.selected_layer_weights = [0.1, 0.4 , 0.4 , 0.8 , 1.6]
        
        model_outputs = [vgg.get_layer(name).output for name in self.selected_layers]
        self.model = tf.keras.models.Model(vgg.input, model_outputs)
        # mixed precision float32 output
        self.linear = layers.Activation('linear', dtype='float32') 

    def call(self, x):
        x = preprocess_input(x)
        x = self.model(x)
        x = self.linear(x)
        return x

    @tf.function
    def loss(self, x, y):
        h1_list = self.model(x)
        h2_list = self.model(y)
        
        rc_loss = 0.0
        
        for h1, h2, weight in zip(h1_list, h2_list, self.selected_layer_weights):
            h1 = tf.cast(h1, tf.float32)
            h2 = tf.cast(h2, tf.float32)
            h1 = K.batch_flatten(h1)
            h2 = K.batch_flatten(h2)
            rc_loss = rc_loss + tf.cast(tf.reduce_mean((h1 - h2)**2), tf.float32) * weight   
            rc_loss = tf.cast(rc_loss, tf.float32)
      
        return rc_loss

class PeceptualLoss(tf.keras.losses.Loss):
    def __init__(self):
        super(PeceptualLoss, self).__init__()
        self.loss_network = LossNetwork()
        self.loss_network.trainable = False
        
    def call(self, y_true, y_pred):
        return self.loss_network.loss(y_true,y_pred)

    
def build_perceptual_loss(*args, **kwargs):
    return PeceptualLoss()

x = tf.zeros((1,720,1080,3))
y = tf.ones((1,720,1080,3)) * 255.
ploss = build_perceptual_loss()
ploss(x,y)


## Render test images

In [None]:
class ImageRenderer():
    def __init__(self, batch_size, max_size=20):
            self.batch_size = batch_size
            self.tb_util = None
            self.save = False
            self.max_size = max_size
            self.tb_dataset = None

    def withTensorboard(self, tensorboard_util, tb_sample=4, tb_batch=4):
        self.tb_util = tensorboard_util
        self.tb_batch = tb_batch
        self.tb_sample = tb_sample
        return self
    
    def saveImages(self, path):
        self.save = True
        self.save_path = path
        return self
    
    def render(self, model, datasets=([], []), batch=1):
        batch_size = self.batch_size
        train_ds, test_ds = datasets
        
        if self.tb_dataset is None:
            self.tb_dataset = [(x,y) for x,y in test_ds.map(lambda x,y: (x[0], y[0])).take(self.tb_batch).batch(self.tb_batch)]
        
        def save_to_tb():
            results = [(model(x),x, y) for x,y in self.tb_dataset]
            for x,x_prev, y in results:
                x = tf.cast(x, tf.uint8)
                x_prev = tf.cast(x_prev, tf.uint8)
                y = tf.cast(y, tf.uint8)
                if self.tb_util is not None:
                    self.tb_util.save_image(x, f"denoised", batch)
                    self.tb_util.save_image(x_prev, f"with_noise", batch)
                    self.tb_util.save_image(y, f"original", batch)
        save_to_tb()
                
        dataset = test_ds.map(lambda x,y: (x[0], y[0])).take(batch_size).batch(batch_size)
                
        rows = batch_size
        cols = 3
        results = [(model(x),x, y) for x,y in dataset]
        plt.ion()
        plt.show()
        plt.figure(figsize=((self.max_size+15) / cols, self.max_size / rows))
        for x,x_prev, y in results:
            x = tf.cast(x, tf.uint8)
            x_prev = tf.cast(x_prev, tf.uint8)
            y = tf.cast(y, tf.uint8)

            for i in range(x.shape[0]):
                im = x[i,:,:,:]
                plt.subplot(cols, rows, i+1)
                plt.imshow(im)
                plt.title("Denoised")
                plt.axis('off')

            for i in range(x_prev.shape[0]):
                im = x_prev[i,:,:,:]
                plt.subplot(cols, rows, i+x_prev.shape[0]+1)
                plt.imshow(im)
                plt.title("With noise")
                plt.axis('off')


            for i in range(y.shape[0]):
                im = y[i,:,:,:]
                plt.subplot(cols, rows, i+x.shape[0]+x_prev.shape[0]+1)
                plt.imshow(im)
                plt.title("Original")
                plt.axis('off')

        plt.subplots_adjust(wspace = 0.1, hspace = 0.5)
        if self.save:
            plt.savefig(self.save_path)

        plt.draw()
        plt.pause(0.001)
            
# def rndd(x):
#     assert len(x.shape) == 4
#     return x
# ImageRenderer(2).withTensorboard(TensorboardUtil("prepa_autencoder_logs")).render(model=lambda x:x, datasets=(div2k_ds_train, div2k_ds_test))

## Model utils

In [None]:
def conv(filters, kernel_size, strides, norm=False, use_init = True, use_bias=True, activation="relu", config=None, **args):
    return layers.Conv2D(filters,
                         kernel_size=kernel_size,
                         strides=strides,
                         padding="same",
                         use_bias=use_bias,
                         activation=activation,
                         kernel_regularizer=config.kernel_regularizer,
#                          kernel_initializer=tf.keras.initializers.HeNormal(seed=32) if use_init else None,
                         **args)


class Resblock(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, config=None, *args, **kwargs):
        super(Resblock, self).__init__(*args, **kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.config = config
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "conv_size": 3,
            "res_activation": "relu",
        })
        return config

    def build(self, input_shape):
        self.conv1 = conv(self.filters, kernel_size=self.kernel_size, strides=1, config=self.config)
        self.conv2 = conv(self.filters, kernel_size=self.kernel_size, strides=1, config=self.config)
        self.conv3 = conv(self.filters, kernel_size=self.kernel_size, strides=1, config=self.config)
        self.add = layers.Add()
        self.activation = layers.Activation('relu', dtype=tf.float16) 
        
    def call(self, inputs):
        x = self.conv1(inputs)
        y = self.conv2(x)
        y = self.conv3(y)
        out = self.add ([x,y])
        return self.activation(out)


class MyRescale(tf.keras.layers.Layer):
  def __init__(self):
    super(MyRescale, self).__init__()
  def build(self, input_shape):
     self.kernel = self.add_weight("kernel", initializer=tf.keras.initializers.Constant(value=255))


  def call(self, inputs):
   return inputs * self.kernel


class MyDeconv(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, name, resize_to=None, config=None):
        super(MyDeconv, self).__init__(name=name)
        self.filters = filters
        self.kernel_size = kernel_size
        self.resize_to = resize_to
        self.strides = strides
        self.config = config
        
    def get_config(self):
        config = super().get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "resize_to": self.resize_to,
            "strides": self.strides,
        })
        return config
        
    def build(self, input_shape):
        filters = self.filters
        kernel_size = self.kernel_size
        name = self.name
        
        strides = self.strides if isinstance(self.strides, tuple) else (self.strides,self.strides)
                     
        if self.resize_to is None:
            h,w = input_shape[1], input_shape[2]
            self.resize_to = (int(h*strides[0]),int(w*strides[1]))
        
        self.resize_layer = layers.Lambda(lambda x: tf.image.resize(x, self.resize_to, method="nearest"), name=f"resize_nearest_{name}")
        self.resblock = Resblock(filters, kernel_size, config=self.config)
        

    def call(self, inputs):
        x = self.resize_layer(inputs)
        return self.resblock(x)


def maxconv_name():
    i = 0
    while True:
        i += 1
        yield f"max_conv_{i}"

        
max_conv_name = iter(maxconv_name())


def maxpoolconv(filters, pool_kernel_size, kernel_size=3, strides=1, config=None):
    name = next(max_conv_name)
    def fn(x):
#         y = conv(filters, kernel_size=kernel_size, strides=1, name=f"{name}_1", norm=norm)(x)
#         y = conv(filters, kernel_size=kernel_size, strides=1, name=f"{name}_2", norm=norm)(y)
#         x = layers.Add()(x,y)
#         x = layers.ReLU(x)
        x = Resblock(filters, kernel_size, name=f"resblock_{name}_1", config=config)(x)
#         x = conv(filters, kernel_size=kernel_size, strides=1, name=f"{name}_3", norm=norm)(x)
        x = layers.MaxPool2D(pool_kernel_size, name=f"{name}_maxpool")(x)
        return x
    return fn

## Autoencoder

In [None]:
def build_autoencoder_large_v3(config):    
    inputs = keras.Input(shape=(480, 720, 3))
    norm = True
    

    def deconv_with_ff(filters, x, ff, kernel_size, stride, name):
        x = MyDeconv(filters, kernel_size, stride, name=name, resize_to=None, config=config)(x)
        x = layers.SpatialDropout2D(config.dropout_rate)(x)
        x = tf.add(x,ff)
        x = layers.ReLU()(x)
        return x
    
    x = tf.keras.layers.Rescaling(1./255)(inputs)
#     x = layers.Lambda(lambda x: tf.image.per_image_standardization(x))(inputs)
    
    ffs = []
    
    def conv_with_ff(x, filters, strides):
        x = maxpoolconv(filters, strides, config=config)(x)
        ffs.append(layers.SpatialDropout2D(config.ff_dropout_rate)(x))
        return x
#     large_inputs = MyDeconv(16, 3, 1, name="input_ff", resize_to=(720, 1080), norm=norm)(inputs)
#     large_inputs = MyDeconv(32, 3, 1, name="input_ff", resize_to=None, norm=norm)(inputs)
    
#     x = conv(32, 3, 1, name="first_conv1", use_bias=True, norm=norm)(x)
#     x = conv(32, 5, 1, name="first_conv2", use_bias=True, norm=norm)(x)
    x = Resblock(16, 3, name=f"first_resblock", config=config)(x)
    ffs.append(layers.SpatialDropout2D(config.ff_dropout_rate)(x))
    
    x = conv_with_ff(x, 32, 2)
    x = conv_with_ff(x, 64, 2)
    x = conv_with_ff(x, 128, 2)
    x = conv_with_ff(x, 512, 2)
    x = maxpoolconv(1024, (2,3), config=config)(x)
    x = layers.SpatialDropout2D(config.dropout_rate)(x)
        
    ffs.reverse()
    x = conv(1024, 3, 1, name="middle_conv", use_bias=True, config=config)(x)
    
  
    x = deconv_with_ff(512, x, ffs[0], 3, (2,3), "mydeconv7")
    x = deconv_with_ff(128, x, ffs[1], 3, 2, "mydeconv6")
    x = deconv_with_ff(64, x, ffs[2], 3, 2, "mydeconv5")
    x = deconv_with_ff(32, x, ffs[3], 3, 2, "mydeconv4")
    x = Resblock(32, 3, name=f"rb1",config=config)(x)
    x = deconv_with_ff(16, x, ffs[4], 3, 2, "mydeconv3")
    
    x = MyDeconv(16, 3, 1, name="mydeconv2", resize_to=(360, 540), config=config)(x)
    x = Resblock(16, 5, name=f"rb2", config=config)(x)
    x = MyDeconv(16, 3, 2, name="mydeconv1", resize_to=None, config=config)(x)
#     x = deconv_with_ff(16, x, large_inputs, 3, 2, "mydeconv1")
    
    x = conv(16, 5, 1, name="last_conv1", use_bias=True, config=config)(x)
    x = Resblock(9, 3, name=f"last_resconv", config=config)(x)
    x = conv(3, 3, 1, name="output_conv", activation="sigmoid", use_bias=True, config=config)(x)

#     x = deconv(64, kernel_size=1, strides=1, name="prelast_deconv", norm=True)(x)
#     x = deconv(3, kernel_size=1, strides=1, name="last_deconv", norm=True)(x)
#     x = layers.Conv2DTranspose(3, kernel_size=3, strides=1, padding="same", 
#                       activation="sigmoid", kernel_initializer=tf.keras.initializers.HeNormal(seed=32))(x)
    
#     x = MyRescale()(x)
    x = tf.keras.layers.Rescaling(255.)(x)
#     x = tf.keras.layers.Lambda(lambda x: tf.cast(x, tf.uint8))(x)
    
    return tf.keras.Model(inputs, x)


## Defining the best range of learning rates

In [None]:
# best_lr_range = [1e-4, 5e-7]
best_lr_range = [5e-6, 5e-7]

initial_learning_rate = best_lr_range[0]
search_epochs = 100

def lr_from_search(epochs):
    lrs = tf.linspace(best_lr_range[0], best_lr_range[1], epochs, name=None, axis=0)
    def fn(epoch, lr):
        return lrs[epoch]
    return fn

In [None]:
def lr_from_range(epochs, lr_range):
    start,end = lr_range
    lrs = tf.linspace(start, end, epochs, name=None, axis=0)
    def fn(epoch, lr):
        if epoch > epochs:
            return end
        return lrs[epoch]
    return fn
    
lr_from_range(10, (5e-6, 5e-7))(4, 0.1)

In [None]:
class RangedExpDecay(tf.keras.optimizers.schedules.ExponentialDecay):
    def __init__(self, end, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.end = end
        
    def __call__(self, step):
        return tf.math.minimum(super().__call__(step), self.end)

def schedule_exp_lr_range(start, end):
    return RangedExpDecay(
        end,
        start,
        decay_steps=50,
        decay_rate=0.95,
        staircase=False)

### Get an exponential rate optimizer from a range of learning rates

In [None]:
def getExponentialOptimizer(start, end, epochs):
    t = 1/tf.math.maximum(epochs-1, 1)
    r = tf.math.pow(end/start, t)

    return tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=start, decay_steps=1,decay_rate=r)

def getPieceWiseOptimizer(values, epochs):
    epochs_per_range = int(epochs / len(values))
    
    start = values[0]
    curr_epochs = 0
    new_values = None
    for iend, end in enumerate(values):
        if iend == len(values)-1:
            epochs_per_range = epochs - curr_epochs
        ls = tf.linspace(start, end, epochs_per_range)
        curr_epochs += epochs_per_range
        start = end
        if new_values is None:
            new_values = ls
        else:
            new_values = tf.concat([new_values, ls], axis=0)
    

    boundaries = list(range(epochs-1))
    print(f"{tf.squeeze(new_values).shape} ls={new_values}")
    print(f"epochs_per_range = {epochs_per_range}, boundaries={boundaries} ")
    
    v = [i for i in new_values]
    
    return keras.optimizers.schedules.PiecewiseConstantDecay(
        boundaries, v)

    

# ran = getDecayRateOptimizer(1e-8,1e-3, 10)
ran = getPieceWiseOptimizer([0.1, 0.2, 0.1], 11)
for i in range(11):
    print(ran(i))

## Gradient accumulator

In [None]:
class CustomTrainStep(tf.keras.Model):
    def __init__(self, n_gradients, autoencoder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.autoencoder = autoencoder
        print(f"CustomTrainStep: n_gradients = {n_gradients}")
        self.n_gradients = tf.Variable(n_gradients, dtype=tf.int32, trainable=False)
        self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False)
        self.update_count = tf.Variable(0, dtype=tf.int32, trainable=False)
        self.gradient_accumulation = [tf.Variable(tf.zeros_like(v, dtype=tf.float32), 
                                                  trainable=False) for v in self.trainable_variables]

        
    def call(self, data):
        return self.autoencoder(data, training=False)
    
    def train_step(self, data):
        self.n_acum_step.assign_add(1)

        x, y = data
        # Gradient Tape
        with tf.GradientTape() as tape:
            y_pred = self.autoencoder(x, training=True)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
            scaled_loss = self.optimizer.get_scaled_loss(loss)
            
            
        # Calculate batch gradients
        scaled_grads = tape.gradient(scaled_loss, self.autoencoder.trainable_variables)
        gradients = self.optimizer.get_unscaled_gradients(scaled_grads)
        # Accumulate batch gradients
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign_add(gradients[i])
 
        # If n_acum_step reach the n_gradients then we apply accumulated gradients to update the variables otherwise do nothing
        tf.cond(tf.equal(self.n_acum_step, self.n_gradients), self.apply_accu_gradients, lambda: None)

        # update metrics
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def apply_accu_gradients(self):
        # apply accumulated gradients
        self.optimizer.apply_gradients(zip(self.gradient_accumulation, self.autoencoder.trainable_variables))
        self.update_count.assign_add(1)

        # reset
        self.n_acum_step.assign(0)
        for i in range(len(self.gradient_accumulation)):
            self.gradient_accumulation[i].assign(tf.zeros_like(self.autoencoder.trainable_variables[i], dtype=tf.float32))

Best result yet:

- In 150/300 epochs, with ranges best_lr_range = [3e-5, 5e-7] and linspace of size 300.

## Trainer tool

In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import LambdaCallback, EarlyStopping
from time import time
import json

class LRMetric(tf.keras.metrics.Mean):
    def __init__(self, model, name="lr_metric", **kwargs):
        super(LRMetric, self).__init__(name=name, **kwargs)
        self.model = model
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        super().update_state(self.model.optimizer.lr)

class GradientUpdates(tf.keras.metrics.Metric):
    def __init__(self, model, name='gradient_updates_metric', **kwargs):
        super(GradientUpdates, self).__init__(name=name, **kwargs)
        self.model = model
        self.true_positives = self.add_weight(name='gradient_updates', initializer='zeros', dtype=tf.int32)

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.true_positives.assign(self.model.update_count)

    def result(self):
        return self.true_positives

class Trainer:
    def __init__(self, name, output_size):
        self.name = name
        self.output_size = output_size
        self.v = 1
        self.lossFn = build_perceptual_loss(input_shape=(self.output_size[0], self.output_size[1], 3))
        self.aggregate = False
        self.useLRRange = False
        self.callbacks = []
        self.early_stop = False
        self.n_gradients = None
        self.lr_range = None
        self.loss_fn_name = "perceptual_loss"
        self.opt = None
        self.lrs = None
        self.tb_util = None
        self.run_eagerly = False
        self.use_hp = False
        self.log_dir = None
#         print("clear session:")
#         tf.keras.backend.clear_session()

    def withCallbacks(self, callbacks = []):
        self.callbacks = callbacks
        return self
        
    def withVersion(self, v):
        self.v = v
        return self
        
    def setLoss(self, loss_name, lossFn):
        self.lossFn = lossFn
        self.loss_fn_name = loss_name
            
    def withMSEPerceptualLoss(self, ploss_downscale=10000):
        perceptual_loss = build_perceptual_loss(input_shape=(self.output_size[0], self.output_size[1], 3))
        mse = tf.keras.losses.MeanSquaredError()
        def custom_loss(x,y):
            return tf.cast(mse(x,y), tf.float32) * (perceptual_loss(x,y)/tf.cast(ploss_downscale, tf.float32))
        self.lossFn = custom_loss
        self.loss_fn_name = "mse_perceptual_loss"
        self.add_hp_callback = None
        return self
        
    def withEagerRun(self, active=True):
        self.run_eagerly = active
        return self
    
    def withLRRange(self, lr_range):
        self.lr_range = lr_range
        self.useLRRange = True
        return self
    
    def withFindBestLR(self, lr_range, epochs):       
        self.lrs = tf.keras.callbacks.LearningRateScheduler(getExponentialOptimizer(lr_range[0],lr_range[1],epochs), verbose=1)
        return self
    
    def withLRScheduler(self, lrs):
        self.lrs = lrs
    
    def withOptimizer(self, opt):
        self.opt = opt
        return self
        
    def withGradientAggregation(self, n_gradients=3):
        self.aggregate = True
        self.n_gradients = n_gradients
        return self
        
    def getCheckpointPath(self):
        return f'./prepa_{self.name}_ckpt'
    
    def getSavedModelPath(self):
        return f'./saved_models/{self.name}_{self.v}'
    
    
    '''
        Use profile batch like this: profile_batch="5,10"
    '''
    def withTensorboard(self, logs_dir, profile_batch=0, use_hp=False):
        log_dir = f"./{logs_dir}/{self.name}"
        self.log_dir = log_dir
        self.tb_util = TensorboardUtil(log_dir)
        self.use_hp = use_hp
        self.callbacks.append(self.tb_util.get_callback(profile_batch=profile_batch))
        return self           

    def withEarlyStopping(self, restore_best_weights=True, patience=5):
        self.early_stop = True
        self.callbacks.append(EarlyStopping(monitor='val_loss', restore_best_weights=restore_best_weights, patience=patience, min_delta=0.))
        return self
    
    def getTrainingConfig(self):
        return {
            "name":self.name,
            "version": self.v,
            "output_size": self.output_size,
            "aggregate_gradients": self.aggregate,
            "n_gradients": self.n_gradients,
            "lr_range": self.lr_range,
            "loss_fn_name": self.loss_fn_name,
            "model_path": self.getSavedModelPath(),
            "early_stop": self.early_stop
        }
    
    def validation_results(self, model, dataset):     
        #loss: 214513.0625 - gradient_upds: 2.0000 - lr_metric: 5.0000e-05 - mse: 6808.1318 - accuracy: 0.3310
        loss, gradient_upds, lr_metric, mse, accuracy = model.evaluate(dataset)
        return {
            "accuracy": accuracy,
            "loss":loss,
            "mse":mse
        }
    
    def getTrainingConfigPreTrain(self, epochs):
        conf = self.getTrainingConfig()
        conf["time"] = time()
        conf["epochs"] = epochs
        return conf
    
    def saveTrainingConfig(self, epochs):
        file1 = open("training.log", "a")  # append mode
        file1.write(json.dumps(self.getTrainingConfigPreTrain(epochs)))
        file1.write("\n")
        file1.close()
    
    def train(self, autoencoder, train_ds, test_ds, epochs, initial_lr, load_checkpoint=True, config=None):
        metrics = []
        
        if self.tb_util:
            self.tb_util.log_text(config.toString(), "training_config", step=0)
        
        if self.aggregate:
            model = CustomTrainStep(n_gradients=self.n_gradients, autoencoder=autoencoder)
            metrics.append(GradientUpdates(model, name="gradient_upds"))
        else:
            model = autoencoder
        
        metrics.append(LRMetric(model, name="lr_metric"))
        
        self.saveTrainingConfig(epochs)
        
#         self.add_hp_callback(config)
        
        if self.lrs is not None:
            lr_schedule = initial_lr
            self.callbacks.append(self.lrs)
        elif self.useLRRange and self.lrs is None:
            assert initial_lr == self.lr_range[0]
            lr_schedule=self.lr_range[0]
            self.callbacks.append(tf.keras.callbacks.LearningRateScheduler(getPieceWiseOptimizer(self.lr_range,epochs), verbose=1))
        else:
            lr_schedule=initial_lr
        
        if self.opt is None or config.optimizer=="adam":
            self.opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
#             self.opt = tfa.optimizers.AdamW(weight_decay=0.0001, learning_rate=lr_schedule, beta_1=0.8, beta_2=0.99)
            self.opt = mixed_precision.LossScaleOptimizer(self.opt)
        
            
        checkpoint_dir = f'./prepa_{self.name}_ckpt'
        checkpoint = tf.train.Checkpoint(model=autoencoder, optimizer=self.opt)
        ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=5)
        
        if load_checkpoint and ckpt_manager.latest_checkpoint:
    #             checkpoint.restore(ckpt_manager.latest_checkpoint)
                checkpoint.restore(ckpt_manager.latest_checkpoint).expect_partial()
        else:
            print("No checkpoints found or loaded")
        
        def save_checkpoint(batch, logs):
            ckpt_save_path = ckpt_manager.save()
            
        images_renderer = ImageRenderer(4).withTensorboard(self.tb_util)
        def print_images(batch, logs):
            images_renderer.render(model=lambda x:model(x, training=False), datasets=(train_ds, test_ds), batch=batch)
#             print_validation(lambda x:model(x, training=False), batch_size=2, save=False, path="./", datasets=(train_ds, test_ds))
      
        def initial_print(logs):
            print_images(0, logs)
    
        def clear_print(batch, logs):
            if batch % 5 == 0:
                display.clear_output()
                
        def on_epoch_end(batch, logs):
            clear_print(batch, logs)
            print_images(batch, logs)
            save_checkpoint(batch, logs)
            
        self.callbacks.append(LambdaCallback(on_train_begin=initial_print ,on_epoch_end=on_epoch_end))
        metrics.append("mse")
        metrics.append("accuracy")
        
        model.compile(optimizer=self.opt, loss=self.lossFn, metrics=metrics, run_eagerly=self.run_eagerly)
        model.fit(train_ds.prefetch(tf.data.AUTOTUNE),
                                    epochs=epochs,
                                    validation_data=test_ds.prefetch(tf.data.AUTOTUNE),
                                    callbacks=self.callbacks)
        
        results = self.validation_results(model, test_ds)
        
        model_path = self.getSavedModelPath()
        model.save(self.getSavedModelPath())
        print(f'model saved to: {model_path}')
        return model_path, results, model
    
# t = Trainer("test_model", (720, 1080))\
#     .withMSEPerceptualLoss()\
#     .withLRRange((6e-5, 5e-8))\
#     .withGradientAggregation()\
#     .withTensorboard("prepa_autencoder_logs")\
#     .withEarlyStopping()

# # t.saveTrainingConfig(100)
# t.getTrainingConfig()


## Training

In [None]:
# epochs = 20 
# weights_path, ghistory, model = Trainer("d_and_s_srch_lr", (720, 1080))\
#     .withFindBestLR((5e-8, 2e-3), epochs)\
#     .withGradientAggregation(4)\
#     .withTensorboard("prepa_autencoder_logs", profile_batch=0)\
#     .train(
#         build_autoencoder_large_v3(),
#         div2k_ds("train").batch(4),
#         div2k_ds("validation").batch(4),
#         epochs,
#         1e-8,
#         load_checkpoint=False)

# 1e-4 : 7e-6

In [None]:
epochs = 100
# lr = (5e-5, 7e-6)
# train_config = TrainingConfig()

# tb_util = TensorboardUtil("test")
# [s.toString() for s in tb_util.hp_search()]

log_dir = "prepa_autencoder_logs"
name = f"aenc"

hpSearch = TensorboardHPSearch(log_dir, name)

def runTest(config, hp_callback, tname):
    lr = config.learning_rate
    weights_path, results, model = Trainer(tname, (720, 1080))\
        .withCallbacks([hp_callback])\
        .withLRRange(lr)\
        .withGradientAggregation(3)\
        .withTensorboard(log_dir)\
        .train(
                build_autoencoder_large_v3(config),
                div2k_ds("train").batch(3),
                div2k_ds("validation").batch(3),
                epochs,
                lr[0],
                load_checkpoint=False,
                config=config)
    
    hpSearch.log_results(tname, results["loss"], results["accuracy"], results["mse"])
    tf.keras.backend.clear_session()

TensorboardHPSearch(log_dir, name).hp_search(
    dropout_rate=[0.1],
    ff_dropout_rate=[0.1],
    regularizers=[1e-2],
    learning_rate=[[1e-4, 5e-5], [1e-4,1e-4, 5e-5, 5e-5,1e-5]],
    optimizer=["adam"],
    run=runTest)
    
#     model = build_autoencoder_large_v3(config)
    
#     loss = 0
#     count = 0
#     tf.keras.utils.disable_interactive_logging()
#     tf.keras.utils.enable_interactive_logging()

#     validation_ds = div2k_ds("validation").take(3).batch(3)
    
#     model.compile(optimizer="adam", loss=build_perceptual_loss(input_shape=(720, 1080, 3)), metrics=["mse", "accuracy"])

#     results = model.evaluate(validation_ds)
    

#     print(results)
    
# #     trainer.tb_util.log_scalar(loss, "accuracy", step=0)
    
    

In [None]:
# epochs = 150 
# weights_path, ghistory, model = Trainer("d_and_s", (720, 1080))\
#     .withVersion(1.1)\
#     .withLRRange((1e-5, 4e-6))\
#     .withGradientAggregation(4)\
#     .withTensorboard("prepa_autencoder_logs")\
#     .train(
#         build_autoencoder_large_v3(),
#         div2k_ds("train").batch(4),
#         div2k_ds("validation").batch(4),
#         epochs,
#         1e-5,
#         load_checkpoint=True)

In [None]:
ploss = build_perceptual_loss("")

prepa = [(str(p),'png') for p in pathlib.Path("./prepa").glob('*.png')]

def gen_img(img_path, img_type):
    raw_png = tf.io.read_file(str(img_path), name=img_path)
    return tf.image.decode_png(raw_png, channels=3, name=img_path)

def gen_prepa():
    for img_path, img_type in prepa:
        yield gen_img(img_path, img_type)

it = iter(gen_prepa())    

img1 = next(it)
img2 = next(it)
inverted = (255-img1)
plt.figure()
plt.imshow(img1)
plt.figure()
plt.imshow(inverted)

contrast = tf.image.random_contrast(img1, 0.2, 0.5)
plt.figure()
plt.imshow(tf.cast(contrast, tf.uint32))


sat = tf.image.random_saturation(img1, 5, 10)
plt.figure()
plt.imshow(tf.cast(sat, tf.uint32))



def losss(x, y):
    return ploss(tf.expand_dims(x, axis=0), tf.expand_dims(y, axis=0))

print(f"ploss self: {losss(img1, img1)}")
print(f"ploss img2: {losss(img1, img2)}")
print(f"ploss inverted: {losss(img1, inverted)}")
print(f"ploss contrast: {losss(img1, contrast)}")
print(f"ploss saturation: {losss(img1, sat)}")

In [None]:
def restore_grad_acc_model(model_name, model_fn):
    model = model_fn()
    model.load_weights(f'./saved_models/{model_name}')
    return model

def test_prepa(autoencoder):
    prepa = [(str(p),'png') for p in pathlib.Path("./prepa").glob('*.png')]

    def gen_img(img_path, img_type):
        raw_png = tf.io.read_file(str(img_path), name=img_path)
        return tf.image.decode_png(raw_png, channels=3, name=img_path)

    def gen_prepa():
        for img_path, img_type in prepa:
            yield gen_img(img_path, img_type)

    it = iter(gen_prepa())    
    for i in range(10):


        img = next(it)
        plt.figure()
        plt.imshow(img)
        plt.figure()
        y_pred = autoencoder.predict(tf.expand_dims(img, axis=0))
        y_pred = tf.cast(y_pred, tf.uint8)
        plt.imshow(y_pred[0])


    # tf.math.reduce_mean(
    #     y_pred, axis=[1,2], keepdims=False, name=None
    # )

    # y = tf.reshape(y_pred, (1,-1,3))
    # y = layers.AveragePooling1D(2)(tf.cast(y, tf.float32))
    # y = layers.AveragePooling1D(2)(tf.cast(y, tf.float32))
    # y = layers.AveragePooling1D(2)(tf.cast(y, tf.float32))
    # # y = tf.image.resize(y, (720, 1080))
    # y
test_prepa(model)


In [None]:
# m = restore_grad_acc_model("denoise_and_scale_100e_large_v3_0002", build_autoencoder_large_v3)

In [None]:
model.summary()