<a href="https://colab.research.google.com/github/sarayaghoubi/sara/blob/master/ImageSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Mount the drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


download the data

In [None]:
# !wget https://zenodo.org/record/5706578/files/Train.zip -P /content/drive/MyDrive/data/Train
# !unzip /content/drive/MyDrive/data/Train/Train.zip -d /content/drive/MyDrive/data/Train/

#!wget https://zenodo.org/record/5706578/files/Val.zip -P /content/drive/MyDrive/data/Val
#!unzip /content/drive/MyDrive/data/Val/Val.zip -d /content/drive/MyDrive/data/Val/

# !wget https://zenodo.org/record/5706578/files/Test.zip -P /content/drive/MyDrive/data/Test
# !unzip /content/drive/MyDrive/data/Test2/Test.zip -d /content/drive/MyDrive/data

install libraries, dependencies

In [None]:
!pip install -U torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -e .
import mmseg
address = '/content/drive'

to handle the data, create the class, resize the pictures

In [None]:
# Last update was on  1 May
import os.path as osp
import numpy as np
from PIL import Image
import os
import cv2
import mmcv
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset


def annotate(img, label):
    """
    The reason why this method was built this method was to create an annotation map
    if the output was meant to be in GrayScale this function must be called
    :param img: the image
    :param label: the colors was used in the segmentation maps
    :return: 2D array that has mapped the colors to the classes' id
    """
    size = img.shape
    annotation = np.zeros(size[0:2])
    i = 0
    for color in label:
        z = np.where(img == color)[0:2]
        annotation[z] = i
        i += 1
    return annotation


class Data:
    """
    So this class was mainly intended to deal with our custom dataset, since it had originally
    images with different size, the output needed to be reconstructed
    Arguments:
    """

    def __init__(self):
        self.std = None
        self.mean = None
        self.n_classes = None
        self.classes = []
        self.palette = None
        self.data_root = root
        self.test_root = data_spc['test_address']
        self.validation_root, self.val_annotation = data_spc['validation_address']
        self.img_postfix, self.img_dir = data_spc['train_image']
        self.ann_postfix, self.ann_dir = data_spc['train_annotation']
        self.data_type = 'AerialDataset'

    def process_data(self, new_folder, size, require_resize):
        """
        first it will resize the image to make sure all the images are in the same shape
        then find the palette
        secondly, find the mean and the std of input images
        Arguments:
            new_folder : the folder that all the new images (after resizing) will be saved
            size : must be equal to the backbone-input image
            require_resize : boolean; if the images were already resized and prepared
        ** please be noticed that the

        """
        if require_resize:
            self.resize(new_folder, size)
            self.img_dir = new_folder
        self.palette, self.classes = self.spc()
        self.mean, self.std = self.statistics()
        self.n_classes = len(self.palette)

    @staticmethod
    def spc():
        color =[[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
                 [159, 129, 183], [0, 255, 0], [255, 195, 128]]
        classes = ('background', 'building', 'road', 'water', 'barren', 'forest',
                   'agricultural')
        return color, classes

    @staticmethod
    def statistics():  # as each image we have has different shape
        """

        :return:  the mean and standard deviation of the images (these will be required for the config file
        """
        return [123.675, 116.28, 103.53], [58.395, 57.12, 57.375]

    def resize(self, new_directory, size):
        label, classes = self.spc()
        for file in os.listdir(osp.join(self.data_root, self.img_dir)):
            address = osp.join(self.data_root, self.img_dir, file)
            img = cv2.resize(np.array(Image.open(address).convert('RGB'), dtype=np.uint8), size,
                             interpolation=cv2.INTER_NEAREST)
            if file.endswith('jpg'):
                img = Image.fromarray(img).convert('RGB')
            if file.endswith('png'):
                img = annotate(img, label)
                img = Image.fromarray(img).convert('P')
            img.save(osp.join(self.data_root, new_directory, file))

    def reconstruct(self, data_config, ignore):
        self.process_data(data_config['new_directory'], data_config['size'], data_config['resize'])
        split_dir = 'splits'
        mmcv.mkdir_or_exist(osp.join(self.data_root, split_dir))
        write_file = {
            'train': self.img_dir,
            'test': self.test_root,
            'val': self.validation_root
        }
        for file in write_file:
            filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
                osp.join(self.data_root, write_file[file]), suffix='.png')]
            with open(osp.join(self.data_root, split_dir, f'{file}.txt'), 'w') as f:
                f.writelines(line + '\n' for line in filename_list)
        classes = self.classes
        palette = self.palette
        if not ignore:
            @DATASETS.register_module()
            class AerialDataset(CustomDataset):
                CLASSES = classes
                PALETTE = palette

                def __init__(self, split, **kwargs):
                    super().__init__(img_suffix='.png', seg_map_suffix='.png',
                                     split=split, **kwargs)
                    assert osp.exists(self.img_dir) and self.split is not None


