In [5]:
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

import numpy as np
import matplotlib.pyplot as plt
%load_ext tensorboard
 

In [8]:
from model import BaseGenerator

In [7]:
@tf.function
def RUN(trainer, var):
    with tf.GradientTape() as tape:
        L = tf.reduce_sum(trainer.G.forward_model(np.random.rand(1, 200)))
    g = tape.gradient(L, var)
    trainer.G_opt.apply_gradients(zip(g, var))

In [None]:

class Generator_v0(BaseGenerator):
    
    def __init__(self, img_size=(32, 32), targ_img_size=(128, 128), z_dim=128, seed=1010, strategy_scope=None, name="GEN01",
                up_type="deconv", apply_resize=False):
        '''
        Input : 
            image_size : (tuple) target image size at starting training (before first extent model).
            targ_img_size : (tuple) target image size at the end of training.
            apply_resize : (bool) whether to resize the img_size to the targ_img_size if is not equal in the forward call.
        '''

        super(Generator_v0, self).__init__(img_size, targ_img_size, z_dim, seed, strategy_scope, name)
        assert up_type in ["deconv", "upsample"], "up_type must be one of deconv or upsample"
        assert img_size[0] == img_size[1], "img_size must be square"
        assert targ_img_size[0] == targ_img_size[1], "targ_img_size must be square"
        self.up_type = up_type
        self.blocks = []
        self.cur_img_size = img_size
        self.apply_resize = apply_resize
        self.blocks = []
        self.suff = NameCaller()
        self.optimizer = None
        self.cur_trainable = None # use just for creation of Variable that hold outside of tf.function

    def _initialize_base(self):
        '''Initialize model for based layer map and conv'''
        inputer = keras.layers.Input(shape=(self.z_dim, ), name="BaseInput"+self.suff.n)
        x1 = latent_mapping(inputer, num_layers=self.configbuild["num_layers"], units=self.configbuild["dense_units"],
                out_units=self.configbuild["out_units"], activation=self.configbuild["dense_act"], initializer=self.initializer, max_norm=None, 
                batch_norm=True, namemap="map_z", name_suffix="")
        
        x1 = keras.layers.Reshape(target_shape=(1, 1, self.configbuild["out_units"]), name='reshape_mapper'+self.suff.n)(x1)

        if self.up_type == "deconv":
            for i_layer in self.configbuild["BaseFilters"]:
                x1 = get_deconv(x1, i_layer, self.configbuild["kernel_size"], strides=(2, 2), padding="same", activation=self.configbuild["conv_act"], 
                        initializer=self.initializer, add_noise=None, max_norm=None, batch_norm=True, name_suffix=self.suff.n)
        
        elif self.up_type == "upsample":
            for i_layer in self.configbuild["BaseFilters"]:
                x1 = get_upsampling(x1, i_layer, self.configbuild["kernel_size"], strides=(1, 1), padding="same", up_size=(2, 2), 
                activation=self.configbuild["conv_act"], interpolation="nearest", initializer=self.initializer, add_noise=None, batch_norm=True, 
                                    name_suffix=self.suff.n)
                
        outputer = keras.layers.Conv2D(3, self.configbuild["kernel_size"], strides=(1, 1), padding="same", activation=self.configbuild["out_act"], use_bias=True,
                        name="BaseOutConv")(x1)
        assert tf.reduce_prod(x1.shape[1:-1]) == tf.reduce_prod(self.img_size), "mapping blocks do not correctly output shape, "\
        "must provide filters with length of {}".format(int(np.log2(self.img_size[0])))
        self.blocks.append(keras.Model(inputer, outputer, name="BaseGeneratorModel"))

    def _extend_model(self, filters=400, up=(1, 1), noise=False):
        '''extent one convolutional blocks with either deconvolution or upsampling'''
        if self.cur_img_size == self.targ_img_size:
            assert (up == (1, 1) and filters==3), "extent model cannot be called when image size is equal to target image size"

        if not self.initialized:
            raise ValueError("Generator model not initialized")
        inputer = keras.layers.Input(shape=self.blocks[-1].output_shape[1:], name="ExtendedInput"+self.suff.n)
        self.cur_img_size = (self.cur_img_size[0] * up[0], self.cur_img_size[0] * up[1])
        last_shape = self.blocks[-1].layers[-1].output_shape
        
        if self.up_type == "deconv":
            outputer = get_deconv(inputer, filters, self.configbuild["kernel_size"], strides=up, padding="same",
                activation=self.configbuild["conv_act"], initializer=self.initializer, add_noise=noise, max_norm=None, batch_norm=True, name_suffix=self.suff.n)
        
        elif self.up_type =="upsample":
            outputer = get_upsampling(inputer, filters, self.configbuild["kernel_size"], strides=(1, 1), padding="same", 
                up_size=up, activation=self.configbuild["conv_act"], interpolation="nearest", initializer=self.initializer, 
                add_noise=noise, batch_norm=True, name_suffix=self.suff.n)
            
        if up[0] > 1:
            outputer = keras.layers.Conv2D(3, (5, 5), strides=(1, 1), padding="same", activation=None, use_bias=True, name="Conv"+self.suff.n)(outputer)
            assert tf.reduce_prod(outputer.shape[1:-1]) == tf.reduce_prod(self.cur_img_size), "extended model do not correctly output shape"\
            "must provide filters with length of {}".format(int(np.log2(self.cur_img_size[0])))
        self.blocks.append(keras.Model(inputer, outputer))
        print("extended model from size of", last_shape, "to", outputer.shape)
    
    def _auto_extend(self):
        '''extend model automatically with arbitrary number of blocks provided in configbuild'''
        for ifilt in next(iter(self.configbuild["filters"]))[:-1]:
            self._extend_model(ifilt, up=(1, 1))
        self._extend_model(next(iter(self.configbuild["filters"]))[-1], up=(2, 2), noise=get_noise_out)

    def set_mapping_trainable(self, trainable=True, prefix='map_z'):
        for ly in self.get_flat_layers():
            if prefix in ly.name:
                ly.trainable = trainable
                print("set trainable_variables of layers {} to {}".format(ly.name, trainable))
    
    def set_joint_trainable(self, trainable=True, not_prefix='map_z'):
        for ly in self.get_flat_layers():
            if not_prefix not in ly.name and ly.trainable_variables != []:  
                ly.trainable = trainable
                print("set trainable_variables of layers {} to {}".format(ly.name, trainable))
                
    def forward_model(self, inputs, training=True): ############ use get_model instead of calling function sequentially, this may result in async weights update
        for bk in self.blocks:
            inputs = bk(inputs)
        if (self.apply_resize and self.cur_img_size != self.targ_img_size):
            inputs = tf.image.resize(inputs, self.targ_img_size, method="nearest")
        return tf.cast(inputs, dtype=tf.float32)
    
    def get_model(self, get_with_functional=False, with_scope=False):
        assert not get_with_functional
        assert self.extendable, "This Generator is not extendable"
        U_layers = self.blocks if get_with_functional else self.get_flat_layers()
        inputer = keras.layers.Input(shape=(self.z_dim), name='BaseInput000')
        xi = U_layers[1](inputer)
        for J_layers in U_layers[2:]:
            xi = J_layers(xi)
        if with_scope: # this actually returns Model with scope since it is sync for all tf.Variable, Just in case 
            with self.strategy_scope.scope():
                return keras.Model(inputer, xi)
        else:
            return keras.Model(inputer, xi)
    
    def get_flat_layers(self):
        lys = []
        for msl in self.blocks:
            for ms in msl.layers: 
                lys.append(ms)
        return lys
    
    def print_config_layers(self):
        for i_layer in self.get_flat_layers():
            print("Name:{} - in_shape:{} - out_shape:{} - trainable:{} - scope{}".format(
            i_layer.name, i_layer.input_shape, i_layer.output_shape, i_layer.trainable, i_layer.name_scope()))
    
    def get_trainable(self):
        # self.cur_trainable = self.get_model().trainable_variables
        self.cur_trainable = []
        for iu in self.get_flat_layers():
            for s in iu.trainable_variables:
                self.cur_trainable.append(s)
        return self.cur_trainable

    def update_params(self, grads):
        assert self.optimizer is not None, "optimizer is not provided"
        if self.cur_trainable is None: # this must not run in the context of tf.function since it create new variables
            self.get_trainable()
        self.optimizer.apply_gradients(zip(grads, self.cur_trainable))

    def save_model(self, path):
        self.get_model().save(path)|