# EDA on Flower Classification with TPU Competition Dataset

# I. Goals

Classify images of flowers in 104 different classes. Classical image classification problem. Distinguish flowers which might be very similar in forms and colors. Images are from 5 public datasets


* There appears to be a label hirarchy (flower type hirarchy. Some classes are very narrow, containing only a particular sub-type of flower (e.g. pink primroses) while other classes contain many sub-types (e.g. wild roses).

* Metrics: Macro-F1 score does not take class-imbalance into account
* Performance on public test set, there is no hidden set. Careful with overfitting

# II. Data Extraction

* Data is available local in Kaggle but also in a GC bucket. See below.
* n TFRecord format. 

* same data in different resolution?



In [None]:
! ls ../input/tpu-getting-started


In [None]:
import tensorflow as tf
print(tf.__version__)
import pandas as pd
import seaborn as sns

In [None]:
from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH) # what do gcs paths look like?

names from https://www.kaggle.com/ryanholbrook/create-your-first-submission

In [None]:
class_names = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102
len(class_names)

# III. Meet & Greet Data

For purpose of EDA focus partly on 512x512

## Categories of Flowers

In [None]:
class_names

Flowers exist in groups

In [None]:
categories = ['lily', 'rose', 'iris', 'tulip', 'daisy', 'poppy']

In [None]:
category_id_map = {name:i for i, name in enumerate(categories)}
id_count = max(category_id_map.values())
ids = []
for name in class_names:
    for cat in categories:
        if cat in name.split():
            ids.append(category_id_map[cat])
            break
    else:
        id_count +=1
        ids.append(id_count)

In [None]:
class_groups = pd.DataFrame(zip(class_names, ids), columns=['names', 'id'])
class_groups.groupby('id')['names'].apply(list).head(len(categories)).values

> Some flowers are of the same type/category and hence expect classification errors among them.

In [None]:
class_name_mapping = {i:name for i, name in enumerate(class_names)}


## Images

In [None]:
! ls ../input/tpu-getting-started


In [None]:
IMAGE_SIZE = [512, 512]

data_root = "../input/tpu-getting-started"
#data_root = GCS_DS_PATH

data_path = data_root + '/tfrecords-jpeg-512x512'


train_512 = tf.io.gfile.glob(data_path + '/train/*.tfrec')
val_512 = tf.io.gfile.glob(data_path + '/val/*.tfrec')
test_512 = tf.io.gfile.glob(data_path + '/test/*.tfrec') 
all_512 = [train_512, val_512, test_512]

16 files per set

In [None]:
[len(dset) for dset in all_512]

In [None]:
train_512

# IV. Univariate Analysis

In [None]:
len(class_names)

In [None]:
from tensorflow.data import Dataset, TFRecordDataset

In [None]:
record_sample = TFRecordDataset(train_512)

In [None]:
num_elements = 0
for element in record_sample:
    num_elements += 1
num_elements

Image loading pipeline. References

* https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle
* https://www.kaggle.com/ryanholbrook/create-your-first-submission

In [None]:

def decode_image(image_data):
    # images are encoded as jpg
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    #image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    depth = tf.constant(104)
    #one_hot_encoded = tf.one_hot(indices=label, depth=depth)
    
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = True # False # disable order, increase speed

    AUTO = tf.data.experimental.AUTOTUNE
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

In [None]:
ds_train_512 = load_dataset(train_512, labeled=True)
ds_val_512 = load_dataset(val_512, labeled=True)
ds_test_512 = load_dataset(test_512, labeled=False)


In [None]:
for b, l in ds_train_512:
    break

In [None]:
l

In [None]:
def get_ds_size(dataset, dtype='train'):
    num_elements = 0
    labels = []
    for img, label in dataset:
        num_elements += 1
        labels.append(label.numpy())
    print(f"{dtype}: number of images: {num_elements}")
    if dtype != 'test':
        return pd.Series([class_name_mapping[label] for label in labels])
    
ds_train_512_labels = get_ds_size(ds_train_512, dtype='train'), 
ds_val_512_labels = get_ds_size(ds_val_512, dtype='val')
get_ds_size(ds_test_512, dtype='test')

In [None]:
total = 12753 + 3712 + 7382
12753/total, 3712/total, 7382/total

> Test set is 2x the validation set in size

## Class Distribution

In [None]:
def get_class_distr(ds_labels):
    ds_dist = pd.concat([ds_labels.value_counts(), 
               ds_labels.value_counts(normalize=True)], axis=1)
    ds_dist.columns = ['counts', 'fraction']
    return ds_dist
ds_train_512_labeldist = get_class_distr(ds_train_512_labels[0])
ds_val_512_labeldist = get_class_distr(ds_val_512_labels)

In [None]:
ds_train_512_labeldist

In [None]:
ds_train_512_labeldist.head(20)

Problem Classes: 27 classes with less than 10 images in the training set!

In [None]:
ds_val_512_labeldist.tail(28)

In [None]:
problem_classes = ds_val_512_labeldist.tail(27).index
problem_classes

In [None]:
import numpy as np

In [None]:
for key in ds_val_512_labeldist.to_dict()['fraction'].keys():
    if not np.isclose(ds_val_512_labeldist.to_dict()['fraction'][key], 
                      ds_val_512_labeldist.to_dict()['fraction'][key]):
        print(f"{key} not close")
    


> * Classes are highly imbalanced, in fact some have only 18 images in train, and 5 images in valid set!
> * Majority class makes up only 6% of the whole data.
> * class distribution in train and valid is the same, as it should be.

