# Analysis of Model flowerclass-efficientnetv2-2


In [None]:
import math, re, os
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
print(tf.__version__)
print(tfa.__version__)

from flowerclass_read_tf_ds import get_datasets
import tensorflow_hub as hub
import pandas as pd
import math
import plotly_express as px
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import itertools

In [None]:
tf.test.gpu_device_name()

# I. Data Loading

In [None]:
image_size = 224
batch_size = 64

In [None]:
#%%debug (50, 480)


In [None]:
class_names = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102
len(class_names)

# II. Model Loading: EfficientNetV2

In [None]:
effnet2_base = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/feature_vector/2"

In [None]:
    effnet2_tfhub = tf.keras.Sequential([
    # Explicitly define the input shape so the model can be properly
    # loaded by the TFLiteConverter
    tf.keras.layers.InputLayer(input_shape=(image_size, image_size,3)),
    hub.KerasLayer(effnet2_base, trainable=False),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(104, activation='softmax')
])
effnet2_tfhub.build((None, image_size, image_size,3,)) #This is to be used for subclassed models, which do not know at instantiation time what their inputs look like.


effnet2_tfhub.summary()

In [None]:
best_phase = 12
effnet2_tfhub.load_weights("../input/flowerclass-efficientnetv2-2/training/"+"cp-"+f"{best_phase}".rjust(4, '0')+".ckpt")

# III. Model Analysis

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

Ensure that validation data loader returns fixed order of elements.

In [None]:
ds_train, ds_valid, ds_test = get_datasets(BATCH_SIZE=batch_size, IMAGE_SIZE=(image_size, image_size), 
                                           RESIZE=None, tpu=False)

img_preds = []
img_labels = []
for imgs, label in tqdm(ds_valid):
    img_preds.append(effnet2_tfhub.predict(imgs, batch_size=batch_size))
    img_labels.append(label.numpy())
    
img_preds = np.concatenate([img_pred.argmax(1) for img_pred in img_preds])
img_labels = np.concatenate([img_label.argmax(1) for img_label in img_labels])


In [None]:
val_results = pd.DataFrame({'pred': img_preds, "label":img_labels})

In [None]:
val_results.head()

# IIIa) Overall Evaluation

In [None]:
confusion_matrix(val_results['label'], val_results['pred'])

In [None]:
print(classification_report(val_results['label'], val_results['pred'], target_names=class_names))

In [None]:
class_report = pd.DataFrame.from_dict(classification_report(val_results['label'], val_results['pred'], target_names=class_names, output_dict=True)).T

class_report['class'] = class_report.index
class_report= class_report.reset_index(drop=True)

class_report.head()

MOst problematic classes with f1 below 90:


> How would improving these classes raise the macro f1 score?

In [None]:
class_report = class_report.loc[:103] # remove the summary statistics, e.g. accuracy

In [None]:
class_report = class_report.sort_values("f1-score").reset_index(drop=True)

In [None]:
class_report.head(9)

> What is wrong with the rose class? bad performance despite many images

> * If we would improve all 8 worst-performing classes to f1 score of 1, it would still only raise performance by 1%! See below.
> * 


In [None]:
class_report_test = class_report.copy()
class_report_test.loc[:7, 'f1-score'] = 1
class_report_test['f1-score'].mean()

In [None]:
class_report_test.loc[:20, 'f1-score'] = 1
class_report_test['f1-score'].mean()

> * Improve first 20 classes would raise by another 1%.
> * It might be better to improve the overall performance of the model then trying to improve individual classes

> * Nevertheless continue with error analysis

In [None]:
class_report.head()

In [None]:
sns.displot(class_report['f1-score'], kde=False)

In [None]:
class_report['f1-score'].describe().to_frame().T

Group classes into a easy category (good performance) and bad performance.

In [None]:
class_report['difficulty'] = 'hard'
class_report.loc[8:, 'difficulty'] = class_report.loc[8:, 'f1-score'].apply(lambda x: 'easy' if x>0.969 else 'medium')

In [None]:
class_report.groupby("difficulty").agg(['mean', 'median'])

Hypothesis test with nonparametric Mann-Whitney U test to compare the samples with label easy and hard above:

In [None]:
import scipy
scipy.stats.mannwhitneyu(class_report.loc[class_report['difficulty'] == 'easy', 'support'],  
                        class_report.loc[class_report['difficulty'] == 'hard', 'support'])


> We cannot reject the null hypothesis that both samples, easy and hard, come from the same distribution. This means there is no evidence to reject the null hypothesis at the 5% level that the number of data points are a reason for the difference between the easy and and hard classes. 

### Common Errors

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

In [None]:
conf_matrix = confusion_matrix(val_results['label'], val_results['pred'])

In [None]:
#plot_confusion_matrix(confusion_matrix(val_results['label'], val_results['pred']), class_names)

In [None]:
conf_matrix.shape

### Confusion (matrix) of top 7 worst performing classes

In [None]:
class_names[:3]

In [None]:
conf_matrix

In [None]:
class_names_mapping = {value:key for key, value in  enumerate(class_names)}

In [None]:
class_names_mapping
class_report['idx'] = class_report['class'].map(class_names_mapping)

In [None]:
class_report.head(7)

In [None]:
worst_classes_FN = conf_matrix[class_report.loc[:7, "idx"]]
worst_classes_FN.shape

In [None]:
worst_classes_FN_sub = worst_classes_FN[:, worst_classes_FN.sum(0) > 0]
worst_classes_FN_sub.shape

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

def plot_confusion_matrix(cm, xclasses, yclasses, title_prefix, figsize=(16,8)):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    fig = plt.figure(figsize=figsize)
    plt.title(title_prefix+f" top {len(yclasses)} classes by f1 score")
    ax = plt.gca()
    cmap=plt.cm.Blues
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    #plt.colorbar(fraction=0.046, pad=0.04)
    
    tick_marks_y = np.arange(len(yclasses))
    tick_marks_x = np.arange(len(xclasses))
    plt.xticks(tick_marks_x, xclasses, rotation=45)
    plt.yticks(tick_marks_y, yclasses)

    for (j,i),label in np.ndenumerate(cm):
        ax.text(i,j,label,ha='center',va='center')
    
    plt.tight_layout()
    if not title_prefix=='FP':
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
    else:
        plt.xlabel('True label')
        plt.ylabel('Predicted label')

In [None]:
plot_confusion_matrix(worst_classes_FN_sub, xclasses= np.array(class_names)[worst_classes_FN.sum(0) > 0], 
                     yclasses=class_report.loc[:7, "class"], title_prefix='FN' )

### FN Results:
* globe-flower (true): only 1 confused with buttercup
* clematis: 2 confused with windflower and columbine
* canterbury bells: no FN
* mexican petunia: 1 confused with petunia, maybe label error?
* black-eyed susan: 5 confused with sunflower.
*  peruvian lily: 1 with lenten rose, one with rose. both are of type rose, by chance?
* rose: 2 with sunflower, 3 with commun tulip, 1 confused with baberton daisy, daisy, 2 sunflower, 1 lotus: mix ups spread among classes
* gazania: one tiger lily, 1 baberton daisy, 1 rose, 1 blanket flower.

In [None]:
worst_classes_FP = conf_matrix[:, class_report.loc[:7, "idx"]]
print(worst_classes_FP.shape)

worst_classes_conf_FP = worst_classes_FP[worst_classes_FP.sum(1) > 0, :]
worst_classes_conf_FP.shape

In [None]:
plot_confusion_matrix(worst_classes_conf_FP.T, xclasses= np.array(class_names)[worst_classes_FP.sum(1) > 0], 
                     yclasses=class_report.loc[:7, "class"], title_prefix='FP' )

### FP Results
* globe-flower: 1 confused with (true) lotus
* dematis: no wrong detection in other classes
* caterbury bells: confused with true balloon flower
* mexican petunia: confusesd iwth 1 true petunia and 1 true desert rose
* black-eyed susan: confused with 1 true daisy
* peruvian lily: confused with 1 true tiger lily,
* rose: confused with 1 true snapdragon, 1 true peruvian lily, 2  camation and other classes. algo thinks everything is a rose which could be due to the relatively larger amount of images for this class. 
* gazania: confused with 2 true marigold, 

> the differences between FN and FP confused classes for the top 8 indicates that the type of confusion of the algorithm might be of different nature.