Config file
includes address of dataset or checkpoint files, learning parameters

In [None]:
import os.path as pt


root = '/content'
mydrive = f'{root}/drive/MyDrive'
output_address = f'{mydrive}/dlv3-res'
img = 'images_png'
ann = 'masks_png'
train = 'Train'
test = 'Test'
saved_checkpoints_deeplabv3plus = f'{mydrive}/deeplabv3plus.pth'
saved_checkpoints_vit = f'{mydrive}/Vit.pth'

checkpoint_files = {
    'Vit' : saved_checkpoints_vit,
    'dlv3' : saved_checkpoints_deeplabv3plus

}
data_subfolders = f'{mydrive}/data'
branch = 'Rural'
data_spc = {
    'size': (1024, 1024),
    'resize': False,
    'train_annotation': ('png', pt.join(data_subfolders, train,branch, ann)),
    'train_image': ('png', pt.join(data_subfolders, train,branch, img)),
    'new_directory': '',
    'test_address': (pt.join(data_subfolders, test,branch, img)),
    'validation_address': (pt.join(data_subfolders, train,branch, img), pt.join(data_subfolders, train,branch, ann))
}


#root = '/home/aminre/Optics/optics-segmentation/dataset/splited'
#img = 'images_png'
#ann = 'masks_png'
#train = 'Train'
#test = 'Test'
#validation = 'Val'
#rt = '/home/aminre/Optics/optics-segmentation/dataset/splited'
config_files = {
    'Vit': '/content/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py',
    'Swin': '/content/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py',
    'dlv3': '/content/mmsegmentation/configs/deeplabv3plus/deeplabv3plus_r18b-d8_512x1024_80k_cityscapes.py'
}

color =[[0,0,0],[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
                 [159, 129, 183], [0, 255, 0], [255, 195, 128]]
classes = ('null','background', 'building', 'road', 'water', 'barren', 'forest',
                   'agricultural')
train_spc = {
    'scale': (1024, 1024),
    'crop_size':(512,512),# (128, 128),
    'batch': 6,
    'max': 500000,  # it only accepts either max_epochs-> for epoch based or max_iter for iter based
    'log_int': 1000,
    'eval': 1000,
    'checkpoint': 1000,
    'lr_rate' : 0.001,# initital : 300 epoch  with 500 epoch with 
    'work_directory': '',
    'train_type' : 'EpochBasedRunner',
    'pretrain':saved_checkpoints_deeplabv3plus,#pt.join(root, 'models', 'VitFullCrossEntropy/epoch_210.pth'),  # if you want to start
    # training the model from scratch set it to None
    'lr_mul':10,# 10,  # if you want to train your model's decoder with a  higher learning rate
    'loss_a': dict(
            type='DiceLoss', use_sigmoid=False, loss_weight=0.5),
    'loss_d': dict(
            type='DiceLoss', use_sigmoid=False, loss_weight=0.5),# default loss is cross entropy, other options are: 1- 'DiceLoss' 2- 'FocalLoss' 3- LovaszLoss
}


In [None]:
from mmseg.apis import set_random_seed
from mmcv import Config
import mmcv
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
from mmseg.apis import inference_segmentor,show_result_pyplot
import matplotlib.patches as m_patches
import os.path as osp
import numpy as np
from matplotlib import pyplot as plt
import os
import torch.distributed as dist
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)
from mmcv.cnn.utils import revert_sync_batchnorm
from mmseg.apis import init_segmentor


