# Preface

This notebook file is the implementation of "SegNet" from the course "02456 Deep Learning" at DTU.
The project is cooperated with Cellari under the guidance of Prof. Ole (DTU Compute) and Peter (Cellari).
The main purpose is to segment the crops and weeds from a drone image.

**Notice:**
1. change the dataset path to your own!
2. change the trained model saving path to your own!

# Dataset file path

In [None]:
# the path of your train data
path_train_raw = '/home/renping/02456 DeepLearning Project/train/cropped/raw'
path_train_anno = '/home/renping/02456 DeepLearning Project/train/cropped/anno'

# the path of your test data
path_test_raw = '/home/renping/02456 DeepLearning Project/test/cropped/raw'
path_test_anno = '/home/renping/02456 DeepLearning Project/test/cropped/anno'

# Batch Generator for Training

## parameter in train batch generator

In [None]:
batches = 10
crop_size = 256
path_train = path_train_raw
path_anno = path_train_anno

In [None]:
import torch
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import transforms as T
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import imageio
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from time import time

# Functions to convert rgb segmaps to 2d

# Convert rgb array to grayscale
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

# Mapping for grayscale original segmap
mapping = {
    0: 0,
    76: 1,
    149: 2,
    225: 3
}

# Mapping for grayscale imgaug segmap
mapping2 = {
    0: 0,
    91: 1,
    132: 2,
    211: 3
}

# Convert original segmap to 2d with classes
def create_anno(anno):
    anno_reshaped = rgb2gray(anno).astype(int)
    for k in mapping:
        anno_reshaped[anno_reshaped==k] = mapping[k]
    return anno_reshaped

# Convert imgaug segmap to 2d with classes
def seg_to_anno(seg):
    anno = rgb2gray(seg).astype(int)
    for k in mapping2:
        anno[anno==k] = mapping2[k]
    return anno

# ============================================================================
# Image Augmentors
seq = iaa.Sequential([
    iaa.HorizontalFlip(0.5),
    iaa.Affine(rotate=(-180, 180)),
    # iaa.Dropout(p=(0, 0.1)),
    # iaa.Sharpen((0.0, 1.0)),
    # iaa.ElasticTransformation(alpha=50, sigma=5),
    iaa.CropToFixedSize(width=crop_size, height=crop_size)
], random_order=False)

# Class to load + process data
class Custom_Data(data.Dataset):
    def __init__(self, path_train, path_anno):

        # Get all raw + annotated images
        raw_img = os.listdir(path_train)
        anno_img = os.listdir(path_anno)
        raw_img.sort()
        anno_img.sort()
        raw_imgs = [os.path.join(path_train, img) for img in raw_img]
        anno_imgs = [os.path.join(path_anno, img) for img in anno_img]

        # Select only images with at least 3 classes represented
        raw_new, anno_new = [], []
        thresh = 512 * 512 / 2
        for i in range(len(anno_imgs)):
            segmap = create_anno(np.array(Image.open(anno_imgs[i])))
            if (len(np.unique(segmap))>=3 and len(segmap[segmap == 0])<thresh):
                raw_new = np.append(np.append(raw_new, raw_imgs[i]), raw_imgs[i])
                anno_new = np.append(np.append(anno_new, anno_imgs[i]), anno_imgs[i])

        # print(len(raw_imgs), len(raw_new))
        self.raw_img = raw_new
        self.anno_img = anno_new

    def __getitem__(self, index, plots=False):

        # Get raw img + segmap at index
        raw_img_path = self.raw_img[index]
        raw_img = np.array(Image.open(raw_img_path))
        anno_img_path = self.anno_img[index]
        anno_img = np.array(Image.open(anno_img_path))

        # Format segmap
        anno_img = create_anno(anno_img).astype('int32')
        seg_map = SegmentationMapsOnImage(anno_img, shape=anno_img.shape)

        # print(raw_img_path, anno_img_path)

        # Perform data augmentations to generate 2 sets of augmented data
        raw_aug, seg_aug = seq(image=raw_img, segmentation_maps=seg_map)
        anno_aug = seg_aug.draw()[0]
        anno_aug = seg_to_anno(anno_aug)

        # Plot images to compare
        if plots:
            plt.figure()
            plt.subplot(2,2,1)
            plt.imshow(raw_img)
            plt.subplot(2,2,2)
            plt.imshow(anno_img)
            plt.subplot(2,2,3)
            plt.imshow(raw_aug)
            plt.subplot(2,2,4)
            plt.imshow(anno_aug)
            plt.show()

        return (raw_aug, anno_aug)

    def __len__(self):
        return len(self.raw_img)



