In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import glob

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.applications import VGG19, Xception

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
batch_size = 256
IMAGE_SIZE = [512, 512]
num_epochs = 65

## Set strategy

In [None]:
# Set distribution strategy to use TPUs
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
print('Found connected TPU: ', resolver.cluster_spec().as_dict()['worker'])

tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
# TPU requires everything to be in a GCS bucket to work.

from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH) # what do gcs paths look like?

## Read Images 

In [None]:
# list files
train_files = tf.io.gfile.glob(f'{GCS_DS_PATH}/tfrecords-jpeg-512x512/train/*')
val_files = tf.io.gfile.glob(f'{GCS_DS_PATH}/tfrecords-jpeg-512x512/val/*')
test_files = tf.io.gfile.glob(f'{GCS_DS_PATH}/tfrecords-jpeg-512x512/test/*')

In [None]:
# file name contains the number of samples in that tf record. 00-224x224-798.tfrec -> contains 798 samples.
def get_num_samples(file_list):
    count = 0 
    for file_name in file_list:
        num_sample = int(file_name.split('.tfrec')[0].rsplit('-', 1)[1])
        count += num_sample
    return count

### Sizes of each dataset

In [None]:
train_size = get_num_samples(train_files)
val_size = get_num_samples(val_files)
test_size = get_num_samples(test_files)

print(f"Train dataset size: {train_size}")
print(f"Validation dataset size: {val_size}")
print(f"Test dataset size: {test_size}")

In [None]:
# Functions and 'classes' variable in this cell were taken from https://www.kaggle.com/code/ryanholbrook/create-your-first-submission/notebook

CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

def data_augment(image, label):
    # Thanks to the dataset.prefetch(AUTO)
    # statement in the next function (below), this happens essentially
    # for free on TPU. Data pipeline code is executed on the "CPU"
    # part of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    #image = tf.image.random_saturation(image, 0, 2)
    return image, label   

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

## Create TF data objects

In [None]:
train_dataset = tf.data.TFRecordDataset(train_files, num_parallel_reads=AUTO)
test_dataset = tf.data.TFRecordDataset(test_files, num_parallel_reads=AUTO)
val_dataset = tf.data.TFRecordDataset(val_files, num_parallel_reads=AUTO)

In [None]:
train_dataset = train_dataset.map(read_labeled_tfrecord)
train_dataset = train_dataset.map(data_augment, num_parallel_calls=AUTO)
train_dataset = train_dataset.shuffle(buffer_size=batch_size)
train_dataset = train_dataset.repeat()
train_dataset = train_dataset.batch(batch_size=batch_size)
train_dataset = train_dataset.prefetch(buffer_size=batch_size)

In [None]:
val_dataset = val_dataset.map(read_labeled_tfrecord)
val_dataset = val_dataset.batch(batch_size=batch_size)
val_dataset = val_dataset.prefetch(buffer_size=batch_size)

In [None]:
test_dataset = test_dataset.map(read_unlabeled_tfrecord)
test_dataset = test_dataset.prefetch(buffer_size=batch_size)

## Build model

In [None]:
# VGG19 model
def create_vgg19_model():
    base_model = VGG19(
        weights='imagenet',
        input_shape=(512, 512, 3),
        include_top=False
    )
    base_model.trainable = False
    inputs = tf.keras.Input(shape=(512, 512, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(4112, activation='relu')(x)
    x = layers.Dense(2056, activation='relu')(x)
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dense(512, activation='relu')(x)
    outputs = layers.Dense(len(CLASSES), activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])

    return model

In [None]:
# Xception model
def create_xception_model():
    base_model = Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(512, 512, 3),
    include_top=False)
    
    base_model.trainable = False
    inputs = tf.keras.Input(shape=(512, 512, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(4112, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(4112, activation='relu')(x)
    outputs = layers.Dense(len(CLASSES), activation='softmax')(x)
    model = tf.keras.Model(inputs, outputs)

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])

    return model

## Callbacks

In [None]:
cp_callback = ModelCheckpoint(filepath='flower_model_xception.hdf5',
                              monitor='val_sparse_categorical_accuracy',
                              save_freq='epoch', verbose=1, period=1,
                              save_best_only=True, save_weights_only=True)

early_stopping = EarlyStopping(monitor='val_sparse_categorical_accuracy',
                               verbose=1, patience=5)

## Train

In [None]:
with strategy.scope():
    batch_size = batch_size * strategy.num_replicas_in_sync
    steps_per_epoch = int(train_size / batch_size)*2
    model = create_xception_model()
    history = model.fit(
                train_dataset, 
                validation_data=val_dataset,
                epochs=num_epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=[cp_callback, early_stopping])

In [None]:
plt.plot(history.history['loss'], label='train_loss') 
plt.plot(history.history['val_loss'], label='val_loss')
plt.legend()
plt.show()

In [None]:
plt.plot(history.history['sparse_categorical_accuracy'], label='train_accuracy')
plt.plot(history.history['val_sparse_categorical_accuracy'], label='val_accuracy')
plt.legend() 
plt.show()

In [None]:
model = create_xception_model()
model.load_weights('flower_model_xception.hdf5')

In [None]:
results = {'id': [], 'label': []}

## Test

In [None]:
def predict(element):
    image = element[0]
    id_ = tf.keras.backend.get_value(element[1]).decode("utf-8")
    result = list(model.predict(np.array([image]))[0])
    max_pred = max(result)
    result = result.index(max_pred)
    results['id'].append(id_)
    results['label'].append(result) 

In [None]:
count = 0
for row in test_dataset: 
    image = row[0]
    id_ = tf.keras.backend.get_value(row[1]).decode("utf-8")
    result = list(model.predict(np.array([image]))[0])
    max_pred = max(result) 
    result = result.index(max_pred) 
    results['id'].append(id_)
    results['label'].append(result)
    count += 1 
    if (count % 500) == 0:
        print(f"Finished predicting {count} images") 

In [None]:
results_df = pd.DataFrame(results)
results_df.to_csv('submission.csv', index=False)