class MMSEGAerialAnalysis:
    def __init__(self, data_config):
        self.cfg = None
        self.model = None
        self.data = Data()

    def init_data(self, *args):
        self.data.reconstruct(*args)

    def prepare_config(self, config_dir, train):
        cfg = Config.fromfile(config_dir)
        # Since we use ony one GPU, BN is used instead of SyncBN
        cfg.model.decode_head.num_classes = self.data.n_classes
        cfg.model.auxiliary_head.num_classes = self.data.n_classes
        cfg.norm_cfg = dict(type='BN')
        cfg.model.decode_head.norm_cfg = dict(type='BN')
        cfg.model.auxiliary_head.norm_cfg = dict(type='BN')

        cfg.dataset_type = self.data.data_type
        cfg.data_root = self.data.data_root
        cfg.img_norm_cfg = dict(
            mean=self.data.mean, std=self.data.std)
        cfg.crop_size = train['crop_size']
        cfg.train_pipeline = [
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations', reduce_zero_label=True),
            dict(type='Resize', img_scale=train['scale'], ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg']),
        ]
        cfg.test_pipeline = [
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=train['scale'],
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(type='Normalize', **cfg.img_norm_cfg),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img']),
                ])
        ]
        cfg.data.train.type = cfg.dataset_type
        cfg.data.train.data_root = cfg.data_root
        cfg.data.train.img_dir = self.data.img_dir
        cfg.data.train.ann_dir = self.data.ann_dir
        cfg.data.train.pipeline = cfg.train_pipeline
        cfg.data.train.split = 'splits/train.txt'

        cfg.data.val.type = cfg.dataset_type
        cfg.data.val.data_root = cfg.data_root
        cfg.data.val.img_dir = self.data.validation_root
        cfg.data.val.ann_dir = self.data.val_annotation
        cfg.data.val.pipeline = cfg.test_pipeline
        cfg.data.val.split = 'splits/val.txt'

        cfg.data.test.type = cfg.dataset_type
        cfg.data.test.data_root = cfg.data_root
        cfg.data.test.img_dir = self.data.validation_root
        cfg.data.test.ann_dir = self.data.val_annotation
        cfg.data.test.pipeline = cfg.test_pipeline
        cfg.data.test.split = 'splits/val.txt'
        cfg.load_from = None
        # Set up working dir to save files and logs.

        cfg.model.decode_head.loss_decode = train['loss_d']
        cfg.model.auxiliary_head.loss_decode = train['loss_a'] # ,dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.5)
        cfg.runner = dict(type=train['train_type'], max_epochs=train['max'])
        cfg.log_config.interval = train['log_int']
        cfg.evaluation.interval = train['eval']
        cfg.checkpoint_config.interval = train['checkpoint']
        cfg.optimizer.lr = train['lr_rate']
        cfg.optimizer.paramwise_cfg = dict(
            custom_keys={
                'head': dict(lr_mult=train['lr_mul'])})
        cfg.work_dir = train['work_directory']
        cfg.data.workers_per_gpu = 1
        cfg.data.samples_per_gpu = train['batch']
        # Set seed to facilitate reproducing the result
        cfg.seed = 0
        set_random_seed(0, deterministic=False)
        cfg.gpu_ids = [0]
        if train['pretrain'] is not None:
            cfg.load_from = train['pretrain']
        else:
            cfg.model.pretrained = None
        self.cfg = cfg

    def init_model(self):
        self.model = build_segmentor(
            self.cfg.model, train_cfg=self.cfg.get('train_cfg'), test_cfg=self.cfg.get('test_cfg'))
        self.model.CLASSES = self.data.n_classes
        self.model.PALETTE = self.data.palette

    def train_model(self, init):
        if init:
            dist.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=0, world_size=1)
        datasets = [build_dataset(self.cfg.data.train)]
        mmcv.mkdir_or_exist(osp.abspath(self.cfg.work_dir))
        train_segmentor(self.model, datasets, self.cfg, distributed=False, validate=True,
                        meta=dict())

    def test_model(self, img_direction, checkpoint, cfg, classes, pallete,model_name):
        cfg = Config.fromfile(cfg)
        cfg.model.train_cfg = None
        cfg.model.decode_head.num_classes = len(classes)
        cfg.model.auxiliary_head.num_classes = len(classes)
        model = build_segmentor(
            cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
        model.CLASSES = classes
        model.PALETTE = pallete
        model.pretrained = None
        checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
        model = revert_sync_batchnorm(model)
        model.cfg = cfg
        model.to('cpu')
        model.eval()
        for file in os.listdir(img_direction):
            image = mmcv.imread(os.path.join(img_direction, file))
            result = inference_segmentor(model, image)
            # print(result)
            self.show_result(image, result, pallete, classes,os.path.join(mydrive,'result',model_name,file))
            
    def color_img(self,seg,color):
      for label, colo in enumerate(color):
        color_seg[seg == label, :] = colo
      return color_seg
    def show_result(self,img, res, colors, labels,path):
        res = self.color_img(res[0],colors)
        fig = plt.figure(figsize=(16,16))
        fig.add_subplot(1, 2, 1)
        plt.imshow(res)
        fig.add_subplot(1, 2, 2)
        plt.imshow(img)
        patches = [m_patches.Patch(color=np.array(colors[i])/ 256.,
                                   label=labels[i]) for i in range(len(colors))]
        # put those patched as legend-handles into the legend
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
                   fontsize='large')
        plt.savefig(path)
        plt.show()


