

<center><h2> On How to train A neural network for image Segmentation using Fast.ai and Transfer Learning</h2></center>


***

<center><img src="https://github.com/shadab4150/Aerial_drone_image_segmentation/raw/master/image_drone/drone1.png"></center>


<center><h3> Please Upvote if you like it. </h3></center>

## What is semantic segmentation ?

* Source: **https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html**

* **Semantic image segmentation is the task of classifying each pixel in an image from a predefined set of classes.**

***

In the following example, different entities are classified.

![kd](https://divamgupta.com/assets/images/posts/imgseg/image15.png?style=centerme)

***


In the above example, the pixels belonging to the bed are classified in the class “bed”, the pixels corresponding to the walls are labeled as “wall”, etc.

In particular, our goal is to take an image of size W x H x 3 and generate a W x H matrix containing the predicted class ID’s corresponding to all the pixels.

***
![kd](https://divamgupta.com/assets/images/posts/imgseg/image14.png?style=centerme)

***

Usually, in an image with various entities, we want to know which pixel belongs to which entity, For example in an outdoor image, we can segment the sky, ground, trees, people, etc.

## Importing useful libraries

In [None]:
from fastai import *
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.callbacks import *
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.image as immg
import gc
import numpy as np
import random
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

In [None]:
open_image('../input/semantic-drone-dataset/semantic_drone_dataset/original_images/001.jpg').data.shape

## Path to the dataset
* Images were too big from original dataset So, resized them in 2 size
* > 1800x1200
* > 600x400

In [None]:
path = Path('../input/drone-images-mask-resized/drone_data_small')  

In [None]:
path.ls()

In [None]:
fnames = get_files(path/'train_small')
fnames_mask = get_files(path/'label_small')

In [None]:
open_image(fnames[4]).data.shape

## Data PreProcessing

In [None]:
from tqdm.notebook import tqdm,tnrange

* Since each pixel belongs to a diffrent class below function counts total number of such classes

In [None]:
path_im = path/'train_small'
path_lb = path/'label_small'
get_y_fns = lambda x: path_lb/f'{x.stem}.png'       # Function to get masks for a image

In [None]:
fnames[30],get_y_fns(fnames[30])

In [None]:
def get_classes(fnames):
    class_codes=[]
    for i in tqdm(range(400)):
        class_codes += list(np.unique(np.asarray(Image.open(get_y_fns(fnames[i])))))
    return np.array(list(set(class_codes)))

In [None]:
# Run this once to get total classes if you want, other wise below cell gives total classes
codes = get_classes(fnames)  

In [None]:
codes = np.array(codes)
codes

In [None]:
sns.set_style('darkgrid')

## Function to show Drone with Mask

In [None]:
def drone_mask(f):  # f = file_name
  img_a = immg.imread(f)
  img_a_mask = immg.imread(get_y_fns(f))
  plt.figure(1,figsize=(20,8))
  plt.subplot(121)
  plt.imshow(img_a);plt.title('Raw Drone footage ');plt.axis('off')
  plt.subplot(122)
  plt.imshow(img_a,alpha=0.8);
  plt.imshow(img_a_mask,alpha=0.8);plt.title('Drone with  mask');plt.axis('off')
  plt.show()

## A sample Drone with Mask

In [None]:
for i in range(3):
    img_num = random.randint(10,200)
    drone_mask(fnames[img_num])

## Creating A DatabLock for the model

In [None]:
src=np.array([400,600])
#src=src//2
src

In [None]:
data = (SegmentationItemList.from_folder(path=path_im)  # Location from path
        .split_by_rand_pct(0.2)                          # Split for train and validation set
        .label_from_func(get_y_fns, classes=codes)      # Label from a above defined function
        .transform(get_transforms(), size=src//2, tfm_y=True)   # If you want to apply any image Transform
        .databunch(bs=4)                                   # Batch size  please decrese batch size if cuda out of memory
        .normalize(imagenet_stats))            # Normalise with imagenet stats

In [None]:
data.show_batch(rows=2,figsize=(20,10));

In [None]:
len(data.train_ds), len(data.valid_ds), data.c  

# Model

* **Metrics for Drone mask**


In [None]:
name2id = {v:k for k,v in enumerate(codes)}
void_code = -1

def drone_accuracy_mask(input, target):
    target = target.squeeze(1)
    mask = target != void_code
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
metrics = drone_accuracy_mask
wd=1e-2    # wd = weight decay

### Fastai's unet_learner
* Source [**Fast.ai**](www.fast.ai)

* This module builds a dynamic U-Net from any backbone **pretrained on ImageNet**, automatically inferring the intermediate sizes.

![kd](https://www.researchgate.net/profile/Alan_Jackson9/publication/323597886/figure/fig2/AS:601386504957959@1520393124691/Convolutional-neural-network-CNN-architecture-based-on-UNET-Ronneberger-et-al.png)

* **This is the original U-Net. The difference here is that the left part is a pretrained model.**

* **This U-Net will sit on top of an encoder ( that can be a pretrained model -- eg. resnet50 ) and with a final output of num_classes.**

In [None]:
arch = models.resnet34
learn = unet_learner(data, # DatBunch
                     arch, # Backbone pretrained arch
                     metrics = [metrics], # metrics
                     wd = wd, bottle=True, # weight decay
                     model_dir = '/kaggle/working/') # model directory to save

## Model Summary

## Finding a suitable learning rate for our model

* With help fast.ai **learning rate finder** function

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
gc.collect() # to clear the cache

In [None]:
callbacks = SaveModelCallback(learn, monitor = 'drone_accuracy_mask', every = 'improvement', mode='max', name = 'best_model' )

In [None]:
lr = 1e-3           # Learning Rate

In [None]:
learn.fit(10, lr,callbacks = [callbacks] )

In [None]:
learn.freeze()
learn.lr_find()
learn.recorder.plot()

In [None]:
gc.collect()

In [None]:
learn.load('best_model');
callbacks2 = SaveModelCallback(learn, monitor = 'drone_accuracy_mask', every = 'improvement', mode='max', name = 'best_model_ft' )

In [None]:
learn.unfreeze()
learn.fit_one_cycle(10,max_lr= slice(1e-5,1e-3/2),callbacks = [callbacks2] )

## Results 
* Intial dynamic unet on top of an encoder ( resnet34 pretrained = 'imagenet' ), trained for 30 epochs gave an **accuracy** of **80.00%** .

## To check results of our trained model

In [None]:
learn.show_results(rows = 4, figsize=(16,18))

In [None]:
learn.save('stage-1-big')  # saving the model 

## Export the model

In [None]:
learn.export('/kaggle/working/drone_mask.pkl')

### Load the model  and predict

* **Function to make a prediction and Overlap the Drone Images with Predicted Drone Mask**

In [None]:
def drone_predict(f):
    img = open_image(f).resize((3,200,300))
    mask = learn.predict(img)[0]
    _,axs = plt.subplots(1,3, figsize=(24,10))
    img.show(ax=axs[0], title='no mask')
    img.show(ax=axs[1], y=mask, title='masked')
    mask.show(ax=axs[2], title='mask only', alpha=1.)

## Prediction

In [None]:
for i in range(3):
    n = random.randint(20,200)
    drone_predict(fnames[n])