In [None]:
ds_train_512_labeldist['counts'].plot(kind='hist')

In [None]:
sns.boxplot(x=ds_train_512_labeldist['counts'])

In [None]:
ds_train_512_labeldist['counts'].median()

> 9 classes have large number of images (outliers above, above ~280) while the median is 88 images per class

## Images Visual Analysis

In [None]:
from matplotlib import pyplot as plt
import math

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case,
                                     # these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is
    # the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(class_names[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                class_names[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None, FIGSIZE=13):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square
    # or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows + 1
        
    # size and spacing
    #FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else class_names[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
ds_train_512

### Random sample of flowers

In [None]:
ds_train_512 = load_dataset(train_512, labeled=True)

ds_train_512 = ds_train_512.batch(10)
bs = next(iter(ds_train_512))

display_batch_of_images(bs)

Findings

* details of background visible
* some images have slighly blurry flowers
* Images appear to stem from the outside, garden, nature etc.
* sometime its one flower, sometimes multiple
* picture angle on plant(s) seem fairly arbitray


In [None]:
ds_train_512 = load_dataset(train_512, labeled=True)

ds_train_512 = ds_train_512.batch(40)
bs = next(iter(ds_train_512))

display_batch_of_images(bs)

Analysis:
* Indoor plants are also possible!
* close up shots showing only part of the flower exist
* Not clear if flowers are in different blooming stages
* insects on flowers
* flowers within category can differ signiifcantly: geranium vs wild geranium

## Flowers by Class

In [None]:
#class_name_mapping

In [None]:
inverse_class_name_mapping = {class_name_mapping[i]: i for i in class_name_mapping}

In [None]:
from tqdm import tqdm
from numpy.random import default_rng


In [None]:
def display_batch_by_class(files, name = 'iris', top_n= 10, FIGSIZE=13):
    
    class_idx = inverse_class_name_mapping[name]
    print(class_idx)
    
    max_imgs_per_class = ds_val_512_labeldist.loc[name,'counts']
    
    if top_n > max_imgs_per_class:
        top_n = max_imgs_per_class
        print(f"warning, class has only {max_imgs_per_class} images. Show all images for class")
        
    
    # get position of class images in dataset
    sample_idx = []
    
    ds = load_dataset(files, labeled=True)
    ds = ds.batch(1)
    for i, (img, label) in tqdm(enumerate(ds)):
        if label.numpy()[0] == class_idx:
            sample_idx.append(i)
            
    # choose randomly top_n images
    rng = default_rng(42)
    sample_idx_shuffled = sample_idx.copy()
    rng.shuffle(sample_idx_shuffled)
    top_n_sample = sample_idx_shuffled[:top_n]

    ds = load_dataset(files, labeled=True)
    ds = ds.batch(1)
    # get thte images for each data point
    images_class = []
    tmp = []
    for i, (img, label) in tqdm(enumerate(ds)):
        if i in top_n_sample:
            images_class.append(img)
            tmp.append(label)

    batch = tf.stack([tf.squeeze(img) for img in images_class]), tf.stack([class_idx for i in range(len(images_class))])
    
    display_batch_of_images(batch, FIGSIZE=FIGSIZE)

#### Most Common Class: Iris

In [None]:
display_batch_by_class(train_512, name = 'iris', top_n= 10)

### Problem Classes

#### Siam Tulip - one of the least common classes with only 5 images

In [None]:
display_batch_by_class(train_512, name = 'siam tulip', top_n= 20)

* **Danger**: Is there something common in their background which could mislead the algorithm to use wrong features for identification? This class is especially prone due to the low number of images

In [None]:
display_batch_by_class(train_512, name = 'moon orchid', top_n= 20)

### Look at all problem classes

In [None]:
problem_classes

In [None]:
for class_name in problem_classes:
    display_batch_by_class(train_512, name = class_name, top_n= 10, FIGSIZE=6)

Analysis for problem classes:
* same shot, from front, one plant only: hard-leaved pocket
* there can be still large variation in color and shape for each image per class.
* some classes have few images with similar background

## Impact of Resolution

* How does the image attributes change when decreasing the resolution?
* Which features are not visible anymore?

In [None]:
files_all_res = [
    tf.io.gfile.glob(data_root + '/tfrecords-jpeg-512x512' + '/train/*.tfrec'),
    tf.io.gfile.glob(data_root + '/tfrecords-jpeg-331x331' + '/train/*.tfrec'),
    tf.io.gfile.glob(data_root + '/tfrecords-jpeg-224x224' + '/train/*.tfrec'),
    tf.io.gfile.glob(data_root + '/tfrecords-jpeg-192x192' + '/train/*.tfrec')
]
resolutions = [512, 331, 224, 192]


Compare impact of resolution on images by comparing the same image of flowers.
Unfortunately the images are not in the same order for different resolutions
and no unique flower id exists to link the images of different resolution.
Hence I pick one class for which I can plot all flower images.

In [None]:
for res, files in zip(resolutions, files_all_res):
    #ds = load_dataset(files, labeled=True)
    print(f"Image Resolution: {res}")
    display_batch_by_class(files, name = 'moon orchid', top_n= 20)
    #ds = ds.batch(1)
    #batch = next(iter(ds))
    #print(res)
    #display_batch_of_images(batch)


### Augmentation Strategy

* Images appear blurry: blurr
* Images are zoomed in and out: zoom in/out
* Images are brighter and darker