In [None]:
try:
    import captum
except:
    !pip install captum
    
try:
    import flask_compress
except:
    !pip install flask_compress

<hr style="border: solid 3px blue;">

# Introduction

![](https://64.media.tumblr.com/d56f4f22049bb23c05da31c28671dd96/tumblr_nr3bspVKPz1u2bcamo2_540.gifv)

Picture Credit: https://64.media.tumblr.com

**What is Herbarium?**
> A herbarium (plural: herbaria) is a collection of preserved plant specimens and associated data used for scientific study.
> 
> The specimens may be whole plants or plant parts; these will usually be in dried form mounted on a sheet of paper (called exsiccatae) but, depending upon the material, may also be stored in boxes or kept in alcohol or other preservative. The specimens in a herbarium are often used as reference material in describing plant taxa; some specimens may be types.

Memories of my childhood came to mind while doing this project. Gone are the days when I was traveling through the mountains and fields to do the homework given to me by school, collecting grass and flowers, putting them between bookshelves, and waiting for them to dry.
I did not know the species and names of the grasses and flowers that I collected at that time, but I had a pleasant memory of seeing dried grass and flowers.
Now that the world has changed a lot, I am surprised at the development of technology again to be able to distinguish their species through machine learning.

Now, let's start the project while remembering the good old days.

In [None]:
import numpy as np
import pandas as pd 
from captum.attr import IntegratedGradients,NoiseTunnel,GradientShap,Occlusion
from captum.attr import visualization as viz

from matplotlib.colors import LinearSegmentedColormap

from captum.insights import AttributionVisualizer, Batch
from captum.insights.attr_vis.features import ImageFeature

from fastai.vision.all import *
import albumentations
from random import randint
import seaborn as sns
import matplotlib.pyplot as plt
import json
import cv2

------------------------
# Setting Up

In [None]:
train_dir = '../input/herbarium-2022-fgvc9/train_images/'

with open("../input/herbarium-2022-fgvc9/train_metadata.json") as json_file:
    train_meta = json.load(json_file)

In [None]:
image_ids = [image["image_id"] for image in train_meta["images"]]
image_dirs = [train_dir + image['file_name'] for image in train_meta["images"]]
category_ids = [annotation['category_id'] for annotation in train_meta['annotations']]
genus_ids = [annotation['genus_id'] for annotation in train_meta['annotations']]

train_df = pd.DataFrame({
    "image_id" : image_ids,
    "image_dir" : image_dirs,
    "category" : category_ids,
    "genus" : genus_ids})

genus_map = {genus['genus_id'] : genus['genus'] for genus in train_meta['genera']}
train_df['genus'] = train_df['genus'].map(genus_map)

train_df.head().style.set_properties(**{'background-color': 'black',
                           'color': 'white',
                           'border-color': 'white'})

In [None]:
train_df.info()

<span style="color:Blue"> Observation:    
    
Wow! Dataset size is too large.

In [None]:
train_df['category'].nunique()

<span style="color:Blue"> Observation:    
    
The level of the target is also large.

In [None]:
cat_val_cnt = train_df['category'].value_counts()
cat_val_cnt

There are too many categories. Let's reduce the target (Category) level to reduce the learning time.

In [None]:
cat_index = cat_val_cnt[cat_val_cnt == 80].sort_values(ascending=False).index
cat_index = cat_index[:20]

In [None]:
herb_train_df = train_df[train_df.category.isin(cat_index)]
herb_train_df.info()

In [None]:
sns.set(style="ticks", context="talk",font_scale = 1)
plt.style.use("dark_background")

-------------------------------
# Augumentation

> Our dataset has a long-tail distribution. The number of images per taxon is as few as seven and as many as 100 images. 

As above, the dataset has a long-tail distribution. There may be many ways to solve this, but in this notebook, we decided to use data augumentation.
Let's use albumentations to augment the original data in various ways.

In [None]:
def get_train_aug(): return albumentations.Compose([
#             albumentations.RandomResizedCrop(300,300),
            albumentations.Transpose(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5)
])

In [None]:
class AlbumentationsTransform(DisplayedTransform):
    split_idx,order=0,2
    def __init__(self, train_aug): store_attr()
    
    def encodes(self, img: PILImage):
        aug_img = self.train_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

In [None]:
item_tfms = [Resize(400), AlbumentationsTransform(get_train_aug())]

----------------------------------
# Making Pipeline and Dataloaders

In [None]:
splits = RandomSplitter(valid_pct=0.2)
dls = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                get_x=ColReader(1),
                get_y=ColReader(2),
                splitter  = splits,
                item_tfms=item_tfms).dataloaders(herb_train_df,bs=16)

