https://www.kaggle.com/datamafia7/efficientnet-b5-on-tpu/data?scriptVersionId=30017888


Testing of DropBlock and GEM (generalized mean pooling) with conv layer before Efficient net.


In [1]:
!pip install ../input/kaggle-efficientnet-repo/efficientnet-1.0.0-py3-none-any.whl

Processing /kaggle/input/kaggle-efficientnet-repo/efficientnet-1.0.0-py3-none-any.whl
Installing collected packages: efficientnet
Successfully installed efficientnet-1.0.0


In [2]:
import os
import numpy as np
import pandas as pd
import argparse
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy
from kaggle_datasets import KaggleDatasets
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score

from tensorflow.keras import layers as L
import efficientnet.tfkeras as efn


In [3]:
KaggleDatasets().get_gcs_path('tfrecords-grapheme-stratified')

'gs://kds-85103a0b82f80def92c1434d603b9689a68feee15283d8f5722d339a'

In [4]:
def normalize(image):
  # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/main.py#L325-L326
  # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_builder.py#L31-L32
  image -= tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255])  # RGB
  image /=  tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255])  # RGB
  return image

In [5]:
# https://www.kaggle.com/c/bengaliai-cv19/discussion/134905

class Generalized_mean_pooling2D(tf.keras.layers.Layer):
    def __init__(self, p=3, epsilon=1e-6, name='', **kwargs):
      super(Generalized_mean_pooling2D, self).__init__(name, **kwargs)

      self.init_p = p
      self.epsilon = epsilon

    def build(self, input_shape):

      if isinstance(input_shape, list) or len(input_shape) != 4:
        raise ValueError('`GeM` pooling layer only allow 1 input with 4 dimensions(b, h, w, c)')


      self.build_shape = input_shape

      self.p = self.add_weight(
              name='p',
              shape=[1,],
              initializer=tf.keras.initializers.Constant(value=self.init_p),
              regularizer=None,
              trainable=True,
              dtype=tf.float32
              )

      self.built=True

    def call(self, inputs):
      input_shape = inputs.get_shape()
      if isinstance(inputs, list) or len(input_shape) != 4:
        raise ValueError('`GeM` pooling layer only allow 1 input with 4 dimensions(b, h, w, c)')

      return (tf.reduce_mean(tf.abs(inputs**self.p), axis=[1,2], keepdims=False) + self.epsilon)**(1.0/self.p)

In [6]:
import tensorflow.keras.backend as K