train_data = Custom_Data(path_train, path_anno)
train_loader = DataLoader(train_data, batch_size=batches, shuffle=True)

## training dataloder shape check

In [None]:
train_iter = iter(train_loader)
train_input, train_target = train_iter.next()
print("train_input size = {}.".format(train_input.size()))
print("train_target size = {}.".format(train_target.size()))

# Batch Generator for Test

## parameter in test batch generator

In [None]:
path_test_input = path_test_raw
path_test_target = path_test_anno

num_batch = 7

In [None]:
class Custom_Data(data.Dataset):
    def __init__(self, path_test_input, path_test_target):
        input = os.listdir(path_test_input)
        self.input = sorted([os.path.join(path_test_input, a) for a in input])

        target = os.listdir(path_test_target)
        self.target = sorted([os.path.join(path_test_target, a) for a in target])

        self.mapping = {
            0:0,
            76:1,
            149:2,
            225:3
        }

    def mapping_to_class(self, target):
        for k in self.mapping:
            target[target==k] = self.mapping[k]
        return target

    def __getitem__(self, index):
        # test input data
        test_input_path = self.input[index]
        test_input = Image.open(test_input_path)
        # test_input = T.RandomCrop(256)(test_input)
        test_input = T.ToTensor()(test_input)

        # test target data
        test_target_path = self.target[index]
        test_target = Image.open(test_target_path)
        # test_target = T.RandomCrop(256)(test_target)
        test_target = T.ToTensor()(test_target)
        test_target_RGB = test_target

        # mapping target to class index
        test_target = T.ToPILImage()(test_target).convert("L")
        test_target_grey = torch.from_numpy(np.array(test_target))
        test_target = self.mapping_to_class(test_target_grey)
    
        return test_input, test_target, test_target_RGB, test_target_grey, test_input_path, test_target_path

    def __len__(self):
        return len(self.input)

test_data = Custom_Data(path_test_input, path_test_target)
test_loader = DataLoader(test_data, batch_size=num_batch, shuffle=True)

## test dataloader shape check

In [None]:
test_iter = iter(test_loader)
test_input, test_target, test_target_RGB, test_target_grey, test_input_path, test_target_path = test_iter.next()
print("test_input size = {}.".format(test_input.size()))
print("test_target size = {}.".format(test_target.size()))
print("test_target_grey size = {}.".format(test_target_grey.size()))

# Network Definition

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class Encode_conv_bn_x2(nn.Module):
    def __init__(self, in_, out):
        super(Encode_conv_bn_x2, self).__init__()
        batchNorm_momentum = 0.1
        self.relu    = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_, out, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)
        self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x    
    
