In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
import tensorflow_addons as tfa
import numpy as np

 The versions of TensorFlow you are currently using is 2.4.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
nfnet_params = {'F0': {'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],'drop_rate': 0.2},
                'F1': {'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],'drop_rate': 0.3},
                'F2': {'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],'drop_rate': 0.4},
                'F3': {'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],'drop_rate': 0.4},
                'F4': {'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],'drop_rate': 0.5},
                'F5': {'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],'drop_rate': 0.5},
                'F6': {'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],'drop_rate': 0.5},
                'F7': {'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24], 'drop_rate': 0.5}}

nonlinearities = {"identity": lambda x: x,
                  "celu": lambda x: tf.nn.crelu(x) * 1.270926833152771,
                  "elu": lambda x: tf.nn.elu(x) * 12716004848480225,
                  "gelu": lambda x: tf.nn.gelu(x) * 1.7015043497085571,
                  "leaky_relu": lambda x: tf.nn.leaky_relu(x) * 1.70590341091156,
                  "log_sigmoid": lambda x: tf.math.log(tf.nn.sigmoid(x)) * 1.9193484783172607,
                  "log_softmax": lambda x: tf.math.log(tf.nn.softmax(x)) * 1.0002083778381348,
                  "relu": lambda x: tf.nn.relu(x) * 1.7139588594436646,
                  "relu6": lambda x: tf.nn.relu6(x) * 1.7131484746932983,
                  "selu": lambda x: tf.nn.selu(x) * 1.0008515119552612,
                  "sigmoid": lambda x: tf.nn.sigmoid(x) * 4.803835391998291,
                  "silu": lambda x: tf.nn.silu(x) * 1.7881293296813965,
                  "soft_sign": lambda x: tf.nn.softsign(x) * 2.338853120803833,
                  "softplus": lambda x: tf.nn.softplus(x) * 1.9203323125839233,
                  "tanh": lambda x: tf.nn.tanh(x) * 1.5939117670059204,}

In [3]:
class WSConv2D(layers.Conv2D):
    def __init__(self, *args, **kwargs):
        super(WSConv2D, self).__init__(kernel_initializer=keras.initializers.VarianceScaling(scale=1.0, mode='fan_in', distribution='untruncated_normal',), *args, **kwargs)
        self.gain = self.add_weight(name='gain',shape=(self.filters,),initializer="ones",trainable=True,dtype=self.dtype)

    def standardize_weight(self, eps):
        mean = tf.math.reduce_mean(self.kernel, axis=(0, 1, 2), keepdims=True)
        var = tf.math.reduce_variance(self.kernel, axis=(0, 1, 2), keepdims=True)
        fan_in = tf.cast(tf.math.reduce_prod(self.kernel.shape[:-1]),'float32')
        scale = tf.math.rsqrt(tf.math.maximum(var * fan_in,eps)) * self.gain
        shift = mean * scale
        return self.kernel * scale - shift

    def call(self, inputs, eps=1e-4):
        weight = self.standardize_weight(eps)
        return tf.nn.conv2d(inputs, weight, strides=self.strides, padding=self.padding.upper(), dilations=self.dilation_rate) + self.bias
    
class SE_Block(keras.Model):
    def __init__(self,out_ch,se_ratio=0.5, act='relu'):
        super(SE_Block, self).__init__()
        self.GAP = layers.GlobalAvgPool2D()
        self.main = keras.Sequential([layers.Conv2D(int(out_ch*se_ratio),1,activation = act),layers.Conv2D(out_ch,1)])

    def call(self,x):
        x = self.GAP(x)[:,tf.newaxis,tf.newaxis]
        x = self.main((x))
        return tf.nn.sigmoid(x) * x

class StochDepth(keras.Model):
    def __init__(self, drop_rate):
        super(StochDepth, self).__init__()
        self.drop_rate = drop_rate

    def call(self, x, training):
        if not training:
            return x
        r = tf.random.uniform(shape=[x.shape[0], 1, 1, 1], dtype=x.dtype)
        keep_prob = 1. - self.drop_rate
        binary_tensor = tf.floor(keep_prob + r)
        return x * binary_tensor

In [4]:
class NFBlock(keras.Model):
    def __init__(self,in_ch,out_ch,expansion=0.5,se_ratio=0.5,kernel_shape=3,group_size=128,stride=1,
                 beta=1.0,alpha=0.2,activation=None,use_two_convs=True,stochdepth_rate=None):
        super(NFBlock, self).__init__()
        self.beta, self.alpha = beta, alpha
        self.activation = activation
        self.avgpool = layers.AveragePooling2D(int(stride)) if stride > 1 else None
        self.conv0 = WSConv2D(group_size * int((out_ch) * expansion) // group_size,1,1,padding="same")
        self.conv1 = WSConv2D(group_size * int((out_ch) * expansion) // group_size,kernel_shape,int(stride),padding="same",groups=int((out_ch) * expansion) // group_size)
        self.conv1b = WSConv2D(group_size * int((out_ch) * expansion) // group_size,kernel_shape,1,padding="same",groups=int((out_ch) * expansion) // group_size) if use_two_convs else None
        self.conv2 = WSConv2D(out_ch,1,1,padding="same")
        self.short_conv = WSConv2D(out_ch,1,padding="same") if stride>1 or in_ch != out_ch else None
        self.se = SE_Block(out_ch,se_ratio)
        self.stoch_depth = StochDepth(stochdepth_rate)  if (stochdepth_rate is not None) else None
        self.skip_gain = self.add_weight(name="skip_gain",shape=(),initializer="zeros",trainable=True,dtype=self.dtype)
            
    def call(self, x, training):
        out = layers.Lambda(self.activation)(x * self.beta)
        x = self.avgpool(out) if self.avgpool else x
        shortcut = self.short_conv(x) if self.short_conv  else x
        out = self.conv0(out)
        out = self.conv1(layers.Lambda(self.activation)(out))
        out = self.conv1b(layers.Lambda(self.activation)(out)) if self.conv1b else out
        out = self.conv2(layers.Lambda(self.activation)(out))
        out = self.se(out) * 2.0
        out = self.stoch_depth(out, training) if self.stoch_depth else out
        out = out * self.skip_gain
        return out * self.alpha + shortcut

In [5]:
def make_NFNet(num_classes=1000,variant="F0",group=128,expand_width = 1.0,se_ratio=0.5,alpha=0.2,depth_rate=0.1,act="relu",last_expand=2):
    activation = nonlinearities[act]
    Input_x = layers.Input((224,224,3))
    x = WSConv2D(16,3,2,padding="same")(Input_x)
    x = layers.Lambda(activation)(x)
    x = WSConv2D(32,3,1,padding="same")(x)
    x = layers.Lambda(activation)(x)
    x = WSConv2D(64,3,1,padding="same")(x)
    x = layers.Lambda(activation)(x)
    x = WSConv2D(nfnet_params[variant]['width'][0]//2,3,2,padding="same")(x)
    index = 0  
    for (block_width,stage_depth,stride) in zip(nfnet_params[variant]['width'],nfnet_params[variant]['depth'],[1]+[2]*len(nfnet_params[variant]['depth'])):
        expected_std = 1.0
        for block_index in range(stage_depth):
            x = NFBlock(x.shape[3],(block_width*expand_width),0.5,se_ratio,3,128,stride if block_index == 0 else 1,
                        1.0/expected_std,alpha,activation,True,depth_rate*index/sum(nfnet_params[variant]['depth']))(x)
            expected_std = (expected_std **2 + alpha**2)**0.5
            index += 1
    x = WSConv2D(int(last_expand * x.shape[3]),1, padding="same")(x)
    x = layers.Lambda(activation)(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(nfnet_params[variant]['drop_rate'])(x)
    x = layers.Dense(num_classes)(x)
    return keras.Model(Input_x,x)

In [6]:
@tf.function
def clip_gradient(grad, weight, clipping=0.01, eps=1e-3):
    param_norm = tf.math.maximum(tf.math.reduce_sum(weight**2, axis=[0,1,2], keepdims=True) ** 0.5, eps)
    grad_norm = tf.math.reduce_sum(grad**2, axis=[0,1,2], keepdims=True)
    max_norm = param_norm * clipping
    trigger = grad_norm < max_norm
    clipped_grad = grad * (max_norm / tf.maximum(grad_norm, 1e-6))
    return tf.where(trigger, grad, clipped_grad)


@tf.function
def unitwise_norm(x):
    if len(x.shape) <= 1:
        axis = None
        keepdims = False
    elif len(x.shape) in [2, 3]:
        axis = 0
        keepdims = True
    elif len(x.shape) == 4:
        axis = [0, 1, 2]
        keepdims = True
    return tf.math.reduce_sum(x ** 2, axis=axis, keepdims=keepdims) ** 0.5

@tf.function
def clip_gradient(grad, weight, clipping=0.01, eps=1e-3):
    param_norm = tf.math.maximum(unitwise_norm(weight), eps)
    grad_norm = unitwise_norm(grad)
    max_norm = param_norm * clipping
    trigger = grad_norm < max_norm
    clipped_grad = grad * (max_norm / tf.math.maximum(grad_norm, 1e-6))
    return tf.where(trigger, grad, clipped_grad)

In [7]:
@tf.function
def train_step(x, y,model,optimizer,ema,clipping_factor=0.01):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True)
        loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y, y_pred)
    gradients = tape.gradient(loss, model.trainable_weights)
    clipped_gradients = []
    for grad, weight in zip(gradients, model.trainable_weights):
        if ("dense" in weight.name):
            clipped_gradients.append(grad)
        else:
            clipped_gradients.append(clip_gradient(grad, weight, clipping=clipping_factor))
    decay_var_list = []
    for layer in model.trainable_weights:
        if not ("gain" in layer.name or "bias" in layer.name):
            decay_var_list.append(layer)
    opt_op = optimizer.apply_gradients(zip(clipped_gradients, model.trainable_weights),decay_var_list=decay_var_list)
    with tf.control_dependencies([opt_op]):
        ema.apply(model.trainable_variables)    
    return loss

In [8]:
class WarmUp(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self,initial_learning_rate,decay_schedule_fn,warmup_steps,power = 1.0):
        super().__init__()
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.power = power
        self.decay_schedule_fn = decay_schedule_fn

    def __call__(self, step):
        global_step_float = tf.cast(step, tf.float32)
        warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
        warmup_percent_done = global_step_float / warmup_steps_float
        warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
        return tf.cond(global_step_float < warmup_steps_float,lambda: warmup_learning_rate,lambda: self.decay_schedule_fn(step - self.warmup_steps))

In [9]:
NUM_CLASSES = 2
variant="F0"
num_step=10
warm_step = 5
lr=0.1
clipping=0.01
batch_size = 4
max_lr = lr * batch_size / 256
ema_decay = 0.99999

lr_decayed_fn = keras.experimental.CosineDecay(max_lr,num_step - 5000)
lr_schedule = WarmUp(max_lr,lr_decayed_fn,5000)
ema = tf.train.ExponentialMovingAverage(decay=ema_decay)

model = make_NFNet(NUM_CLASSES,variant)
optimizer = tfa.optimizers.SGDW(learning_rate=lr_schedule, weight_decay=2e-5, momentum=0.9)

x = np.concatenate([np.ones((2,224,224,3),'float32'),np.zeros((2,224,224,3),'float32')])
y = np.concatenate([np.ones((2,1),'float32'),np.zeros((2,1),'float32')])

for i in range(100):
    print(train_step(x,y,model,optimizer,ema,clipping))

Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
tf.Tensor(0.7018745, shape=(), dtype=float32)
tf.Tensor(0.70016325, shape=(), dtype=float32)
tf.Tensor(0.70016336, shape=(), dtype=float32)
tf.Tensor(0.7004414, shape=(), dtype=float32)
tf.Tensor(0.69982904, shape=(), dtype=float32)
tf.Tensor(0.6995492, shape=(), dtype=float32)
tf.Tensor(0.7006166, shape=(), dtype=float32)
tf.Tensor(0.69781744, shape=(), dtype=float32)
tf.Tensor(0.703122, shape=(), dtype=float32)
tf.Tensor(0.7002796, shape=(), dtype=float32)
tf.Tensor(0.7002015, shape=(), dtype=float32)
tf.Tensor(0.6988789, shape=(), dtype=float32)
tf.Tensor(0.70240563, shape=(), dtype=float32)
tf.Tensor(0.70096356, shape=(), dtype=float32)
tf.Tensor(0.70213974, shape=(), dtype=float32)
tf.Tensor(0.703694, shape=(), dtype=float32)
tf.Tensor(0.70119774, shape=(), dtype=float32)
tf.Tensor(0.7011466, shape=(), dtype=float32)
tf.Tensor(0.699