--------------------------------------------
# Showing Batch

In [None]:
dls.show_batch(max_n=16)

-----------------------------------------
# Modeling (Deeper and Deeper)

![](https://i.pinimg.com/originals/59/b6/34/59b634c891be23b5e235a6587e808dc3.gif)

Picture Credit: https://i.pinimg.com

It will depend on what you need to do, but in general, in deep learning, the deeper the layer, the better.
Considering weight decay, resnet is used in this notebook.

Also, use early stopping to prevent overfitting. Also, make a callback function to check the distribution of activations.
Let Metric use accuracy and tok_k_accuracy at the same time.

In [None]:
learn = cnn_learner(dls, 
                    xresnet18, 
                    metrics=[accuracy, top_k_accuracy],
                    cbs = [EarlyStoppingCallback(monitor='accuracy', min_delta=0.1, patience=5),ActivationStats(with_hist=True)])

**What is XResnet**
> 1. It applies a number of tricks that modify things like the training or model.
> 2. Heuristics to increase the parallelism of training and decrease the computational cost through lower precision computing and modifying the learning rate or biases
> 3. Tweaking the models by modifying the network architecture. They explore several modifications they call ResNet A, ResNet B, ResNet C and ResNet D. These modify the stride length in particular convolutional layers.
> 4. Training refinements to improve accuracy
> * Learning Rate Decay
> * Label Smoothing
> * Knowledge Distillation
> * Mixup Training
> * Transfer learning to see if they benefit from any downstream learning improvements to improve accuracy.

Ref: https://www.quora.com/What-is-the-difference-between-ResNeXt-XResnet-and-ResNet

In [None]:
learn.model

-------------------------------
# Finding the proper learning rate

In [None]:
plt.rcParams["figure.figsize"] = (8,6)
sr = learn.lr_find()
sr.valley

--------------------------------------------
# Training

![](https://miro.medium.com/max/973/1*nhmPdWSGh3ziatQKOmVq0Q.png)

Picture Credit: https://miro.medium.com

Let's learn until early stopping!

In [None]:
learn.fit_one_cycle(100,sr.valley)

-------------------------------------------
# Checking Activation

![](https://forums.fast.ai/uploads/default/optimized/3X/5/7/57a02a03d86a56561484aee9e88222ecbb7c1cf5_2_690x251.jpeg)

> The idea of the colorful dimension is to express with colors the mean and standard deviation of activations for each batch during training. Vertical axis represents a group (bin) of activation values. Each column in the horizontal axis is a batch. The colours represent how many activations for that batch have a value in that bin.

Ref: https://forums.fast.ai/t/the-colorful-dimension/42908

In [None]:
def plot_layer_stats(self, idx):
    plt,axs = subplots(1, 3, figsize=(15,3))
    plt.subplots_adjust(wspace=0.5)
    for o,ax,title in zip(self.layer_stats(idx),axs,('mean','std','% near zero')):
        ax.plot(o)
        ax.set_title(f"{-1*layer}th layer {title}")

In [None]:
for layer in range(1,4):
    plot_layer_stats(learn.activation_stats,-1*layer)

In [None]:
def color_dim(self, idx):
    with plt.rc_context({"figure.figsize": (10,40), "figure.dpi": (600)}):
        res = self.hist(idx)
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.imshow(res, origin='lower')
        ax.set_title(f"{idx}th activation histogram")
        ax.axis('off')

In [None]:
matplotlib.rcParams['image.cmap'] = 'rainbow_r'
for layer in range(1,4):
    color_dim(learn.activation_stats,-1*layer)

----------------------------------------------
# Checking Results

In [None]:
plt.rcParams["figure.figsize"] = (8,6)
learn.recorder.plot_loss()

In [None]:
learn.show_results(figsize=(20,20))

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(15,15),dpi=480)

In [None]:
interp.most_confused(min_val=5)

-----------------------------
# Interpreting

We want to know where we are looking in our model and which points we have decided are important.
This allows us to check whether the learning has gone well and whether there is more work to be done to further improve performance.

In [None]:
class CaptumInterpretation():
    "Captum Interpretation for Resnet"
    def __init__(self,learn,cmap_name='viridis',colors=None,N=256,methods=('original_image','heat_map'),
                 signs=("all", "positive"),outlier_perc=1):
        if colors is None: colors = [(0, '#ffffff'),(0.25, '#000000'),(1, '#000000')]
        store_attr()
        self.dls,self.model = learn.dls,self.learn.model
        self.supported_metrics=['IG','NT','Occl']

    def get_baseline_img(self, img_tensor,baseline_type):
        baseline_img=None
        if baseline_type=='zeros': baseline_img= img_tensor*0
        if baseline_type=='uniform': baseline_img= torch.rand(img_tensor.shape)
        if baseline_type=='gauss':
            baseline_img= (torch.rand(img_tensor.shape).to(self.dls.device)+img_tensor)/2
        return baseline_img.to(self.dls.device)

    def visualize(self,inp,metric='IG',n_steps=1000,baseline_type='zeros',nt_type='smoothgrad', strides=(3,4,4), sliding_window_shapes=(3,15,15)):
        if metric not in self.supported_metrics:
            raise Exception(f"Metric {metric} is not supported. Currently {self.supported_metrics} are only supported")
        print(inp)
        tls = L([TfmdLists(inp, t) for t in L(ifnone(self.dls.tfms,[None]))])
        inp_data=list(zip(*(tls[0],tls[1])))[0]
        enc_data,dec_data=self._get_enc_dec_data(inp_data)
        attributions=self._get_attributions(enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes)
        self._viz(attributions,dec_data,metric)

    def _viz(self,attributions,dec_data,metric):
        default_cmap = LinearSegmentedColormap.from_list(self.cmap_name,self.colors, N=self.N)
        _ = viz.visualize_image_attr_multiple(np.transpose(attributions.squeeze().cpu().detach().numpy(), (1,2,0)),
                                              np.transpose(dec_data[0].numpy(), (1,2,0)),
                                              methods=self.methods,
                                              cmap=default_cmap,
                                              show_colorbar=True,
                                              signs=self.signs,
                                              outlier_perc=self.outlier_perc, titles=[f'Original Image - ({dec_data[1]})', metric])



    def _get_enc_dec_data(self,inp_data):
        dec_data=self.dls.after_item(inp_data)
        enc_data=self.dls.after_batch(to_device(self.dls.before_batch(dec_data),self.dls.device))
        return(enc_data,dec_data)

    def _get_attributions(self,enc_data,metric,n_steps,nt_type,baseline_type,strides,sliding_window_shapes):
        # Get Baseline
        baseline=self.get_baseline_img(enc_data[0],baseline_type)
        supported_metrics ={}
        if metric == 'IG':
            self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
            return self._int_grads.attribute(enc_data[0],baseline, target=enc_data[1], n_steps=200)
        elif metric == 'NT':
            self._int_grads = self._int_grads if hasattr(self,'_int_grads') else IntegratedGradients(self.model)
            self._noise_tunnel= self._noise_tunnel if hasattr(self,'_noise_tunnel') else NoiseTunnel(self._int_grads)
            return self._noise_tunnel.attribute(enc_data[0].to(self.dls.device), n_samples=1, nt_type=nt_type, target=enc_data[1])
        elif metric == 'Occl':
            self._occlusion = self._occlusion if hasattr(self,'_occlusion') else Occlusion(self.model)
            return self._occlusion.attribute(enc_data[0].to(self.dls.device),
                                       strides = strides,
                                       target=enc_data[1],
                                       sliding_window_shapes=sliding_window_shapes,
                                       baselines=baseline)

In [None]:
path = '../input/herbarium-2022-fgvc9/train_images/002'
fnames = get_image_files(path)
splits = RandomSplitter(valid_pct=0.2)
dls = DataBlock(blocks    = (ImageBlock, CategoryBlock),
                get_items = get_image_files,
                get_y     = parent_label,
                splitter  = splits,
                item_tfms=item_tfms).dataloaders(path,bs=16)

In [None]:
learn = cnn_learner(dls, 
                    xresnet18, 
                    metrics=[accuracy, top_k_accuracy],
                    cbs = [EarlyStoppingCallback(monitor='accuracy', min_delta=0.1, patience=5),ActivationStats(with_hist=True)])

In [None]:
captum=CaptumInterpretation(learn,colors=['green','red','yellow'])
idx=randint(0,len(fnames))
captum.visualize(fnames[idx])

In [None]:
idx=randint(0,len(fnames))
captum.visualize(fnames[idx],metric='IG',baseline_type='uniform')

<hr style="border: solid 3px blue;">