In [None]:
color =[[0,0,0],[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
                 [159, 129, 183], [0, 255, 0], [255, 195, 128]]
# for i in range(len(color)):
#   print(color)
#   c = color[i]
#   print(c)
#   c = c[::-1]
#   print(c)

In [None]:
#@title
import gc

if __name__ == '__main__':
    dt = [(pt.join(data_subfolders, test,'Rural', img)),(pt.join(data_subfolders, test,'Urban', img))]
    gc.collect()
    ignore_define = True
    model = ['Vit','dlv3']
    initial_process = False# if model=='dlv3' else True
    for m in model:
      for d in dt:
        train_spc['work_directory'] = output_address
        Agent = MMSEGAerialAnalysis(data_spc)
        Agent.init_data(data_spc, ignore_define)
        Agent.prepare_config(config_files[m], train_spc)
        Agent.init_model()
        # Agent.train_model(initial_process)
        ignore_define = True
        Agent.test_model(d, checkpoint_files[m], config_files[m],classes, color,m)

In [None]:
from google.colab.patches import cv2_imshow
a = cv2.imread('/content/drive/MyDrive/data/Train/Rural/masks_png/0.png')
color =[[0,0,0],[255, 255, 255], [255, 0, 0], [255, 255, 0], [0, 0, 255],
                 [159, 129, 183], [0, 255, 0], [255, 195, 128]]
classes = ('null','background', 'building', 'road', 'water', 'barren', 'forest',
                   'agricultural')
seg = a[:,:,0]
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)

    # color_seg = color_seg[..., ::-1]
plt.figure(figsize=(10, 10))
plt.imshow(color_seg)
patches = [m_patches.Patch(color=np.array(color[i]) / 256.,
                                   label=classes[i]) for i in range(len(color))]
        # put those patched as legend-handles into the legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
            fontsize='large')
plt.show()
# Agent.show_result(cv2.imread('/content/drive/MyDrive/data/Train/Rural/images_png/0.png'),color_seg,color,classes,m)