## **Importing Libraries.**

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")

import tensorflow as tf
from functools import partial
from tensorflow import keras
import keras.layers as L

from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, LearningRateScheduler

from kaggle_datasets import KaggleDatasets

import re

## **Device**

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Device:", tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

## **Data Configuration**

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

GCS_DS_Path = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_Path)

In [None]:
IMAGE_SIZE = [512,512]
EPOCHS = 22
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

GCS_PATH = GCS_DS_Path + "/tfrecords-jpeg-512x512"

In [None]:
train_files = tf.io.gfile.glob(GCS_PATH + "/train/*.tfrec")
val_files = tf.io.gfile.glob(GCS_PATH + "/val/*.tfrec")
test_files = tf.io.gfile.glob(GCS_PATH + "/test/*.tfrec")

In [None]:
print("Train TFRecord Files:", len(train_files))
print("Validation TFRecord Files:", len(val_files))
print("Test TFRecord Files:", len(test_files))

## **Decoding the Data**

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

In [None]:
def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

## **Defining Loading Models**

In [None]:
def load_dataset(filenames, labeled = True, ordered = False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disables order , increases speed
        
    dataset = tf.data.TFRecordDataset(
        filenames,
        num_parallel_reads=AUTOTUNE
    ) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order   
    )  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE
                         ) # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

In [None]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_flip_up_down(image)
    #image = tf.image.random_saturation(image, lower=0, upper=2)
    #image = tf.image.rot90(image)
    return image, label

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
NUM_TRAIN_IMAGES = count_data_items(train_files)
NUM_VAL_IMAGES = count_data_items(val_files)
NUM_TEST_IMAGES = count_data_items(test_files)
print(f"train : {NUM_TRAIN_IMAGES}, test : {NUM_TEST_IMAGES}, val : {NUM_VAL_IMAGES}")

In [None]:
def get_train_dataset(filenames, labeled = True, shuffle = True):
    dataset = load_dataset(filenames, labeled = labeled)
    dataset = dataset.map(data_augment, num_parallel_calls = AUTOTUNE)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_val_dataset(filenames, labeled = True, ordered = False):
    dataset = load_dataset(filenames, labeled=labeled,ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
    
def get_test_dataset(filenames, labeled = False, ordered = False):
    dataset = load_dataset(filenames, labeled=labeled, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
train_dataset = get_train_dataset(filenames=train_files)
val_dataset = get_val_dataset(filenames=val_files)

**LABELS**

In [None]:
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']

## **Visualize some Images**

In [None]:
image_batch, label_batch = next(iter(val_dataset))

def show_batch(image_batch, label_batch):
    plt.figure(figsize = [20,12])
    for i in range(25):
        plt.subplot(5,5,i+1)
        plt.imshow(image_batch[i])
        plt.title(CLASSES[label_batch[i].numpy()])
        plt.axis('off')
    plt.show()

In [None]:
show_batch(image_batch, label_batch)

## **MODEL**

### **Define Callbacks**

In [None]:
init_lr = 1e-4
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    init_lr, decay_steps=10_000, decay_rate=0.96, staircase=True
)


my_callbacks = [ModelCheckpoint("flowers&flowers.h5", save_best_only=True),
               EarlyStopping(monitor="val_loss", patience=4),
               CSVLogger("train.log")]

### **Build the model.**

In [None]:
def make_model():
    base_model = tf.keras.applications.DenseNet121(
        input_shape = [*IMAGE_SIZE, 3], include_top = False, 
        weights = "imagenet", pooling="avg"
    )
    
    base_model.trainable = True
    
    for layer in base_model.layers[:54]:
        layer.trainable = False
    
    inputs = L.Input([*IMAGE_SIZE, 3])
    #x = tf.keras.applications.densenet.preprocess_input(inputs)
    x = base_model(inputs)
    #x = L.Dense(256, activation = "relu")(x)
    #x = L.Dropout(0.4)(x)
    outputs = L.Dense(104, activation = "softmax")(x)
    
    model = tf.keras.models.Model(inputs, outputs)
    
    model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        loss = "sparse_categorical_crossentropy",
        metrics = ["sparse_categorical_accuracy"]
    )
    
    print(model.summary())
    return model

In [None]:
with strategy.scope():
    model = make_model()

### **Training**

In [None]:
STEPS_PER_EPOCH = NUM_TRAIN_IMAGES // BATCH_SIZE

In [None]:
history = model.fit(train_dataset,
                    epochs = EPOCHS,
                    steps_per_epoch = STEPS_PER_EPOCH,
                    validation_data = val_dataset,
                    callbacks = my_callbacks)

**Accuracy / loss**

In [None]:
plt.figure(figsize=(15,6))

plt.subplot(1,2,1)
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label = 'Training')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label = 'validation')

plt.title("Accuracy")
plt.legend()

plt.subplot(1,2,2)
plt.plot(history.epoch,history.history['loss'],label = 'Training')
plt.plot(history.epoch,history.history['val_loss'],label = 'validation')

plt.title("Loss")
plt.legend()
plt.show()

## **Test and Predictions.**

In [None]:
model = tf.keras.models.load_model("./flowers&flowers.h5")

In [None]:
model.evaluate(val_dataset)

In [None]:
test_dataset = get_test_dataset(test_files, ordered = True)

In [None]:
def show_batch_predictions(image_batch):
    plt.figure(figsize=(20, 12))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[0][n])
        img_array = tf.expand_dims(image_batch[0][n], axis=0)
        plt.title(CLASSES[np.argmax(model.predict(img_array)[0])])
        plt.axis("off")


image_batch = next(iter(test_dataset))

show_batch_predictions(image_batch)

### **Submission.**

In [None]:
sample = pd.read_csv("../input/tpu-getting-started/sample_submission.csv")

In [None]:
sample

In [None]:
print("Making Predictions....")

test_images_ds = test_dataset.map(lambda image, idnum : image)
prob = model.predict(test_images_ds)
pred = np.argmax(prob, axis = -1)
print(pred)

print("Generating CSV file....")

test_ids_ds = test_dataset.map(lambda image, idnum : idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(7382))).numpy().astype("U")

print("Saving CSV....")

np.savetxt(
    "submission.csv",
    np.rec.fromarrays([test_ids, pred]),
    fmt = ['%s', '%d'],
    delimiter = ",",
    header = "id,label",
    comments="",
)
print("Completed!")