In [7]:
class DropBlock1D(tf.keras.layers.Layer):
    """See: https://arxiv.org/pdf/1810.12890.pdf"""

    def __init__(self,
                 block_size,
                 keep_prob,
                 sync_channels=False,
                 data_format='channels_last',
                 **kwargs):
        """Initialize the layer.
        :param block_size: Size for each mask block.
        :param keep_prob: Probability of keeping the original feature.
        :param sync_channels: Whether to use the same dropout for all channels.
        :param data_format: 'channels_first' or 'channels_last' (default).
        :param kwargs: Arguments for parent class.
        """
        super(DropBlock1D, self).__init__(**kwargs)
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.sync_channels = sync_channels
        self.data_format = data_format #K.normalize_data_format(data_format)
        self.input_spec = tf.keras.layers.InputSpec(ndim=3)
        self.supports_masking = True

    def get_config(self):
        config = {'block_size': self.block_size,
                  'keep_prob': self.keep_prob,
                  'sync_channels': self.sync_channels,
                  'data_format': self.data_format}
        base_config = super(DropBlock1D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_mask(self, inputs, mask=None):
        return mask

    def compute_output_shape(self, input_shape):
        return input_shape

    def _get_gamma(self, feature_dim):
        """Get the number of activation units to drop"""
        feature_dim = K.cast(feature_dim, K.floatx())
        block_size = K.constant(self.block_size, dtype=K.floatx())
        return ((1.0 - self.keep_prob) / block_size) * (feature_dim / (feature_dim - block_size + 1.0))

    def _compute_valid_seed_region(self, seq_length):
        positions = K.arange(seq_length)
        half_block_size = self.block_size // 2
        valid_seed_region = K.switch(
            K.all(
                K.stack(
                    [
                        positions >= half_block_size,
                        positions < seq_length - half_block_size,
                    ],
                    axis=-1,
                ),
                axis=-1,
            ),
            K.ones((seq_length,)),
            K.zeros((seq_length,)),
        )
        return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1)

    def _compute_drop_mask(self, shape):
        seq_length = shape[1]
        mask = K.random_binomial(shape, p=self._get_gamma(seq_length))
        mask *= self._compute_valid_seed_region(seq_length)
        mask = tf.keras.layers.MaxPool1D(
            pool_size=self.block_size,
            padding='same',
            strides=1,
            data_format='channels_last',
        )(mask)
        return 1.0 - mask

    def call(self, inputs, training=None):

        def dropped_inputs():
            outputs = inputs
            if self.data_format == 'channels_first':
                outputs = K.permute_dimensions(outputs, [0, 2, 1])
            shape = K.shape(outputs)
            if self.sync_channels:
                mask = self._compute_drop_mask([shape[0], shape[1], 1])
            else:
                mask = self._compute_drop_mask(shape)
            outputs = outputs * mask *\
                (K.cast(K.prod(shape), dtype=K.floatx()) / K.sum(mask))
            if self.data_format == 'channels_first':
                outputs = K.permute_dimensions(outputs, [0, 2, 1])
            return outputs

        return K.in_train_phase(dropped_inputs, inputs, training=training)