class Encode_conv_bn_x3(nn.Module):
    def __init__(self, in_, out):
        super(Encode_conv_bn_x3, self).__init__()
        batchNorm_momentum = 0.1
        self.relu    = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_, out, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)
        self.conv2 = nn.Conv2d(out, out, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)
        self.conv3 = nn.Conv2d(out, out, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        return x

class Dencode_conv_bn_x1(nn.Module):
    def __init__(self, in_, out):
        super(Dencode_conv_bn_x1, self).__init__()
        batchNorm_momentum = 0.1
        self.relu    = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_, in_, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_, momentum= batchNorm_momentum)   
        self.conv2= nn.Conv2d(in_, out, kernel_size=3, padding=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        return x
    
class Dencode_conv_bn_x2(nn.Module):
    def __init__(self, in_, out):
        super(Dencode_conv_bn_x2, self).__init__()
        batchNorm_momentum = 0.1
        
        self.relu    = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_, in_, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_, momentum= batchNorm_momentum)
        
        self.conv2= nn.Conv2d(in_, out, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class Dencode_conv_bn_x3(nn.Module):
    def __init__(self, in_, out):
        super(Dencode_conv_bn_x3, self).__init__()
        batchNorm_momentum = 0.1
        
        self.relu    = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_, in_, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_, momentum= batchNorm_momentum)
        
        self.conv2 = nn.Conv2d(in_, in_, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_, momentum= batchNorm_momentum)
        
        self.conv3 = nn.Conv2d(in_, out, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(out, momentum= batchNorm_momentum)
        

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        return x

class SegNet(nn.Module):
    def __init__(self,input_nbr = 3,label_nbr = 4):
        '''
        input_nbr: the number of channels of input image;
        label_nbr: the number of classes need to be segmented
        '''
        super(SegNet, self).__init__()
        batchNorm_momentum = 0.1
        
        self.encode1=Encode_conv_bn_x2(input_nbr,64)
        self.encode2=Encode_conv_bn_x2(64,128)
        self.encode3=Encode_conv_bn_x3(128,256)
        self.encode4=Encode_conv_bn_x3(256,512)
        self.encode5=Encode_conv_bn_x3(512,512)
        
        self.dencode5=Dencode_conv_bn_x3(512,512)
        self.dencode4=Dencode_conv_bn_x3(512,256)
        self.dencode3=Dencode_conv_bn_x2(256,128)
        self.dencode2=Dencode_conv_bn_x2(128,64)
        self.dencode1=Dencode_conv_bn_x1(64,label_nbr)
        
    def forward(self, x):
        # Stage 1
        x1=F.relu(self.encode1(x))
        self.x1p, self.id1 = F.max_pool2d(x1,kernel_size=2, stride=2,return_indices=True)

        # Stage 2
        x2=F.relu(self.encode2(self.x1p))
        self.x2p, self.id2 = F.max_pool2d(x2,kernel_size=2, stride=2,return_indices=True)

        # Stage 3
        x3=F.relu(self.encode3(self.x2p))
        self.x3p, self.id3 = F.max_pool2d(x3,kernel_size=2, stride=2,return_indices=True)

        # Stage 4
        x4=F.relu(self.encode4(self.x3p))
        self.x4p, self.id4 = F.max_pool2d(x4,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        x5=F.relu(self.encode5(self.x4p))
        self.x5p, self.id5 = F.max_pool2d(x5,kernel_size=2, stride=2,return_indices=True)
        
        # Stage 5d
        x5 = F.max_unpool2d(self.x5p, self.id5, kernel_size=2, stride=2)
        x5=F.relu(self.dencode5(x5))

        # Stage 4d
        x4= F.max_unpool2d(x5, self.id4, kernel_size=2, stride=2)
        x4=F.relu(self.dencode4(x4))
        
        
        # Stage 3d
        x3= F.max_unpool2d(x4, self.id3, kernel_size=2, stride=2)
        x3=F.relu(self.dencode3(x3))

        # Stage 2d
        x2= F.max_unpool2d(x3, self.id2, kernel_size=2, stride=2)
        x2=F.relu(self.dencode2(x2))

        # Stage 1d
        x1 = F.max_unpool2d(x2, self.id1, kernel_size=2, stride=2)
        x1=self.dencode1(x1)
        return x1

net = SegNet()

## network summary

In [None]:
from torchsummary import summary
net = net.cuda()
summary(net,(3, 256, 256), batch_size=5)

# Training and Testing

## parameter in train loop

In [None]:
num_epoch = 100
learning_rate=0.0001
model_path = '/home/renping/02456 DeepLearning Project/net_trained/net_trained.pt'

In [None]:
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm


# mapping = {
#     0: 0, # no data
#     76: 1, # soil
#     149: 2, # crops
#     225: 3 # weeds
# }


# initialize the net, loss, optimizer
use_cuda = torch.cuda.is_available()
# use_cuda = False
if use_cuda:
    print("Running on GPU!")
    net = net.cuda()
    criterion = nn.CrossEntropyLoss().cuda()
else:
    print("Running on CPU!")
    net = net
    criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(net.parameters(), lr=learning_rate)


print("Start training!")
print("Total training epoch: {}.".format(num_epoch))

total_loss = []
total_accuracy = []

total_loss_test = []
total_accuracy_test = []

for epoch in tqdm(range(0, num_epoch)):

    train_epoch_loss = 0.0
    net.train()
    for i, train_data in enumerate(train_loader):
        inputs, targets = train_data
        inputs = inputs.permute(0,3,1,2)
        inputs = inputs.type(torch.FloatTensor)
        targets = targets.type(torch.LongTensor)
        if use_cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()
        # print("inpusts size = {}".format(inputs.size()))
        # print("targets size = {}".format(targets.size()))
        optimizer.zero_grad()
        # forward
        outputs = net.forward(inputs)
        # print("outputs size = {}".format(outputs.size()))
        loss = criterion(outputs, targets)

        # backward
        loss.backward()
        optimizer.step()

        train_epoch_loss = train_epoch_loss + loss.item()
        
    total_loss.append(train_epoch_loss/88)

    test_epoch_loss = 0.0
    net.eval()
    for j, test_data in enumerate(test_loader):
        test_input, test_target, _, _, _, _ = test_data
        test_input = test_input.type(torch.FloatTensor)
        test_target = test_target.type(torch.LongTensor)
        if use_cuda:
            test_input = test_input.cuda()
            test_target = test_target.cuda()
        with torch.no_grad():
            test_output = net.forward(test_input)
            loss_test = criterion(test_output, test_target)

        test_epoch_loss = test_epoch_loss + loss_test.item()

    total_loss_test.append(test_epoch_loss/35)

    # print train epoch loss every epoch
    print("epoch {}/{}, train loss {}".format(epoch+1, num_epoch, train_epoch_loss/88))

    # print test epoch loss every epoch
    print("epoch {}/{}, test loss {}".format(epoch+1, num_epoch, test_epoch_loss/35))


torch.save(net.state_dict(), model_path)

print("Trained Model Saved!")
print("Training Finished!")


plt.figure()
plt.title("loss")
plt.plot(total_loss, label="train loss")
plt.plot(total_loss_test, label= "test loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend(loc=1)
plt.grid()


plt.savefig("loss")
plt.show()
print("loss figure saved!")

# Model Evaluation

In [None]:
import random
import glob
import matplotlib.patches as mpatches
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
import seaborn as sbn


# mapping = {
#     0: 0, # no data
#     76: 1, # soil
#     149: 2, # crops
#     225: 3 # weeds
# }

net.load_state_dict(torch.load(model_path))
net.eval()
use_cuda = torch.cuda.is_available()
if use_cuda:
    print("Running on GPU!")
    net = net.cuda()

path_to_images = path_test_raw
path_to_annotations = path_test_anno

img_list = glob.glob(path_to_images+os.sep+"*.png")
img_list = sorted(img_list)
anno_list = glob.glob(path_to_annotations+os.sep+"*.png")
anno_list = sorted(anno_list)


preds = []
for rand_int in range(len(img_list)):
    # rand_int = random.randint(0, len(img_list)-1)
    # rand_int = random.randint(0, len(img_list)-1)

    raw_img = np.array(Image.open(img_list[rand_int]))
    show_anno = np.array(Image.open(anno_list[rand_int]).convert("L"))


    show_raw = raw_img

    img = torch.from_numpy(raw_img).type(torch.FloatTensor).cuda()
    img = img.permute(2,0,1)
    # print(img.shape)

    img = img.unsqueeze(0)

    # print(img.shape)

    pred = net(img)

    # print(pred.shape)

    pred = pred.squeeze()
    # print(pred.shape)

    pred_img = torch.argmax(pred, dim=0).cpu().numpy()
    # print("{}th image:{} ".format(rand_int+1, np.unique(pred_img)))
    preds.append(pred_img)

# Utilities
def color_mapping(img):
    mapping_rev = {
    0:0, # no data
    76:1, # soil
    149:2, # crops
    225:3 # weeds
    }
    for item in mapping_rev:
        img[img==item] = mapping_rev[item]
    return img

annotations_path = path_test_anno
annotations_list = sorted(glob.glob(annotations_path+os.sep+'*.png'))

annos = []
for file in annotations_list:
    anno = np.array(Image.open(file).convert('L'))
    anno = color_mapping(anno)
    annos.append(anno)
annos = np.array(annos)
print('Annotations shape: {}'.format(annos.shape))


classes = ['No data', 'Soil', 'Crops', 'Weeds']
colors = [(1.0, 1.0, 0.8980392156862745, 1.0), # no data
         (0.7359477124183007, 0.8915032679738563, 0.5843137254901961, 1.0), # soil
         (0.21568627450980393, 0.6196078431372549, 0.330718954248366, 1.0), # crops
         (0.0, 0.27058823529411763, 0.1607843137254902, 1.0) # weeds
         ]
patches = [mpatches.Patch(color=colors[i], label="{l}".format(l=classes[i])) for i in range(len(classes))]


fig, ax = plt.subplots(2, 2, figsize=(10, 10))

ex1 = 16
ex2 = 10

# Anno ex 1
ax[0][0].imshow(annos[ex1], cmap=plt.cm.YlGn, vmin = 0, vmax = 3)
ax[0][0].set_title('Annotation example 1')

# Anno ex 2
ax[1][0].imshow(annos[ex2], cmap=plt.cm.YlGn, vmin = 0, vmax = 3)
ax[1][0].set_title('Annotation example 2')

# Pred ex 1
ax[0][1].imshow(preds[ex1], cmap=plt.cm.YlGn, vmin = 0, vmax = 3)
ax[0][1].set_title('Prediction example 1')

# Pred ex 2
ax[1][1].imshow(preds[ex2], cmap=plt.cm.YlGn, vmin = 0, vmax = 3)
ax[1][1].set_title('Prediction example 2')

fig.legend(handles=patches, loc='right')
# plt.show()

preds = np.array(preds)

cm_preds = np.reshape(preds, preds.shape[0]*preds.shape[1]*preds.shape[2])
cm_annos = np.reshape(annos, annos.shape[0]*annos.shape[1]*annos.shape[2])
from sklearn.utils.class_weight import compute_sample_weight
weights = compute_sample_weight(class_weight='balanced', y=cm_preds)
cm = confusion_matrix(cm_annos, cm_preds, sample_weight=weights)


plt.figure()
sbn.heatmap(cm, xticklabels=classes, yticklabels=classes) # annot=True prints the actual numbers in the heatmap matrix
plt.xlabel('Annotation')
plt.ylabel('Prediction')

plt.savefig("confusion matrix")
plt.show()
print("confusion matrix figure saved!!")

f1 = 0
for idx in range(len(annos)):
     for classes in np.unique(annos[idx]):
         if classes == 0:
             continue
     y_pred = (preds[idx]==classes)*1
     y_true = (annos[idx]==classes)*1
     f1 +=f1_score(y_true.reshape(-1), y_pred.reshape(-1), average='macro')

f1 /= len(annos)
print("F1 score: {}".format(f1))