## Intro
**Hello kagglers.
In this notebook I will load, train and save TFHub models with TPU. It's not for LB points, just for study.**
**Some code was taken from this** [notebook](https://www.kaggle.com/smirnyaginandr/notebooke180f47a7c)

## Step Uno
**Just import**

In [None]:
import pandas as pd 
import os
from functools import partial
import tensorflow as tf
import re, math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from kaggle_datasets import KaggleDatasets
import csv
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
from kaggle_secrets import UserSecretsClient


**Init datasets for use with TPU. Dataset 'tf-hub' contains 4 models from TFHub. I will use them for training.**

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
GCS_PATH_TO_SAVEDMODEL = KaggleDatasets().get_gcs_path('tf-hub')

## Step zwei
**Init resolver and strategy. It MUST be before the next step!**

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 
print("REPLICAS: ", strategy.num_replicas_in_sync)
AUTO = tf.data.experimental.AUTOTUNE

## Step three
**Set up credentials to grant kaggle access to your Google Cloud account. It's important for reading private datasets and for saving trained models to cloud storage.
This step MUST be after the init TPU strategy. Thanks to @morodertobias for the advice.**

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

## Step четыре
**Get filenames for train and validation sets**

In [None]:
TRAINING_FILENAMES, VALID_FILENAMES = train_test_split(
    tf.io.gfile.glob(GCS_DS_PATH + '/train_tfrecords/ld_train*.tfrec'),
    test_size=0.35, random_state=5
)

**Define some variables**

In [None]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
AUTOTUNE = tf.data.experimental.AUTOTUNE
CLASSES = 5
CHANNELS = 3
SEED = 42
DIM = 224

**Define functions to augment data and create datasets**

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0 
    image = tf.image.resize(image, [DIM, DIM])
    return image


def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image"              : tf.io.FixedLenFeature([], tf.string),
        "target"              : tf.io.FixedLenFeature([], tf.int64),
    }
    
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example["target"]
    label = tf.one_hot(label,depth=5)
    label = tf.cast(label,tf.float32)
    return image, label


def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image"              : tf.io.FixedLenFeature([], tf.string),
        "image_name"           : tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['image_name']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False 
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    return dataset

def data_augment(image, label):
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
            
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)

    return image, label  

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.map(data_augment)
    dataset = dataset.repeat() 
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALID_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

**Get datasets**

In [None]:
train_dataset = get_training_dataset()
validation_dataset = get_validation_dataset()

**Count steps for train and validation**

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

STEPS_PER_EPOCH = count_data_items(TRAINING_FILENAMES) // BATCH_SIZE
VALID_STEPS = count_data_items(VALID_FILENAMES) // BATCH_SIZE

**Define callbacks**

In [None]:
early_stop = tf.keras.callbacks.EarlyStopping(monitor = 'val_acc', min_delta = 0.001, 
                           patience = 10, mode = 'max', verbose = 1,
                           restore_best_weights = True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_loss', factor = 0.1, 
                              patience = 2, min_delta = 0.001, 
                              mode = 'min', verbose = 1)

## Step 五
**Before this step you need creata a bucket on the Google Cloud.** 

![](https://storage.cloud.google.com/cassava_saved_models/cr_bucket-1.png)

**I called my bucket "cassava_saved_models".**

![](https://storage.cloud.google.com/cassava_saved_models/cr_bucket-2.png)

**And then just train the models. I am skipping "efficientnet" because this model contains a module for TF1 and unfortunately I havn't found a way to load this model correctly.**

In [None]:
tf.keras.backend.clear_session()
models = [GCS_PATH_TO_SAVEDMODEL + '/nets/' + x for x in tf.io.gfile.listdir(GCS_PATH_TO_SAVEDMODEL + '/nets/')]
with strategy.scope():
    
    for model_path in models:
        model_name = model_path.split('/')[-2]
        if model_name.startswith('efficientnet'):
            continue
        print()
        print('=' * 100)
        print(f'Load model {model_name}')
        
        loaded_model = tf.saved_model.load(model_path)
        base_model = hub.KerasLayer(loaded_model, trainable=True)
        model = tf.keras.Sequential([
            tf.keras.Input(shape=(DIM, DIM, 3)),
            base_model, 
            tf.keras.layers.Dense(5, activation='softmax')
        ])
        model.compile(
                optimizer=tf.keras.optimizers.Adam(),
                loss='categorical_crossentropy',  
                metrics=['acc'])

        history = model.fit(train_dataset, 
                            epochs=20,
                            callbacks=[early_stop, reduce_lr],
                            validation_data = validation_dataset,
                            steps_per_epoch = STEPS_PER_EPOCH,
                            verbose=1)
        model.save(f'gs://cassava_saved_models/saved_{model_name}_{int(max(history.history["val_acc"])*100)}')
        print(f'Model {model_name} saved')

**After training you will find saved model in your bucket.**
![](https://storage.cloud.google.com/cassava_saved_models/cr_bucket-3.png)