class DropBlock2D(tf.keras.layers.Layer):
    """See: https://arxiv.org/pdf/1810.12890.pdf"""

    def __init__(self,
                 block_size,
                 keep_prob,
                 sync_channels=False,
                 data_format='channels_last',
                 **kwargs):
        """Initialize the layer.
        :param block_size: Size for each mask block.
        :param keep_prob: Probability of keeping the original feature.
        :param sync_channels: Whether to use the same dropout for all channels.
        :param data_format: 'channels_first' or 'channels_last' (default).
        :param kwargs: Arguments for parent class.
        """
        super(DropBlock2D, self).__init__(**kwargs)
        self.block_size = block_size
        self.keep_prob = keep_prob
        self.sync_channels = sync_channels
        self.data_format = data_format #K.normalize_data_format(data_format)
        self.input_spec = tf.keras.layers.InputSpec(ndim=4)
        self.supports_masking = True

    def get_config(self):
        config = {'block_size': self.block_size,
                  'keep_prob': self.keep_prob,
                  'sync_channels': self.sync_channels,
                  'data_format': self.data_format}
        base_config = super(DropBlock2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_mask(self, inputs, mask=None):
        return mask

    def compute_output_shape(self, input_shape):
        return input_shape

    def _get_gamma(self, height, width):
        """Get the number of activation units to drop"""
        height, width = K.cast(height, K.floatx()), K.cast(width, K.floatx())
        block_size = K.constant(self.block_size, dtype=K.floatx())
        return ((1.0 - self.keep_prob) / (block_size ** 2)) *\
               (height * width / ((height - block_size + 1.0) * (width - block_size + 1.0)))

    def _compute_valid_seed_region(self, height, width):
        positions = K.concatenate([
            K.expand_dims(K.tile(K.expand_dims(K.arange(height), axis=1), [1, width]), axis=-1),
            K.expand_dims(K.tile(K.expand_dims(K.arange(width), axis=0), [height, 1]), axis=-1),
        ], axis=-1)
        half_block_size = self.block_size // 2
        valid_seed_region = K.switch(
            K.all(
                K.stack(
                    [
                        positions[:, :, 0] >= half_block_size,
                        positions[:, :, 1] >= half_block_size,
                        positions[:, :, 0] < height - half_block_size,
                        positions[:, :, 1] < width - half_block_size,
                    ],
                    axis=-1,
                ),
                axis=-1,
            ),
            K.ones((height, width)),
            K.zeros((height, width)),
        )
        return K.expand_dims(K.expand_dims(valid_seed_region, axis=0), axis=-1)

    def _compute_drop_mask(self, shape):
        height, width = shape[1], shape[2]
        mask = K.random_binomial(shape, p=self._get_gamma(height, width))
        mask *= self._compute_valid_seed_region(height, width)
        mask = tf.keras.layers.MaxPool2D(
            pool_size=(self.block_size, self.block_size),
            padding='same',
            strides=1,
            data_format='channels_last',
        )(mask)
        return 1.0 - mask

    def call(self, inputs, training=None):

        def dropped_inputs():
            outputs = inputs
            if self.data_format == 'channels_first':
                outputs = K.permute_dimensions(outputs, [0, 2, 3, 1])
            shape = K.shape(outputs)
            if self.sync_channels:
                mask = self._compute_drop_mask([shape[0], shape[1], shape[2], 1])
            else:
                mask = self._compute_drop_mask(shape)
            outputs = outputs * mask *\
                (K.cast(K.prod(shape), dtype=K.floatx()) / K.sum(mask))
            if self.data_format == 'channels_first':
                outputs = K.permute_dimensions(outputs, [0, 3, 1, 2])
            return outputs

        return K.in_train_phase(dropped_inputs, inputs, training=training)

In [8]:
def get_model(input_size, backbone='dense', weights='imagenet', tta=False):
    print(f'Using backbone {backbone} and weights {weights}')
    x = L.Input(shape=input_size, name='imgs', dtype='float32')
    y = normalize(x)
    if backbone == 'dense':
        model_fn = tf.keras.applications.densenet.DenseNet169(input_shape=(input_size[0],input_size[1],3), weights=weights, include_top=False)
    if backbone == 'inception':
        model_fn = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(input_shape=(input_size[0],input_size[1],3), weights=weights, include_top=False)
    if backbone == 'xception':
        model_fn = tf.keras.applications.xception.Xception(input_shape=(input_size[0],input_size[1],3), weights=weights, include_top=False)
        
    #y = L.Conv2D(3,(3,3),padding='same')(x)
    #y = DropBlock2D(block_size=5, keep_prob=0.7, name='Dropout-1',input_shape=(input_size[0],input_size[1],3))(y)
    y_effn = model_fn(y)
    
    #model_effn = tf.keras.Model(x,y_effn)
    
    y_pooled = L.GlobalAveragePooling2D()(y_effn) #Generalized_mean_pooling2D()(y)
    
    #model_pooled = tf.keras.Model(x,y_pooled)
    
    y = L.Dropout(0.2)(y_pooled)
    #y = L.Dense(512)(y)
    
    # 1292 of 1295 are present
    #y1 = DropBlock1D(block_size=5,keep_prob=0.7)(y)
    y1 = L.Dense(168, activation='softmax',name='grapheme')(y)

    #y2 = DropBlock1D(block_size=5,keep_prob=0.4)(y)
    y2 = L.Dense(11, activation='softmax',name='vowel')(y)

    #y3 = DropBlock1D(block_size=5,keep_prob=0.4)(y)
    y3 = L.Dense(7, activation='softmax',name='consonant')(y)
    
    model = tf.keras.Model(x, [y1,y2,y3])

    if tta:
        assert False, 'This does not make sense yet'
        x_flip = tf.reverse(x, [2])  # 'NHWC'
        y_tta = tf.add(model(x), model(x_flip)) / 2.0
        tta_model = tf.keras.Model(x, y_tta)
        return model, tta_model

    return model

In [9]:
model = get_model((160,256,3),'xception')

Using backbone xception and weights imagenet
Downloading data from https://github.com/fchollet/deep-learning-models/releases/download/v0.4/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


In [10]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
imgs (InputLayer)               [(None, 160, 256, 3) 0                                            
__________________________________________________________________________________________________
tf_op_layer_sub (TensorFlowOpLa [(None, 160, 256, 3) 0           imgs[0][0]                       
__________________________________________________________________________________________________
tf_op_layer_truediv (TensorFlow [(None, 160, 256, 3) 0           tf_op_layer_sub[0][0]            
__________________________________________________________________________________________________
xception (Model)                (None, 5, 8, 2048)   20861480    tf_op_layer_truediv[0][0]        
______________________________________________________________________________________________

In [11]:
def mixup(img_batch, label_batch, batch_size):
    # https://github.com/tensorpack/tensorpack/blob/master/examples/ResNet/cifar10-preact18-mixup.py
    weight = tf.random.uniform([batch_size])
    x_weight = tf.reshape(weight, [batch_size, 1, 1, 1])
    y_weight = tf.reshape(weight, [batch_size, 1])
    index = tf.random.shuffle(tf.range(batch_size, dtype=tf.int32))
    x1, x2 = img_batch, tf.gather(img_batch, index)
    img_batch = x1 * x_weight + x2 * (1. - x_weight)
    y1, y2 = label_batch[0], tf.gather(label_batch[0], index)
    label1_batch = y1 * y_weight + y2 * (1. - y_weight)
    y1, y2 = label_batch[1], tf.gather(label_batch[1], index)
    label2_batch = y1 * y_weight + y2 * (1. - y_weight)
    y1, y2 = label_batch[2], tf.gather(label_batch[2], index)
    label3_batch = y1 * y_weight + y2 * (1. - y_weight)
    return img_batch, (label1_batch, label2_batch, label3_batch)

In [12]:
def get_strategy():
    # Detect hardware, return appropriate distribution strategy
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
        print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    except ValueError:
        tpu = None

    if tpu:
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    else:
        strategy = tf.distribute.get_strategy()

    print('REPLICAS: ', strategy.num_replicas_in_sync)
    return strategy


In [13]:
def one_hot_concatenated(image, label1, label2, label3):
    label = tf.concat([tf.one_hot(label1, 168),tf.one_hot(label2, 11),tf.one_hot(label3, 7)],-1)
    return image, label

def one_hot(image, label1, label2, label3):
    label = (tf.one_hot(label1, 168),tf.one_hot(label2, 11),tf.one_hot(label3, 7))
    return image, label

def read_tfrecords(example, input_size):
    features = {
      'img': tf.io.FixedLenFeature([], tf.string),
      'image_id': tf.io.FixedLenFeature([], tf.int64),
      'grapheme_root': tf.io.FixedLenFeature([], tf.int64),
      'vowel_diacritic': tf.io.FixedLenFeature([], tf.int64),
      'consonant_diacritic': tf.io.FixedLenFeature([], tf.int64),
      'unique_tuple': tf.io.FixedLenFeature([], tf.int64),
    }
    example = tf.io.parse_single_example(example, features)
    img = tf.image.decode_image(example['img'])
    img = tf.reshape(img, input_size + (1, ))
    img = tf.cast(img, tf.float32)
    # grayscale -> RGB
    img = tf.repeat(img, 3, -1)

    # image_id = tf.cast(example['image_id'], tf.int32)
    grapheme_root = tf.cast(example['grapheme_root'], tf.int32)
    vowel_diacritic = tf.cast(example['vowel_diacritic'], tf.int32)
    consonant_diacritic = tf.cast(example['consonant_diacritic'], tf.int32)
    # unique_tuple = tf.cast(example['unique_tuple'], tf.int32)
    return img, grapheme_root,vowel_diacritic,consonant_diacritic


In [14]:
class CustomRecall(tf.keras.callbacks.Callback):
    
    def __init__(self, val_ds):
        self.best_score = 0
        self.val_ds = val_ds
        
        self.val_pred_grapheme = []
        self.val_pred_vowel = []
        self.val_pred_cons = []
        
        self.val_targ_grapheme = []
        self.val_targ_vowel = []
        self.val_targ_cons = []
        
    def on_test_batch_end(self, batch, logs={}):
        
        print (batch)
        
        for batch_data in self.val_ds.take(batch+1):
            pass
        
        val_predict = self.model.predict(batch_data[0])
        val_targ = batch_data[1]
        
        self.val_pred_grapheme += val_predict[0].argmax(1).tolist()
        self.val_pred_vowel += val_predict[1].argmax(1).tolist()
        self.val_pred_cons += val_predict[2].argmax(1).tolist()
        
        self.val_targ_grapheme += val_targ[0].numpy().argmax(1).tolist()
        self.val_targ_vowel += val_targ[1].numpy().argmax(1).tolist()
        self.val_targ_cons += val_targ[2].numpy().argmax(1).tolist()
    
    def on_test_end(self, logs={}):
        
        recall_grapheme = recall_score(self.val_targ_grapheme, self.val_pred_grapheme, average='macro')
        recall_vowel = recall_score(self.val_targ_vowel, self.val_pred_vowel, average='macro')
        recall_cons = recall_score(self.val_targ_cons, self.val_pred_cons, average='macro')
        
        overall_recall = np.average([recall_grapheme,recall_vowel,recall_cons],weights=[.5,.25,.25])
        
        print (overall_recall, recall_grapheme, recall_vowel, recall_cons)
        
        if overall_recall > self.best_score:
            self.best_score = overall_score
            
            weight_fn = 'model-%d.h5' % (self.best_score*100)
            model.save_weights(weight_fn)
            print(f'Saved weights to: {weight_fn}')
            
    '''
    def on_epoch_end(self, epoch, logs={}):
        
        recall_grapheme = recall_score(self.val_targ_grapheme, self.val_pred_grapheme, average='macro')
        recall_vowel = recall_score(self.val_targ_vowel, self.val_pred_vowel, average='macro')
        recall_cons = recall_score(self.val_targ_cons, self.val_pred_cons, average='macro')
        
        overall_recall = np.average([recall_grapheme,recall_vowel,recall_cons],weights=[.5,.25,.25])
        
        print (overall_recall, recall_grapheme, recall_vowel, recall_cons)
        
        if overall_recall > self.best_score:
            self.best_score = overall_score
            
            weight_fn = 'model-%d-%d.h5' % (epoch, self.best_score*100)
            model.save_weights(weight_fn)
            print(f'Saved weights to: {weight_fn}')
    '''

In [15]:
def recall_old(y_true, y_pred):
    # Calculates the recall
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

In [16]:
def recall(y_true, y_pred):
    y_pred = K.round(y_pred)
    tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
    # tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
    fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
    fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)

    p = tp / (tp + fp + K.epsilon())
    r = tp / (tp + fn + K.epsilon())

    return r

In [17]:
def f1(y_true, y_pred):
    y_pred = K.round(y_pred)
    tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
    # tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
    fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
    fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)

    p = tp / (tp + fp + K.epsilon())
    r = tp / (tp + fn + K.epsilon())

    f1 = 2*p*r / (p+r+K.epsilon())

    return K.mean(f1)

In [18]:
def main():
    global parser, train_ds, val_ds, model, num_val_samples, val_step
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_id', type=int, default=0)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--input_size', type=str, default='160,256')
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--backbone', type=str, default='xception')
    parser.add_argument('--weights', type=str, default='imagenet')
    args, _ = parser.parse_known_args()

    args.input_size = tuple(int(x) for x in args.input_size.split(','))
    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

    # build the model
    strategy = get_strategy()
    with strategy.scope():
        model = get_model(input_size=args.input_size + (3, ), backbone=args.backbone,
            weights=args.weights)

    model.compile(optimizer=Adam(lr=args.lr),
                loss=categorical_crossentropy,
                metrics=[categorical_accuracy, tf.keras.metrics.Recall()]) #
    
    print(model.summary())
    AUTO = tf.data.experimental.AUTOTUNE
    # create the training and validation datasets
    ds_path = KaggleDatasets().get_gcs_path('tfrecords-grapheme-stratified') #KaggleDatasets().get_gcs_path('bengali-tfrecords-v010')
    
    train_fns = tf.io.gfile.glob(os.path.join(ds_path, 'train*.tfrec')) #tf.io.gfile.glob(os.path.join(ds_path, 'records/train*.tfrec'))
    train_ds = tf.data.TFRecordDataset(train_fns, num_parallel_reads=AUTO)
    train_ds = train_ds.map(lambda e: read_tfrecords(e, args.input_size), num_parallel_calls=AUTO)
    train_ds = train_ds.repeat().batch(args.batch_size)
    train_ds = train_ds.map(one_hot, num_parallel_calls=AUTO)
    train_ds = train_ds.map(lambda a, b: mixup(a, b, args.batch_size), num_parallel_calls=AUTO)

    val_fns = tf.io.gfile.glob(os.path.join(ds_path, 'val*.tfrec')) #tf.io.gfile.glob(os.path.join(ds_path, 'records/val*.tfrec'))
    val_ds = tf.data.TFRecordDataset(val_fns, num_parallel_reads=AUTO)
    val_ds = val_ds.map(lambda e: read_tfrecords(e, args.input_size), num_parallel_calls=AUTO)
    val_ds = val_ds.batch(args.batch_size)
    val_ds = val_ds.map(one_hot, num_parallel_calls=AUTO)

    callback1 = tf.keras.callbacks.EarlyStopping(monitor='val_grapheme_categorical_accuracy', mode='max', patience=5, verbose=1)
    def scheduler(epoch):
        if epoch < 4:
            return args.lr
        else:
            return args.lr * tf.math.exp(0.2 * (3 - epoch))
    callback2 = tf.keras.callbacks.LearningRateScheduler(scheduler,verbose=1)
    
    weight_fn = 'model-%04d.h5' % args.model_id
    
    callback3 = tf.keras.callbacks.ModelCheckpoint(monitor='val_grapheme_categorical_accuracy', mode='max', save_best_only=True, filepath=weight_fn, verbose=1) #CustomRecall(val_ds)
    
    # train
    num_train_samples = sum(int(fn.split('_')[2]) for fn in train_fns)
    num_val_samples = sum(int(fn.split('_')[2]) for fn in val_fns)
    steps_per_epoch = num_train_samples // args.batch_size
    val_step = num_val_samples // args.batch_size
    
    print(f'Training on {num_train_samples} samples. Each epochs requires {steps_per_epoch} steps')
    print(f'Validation on {num_val_samples} samples.')
    
    h = model.fit(train_ds, steps_per_epoch=steps_per_epoch, epochs=args.epochs, verbose=1,
      validation_data=val_ds,callbacks=[callback1,callback2,callback3])
    
    #print(h)
    #weight_fn = 'model-%04d.h5' % args.model_id
    #model.save_weights(weight_fn)
    #print(f'Saved weights to: {weight_fn}')
    
    #model_effn.save_weights('model_effn.h5')
    #model_pooled.save_weights('model_pooled.h5')

In [19]:
main()

Running on TPU  ['10.0.0.2:8470']
REPLICAS:  8
Using backbone xception and weights imagenet
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
imgs (InputLayer)               [(None, 160, 256, 3) 0                                            
__________________________________________________________________________________________________
tf_op_layer_sub_1 (TensorFlowOp [(None, 160, 256, 3) 0           imgs[0][0]                       
__________________________________________________________________________________________________
tf_op_layer_truediv_1 (TensorFl [(None, 160, 256, 3) 0           tf_op_layer_sub_1[0][0]          
__________________________________________________________________________________________________
xception (Model)                (None, 5, 8, 2048)   20861480    tf_op_layer_truediv_1[0][0]      
