<a href="https://colab.research.google.com/github/sarayaghoubi/sara/blob/master/ImageSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!wget https://zenodo.org/record/5706578/files/Train.zip -P /content/drive/MyDrive/data/Train
!unzip /content/drive/MyDrive/data/Train/Train.zip -d /content/drive/MyDrive/data/Train/


!wget https://zenodo.org/record/5706578/files/Test.zip -P /content/drive/MyDrive/data/Test
!unzip /content/drive/MyDrive/data/Test/Test.zip -d /content/drive/MyDrive/data/Test/

!wget https://zenodo.org/record/5706578/files/Val.zip -P /content/drive/MyDrive/data/Val
!unzip /content/drive/MyDrive/data/Val/Val.zip -d /content/drive/MyDrive/data/Val/

In [None]:
!pip install -U torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
!git clone https://github.com/open-mmlab/mmsegmentation.git
%cd mmsegmentation
!pip install -e .
import mmseg

In [None]:
# Last update was on  4 Mar
import os.path as osp
import numpy as np
from PIL import Image
import os
import cv2
import mmcv
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset
import random


def annotate(img, label):
    """
    The reason why this method was built this method was to create an annotation map
    if the output was meant to be in GrayScale this function must be called
    :param img: the image
    :param label: the colors was used in the segmentation maps
    :return: 2D array that has mapped the colors to the classes' id
    """
    size = img.shape
    annotation = np.zeros(size[0:2])
    i = 0
    for color in label:
        z = np.where(img == color)[0:2]
        annotation[z] = i
        i += 1
    return annotation


class Data:
    """
    So this class was mainly intended to deal with our custom dataset, since it had originally
    images with different size, the output needed to be reconstructed
    Arguments:
        data_root: the folder contain the dataset
        img_dir: address of the folder has the input
        img/ann_postfix: postfix
    """

    def __init__(self,
                 data_config
                 ):
        self.std = None
        self.mean = None
        self.n_classes = None
        self.classes = []
        self.palette = None
        self.data_root = data_config['root']
        self.test_root, self.test_annotation = data_config['test_address']
        self.validation_root, self.val_annotation = data_config['validation_address']
        self.img_postfix, self.img_dir = data_config['image']
        self.ann_postfix, self.ann_dir = data_config['annotation']

    def process_data(self, new_folder, size, require_resize):
        """
        first it will resize the image to make sure all the images are in the same shape
        then find the palette
        secondly, find the mean and the std of input images
        Arguments:
            new_folder : the folder that all the new images (after resizing) will be saved
            size : must be equal to the backbone-input image
            require_resize : boolean; if the images were already resized and prepared
        ** please be noticed that the

        """
        if require_resize:
            self.resize(new_folder, size)
            self.img_dir = new_folder
        self.palette, self.classes = self.spc()
        self.mean, self.std = self.statistics()
        self.n_classes = len(self.palette)

    def spc(self):
        labels = 0
        for file in os.listdir(os.path.join(self.data_root, self.ann_dir)):
            file_address = os.path.join(self.data_root, self.ann_dir, file)
            annotation = np.array(Image.open(file_address))
            num_current_labels = len(np.unique(annotation))
            if num_current_labels > labels:
                labels = num_current_labels
        random_rgb = lambda: random.randint(0, 255)
        color = [(random_rgb, random_rgb, random_rgb) for i in range(labels)]
        classes = [chr(i) for i in range(ord('A'), ord('H') + 1)]
        return color, classes

    def statistics(self):  # as each image we have has different shape
        """

        :return:  the mean and standard deviation of the images (these will be required for the config file
        """
        R, G, B = [], [], []
        for file in os.listdir(osp.join(self.data_root, self.img_dir)):
            if file.endswith('jpg'):
                read_path = (osp.join(self.data_root, self.img_dir, file))
                img = np.array(Image.open(read_path).convert('RGB'))
                R.append(img[:, :, 0])
                G.append(img[:, :, 1])
                B.append(img[:, :, 2])
        R, G, B = np.array(R), np.array(G), np.array(B)
        return [np.mean(R), np.mean(G), np.mean(B)], [np.std(R), np.std(G), np.std(B)]

    def resize(self, new_directory, size):
        label, classes = self.spc()
        for file in os.listdir(osp.join(self.data_root, self.img_dir)):
            address = osp.join(self.data_root, self.img_dir, file)
            img = cv2.resize(np.array(Image.open(address).convert('RGB'), dtype=np.uint8), size,
                             interpolation=cv2.INTER_NEAREST)
            if file.endswith('jpg'):
                img = Image.fromarray(img).convert('RGB')
            if file.endswith('png'):
                img = annotate(img, label)
                img = Image.fromarray(img).convert('P')
            img.save(osp.join(self.data_root, new_directory, file))

    def reconstruct(self, data_config,ignore):
        self.process_data(data_config['new_directory'], data_config['size'], data_config['resize'])
        split_dir = 'splits'
        mmcv.mkdir_or_exist(osp.join(self.data_root, split_dir))
        write_file = {
            'train': self.img_dir,
            'test': self.test_root,
            'val': self.validation_root
        }
        for file in write_file:
            filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
                osp.join(self.data_root, write_file[file]), suffix='.png')]
            with open(osp.join(self.data_root, split_dir, f'{file}.txt'), 'w') as f:
                f.writelines(line + '\n' for line in filename_list)
        classes = self.classes
        palette = self.palette
        if not ignore:
          @DATASETS.register_module()
          class AerialDataset(CustomDataset):
              CLASSES = classes
              PALETTE = palette

              def __init__(self, split, **kwargs):
                  super().__init__(img_suffix='.png', seg_map_suffix='.png',
                                  split=split, **kwargs)
                  assert osp.exists(self.img_dir) and self.split is not None


In [None]:
from mmseg.apis import set_random_seed
from mmcv import Config
import mmcv
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor
# from mmcv.runner import load_checkpoint
from mmseg.apis import inference_segmentor
import matplotlib.patches as m_patches
import os.path as osp
import numpy as np
from matplotlib import pyplot as plt
import os


class MMSEGAerialAnalysis:
    def __init__(self, data_config):
        self.cfg = None
        self.model = None
        self.data = Data(data_config)

    def init_data(self, *args):
        self.data.reconstruct(*args)

    def prepare_config(self, config_dir, train):
        cfg = Config.fromfile(config_dir)
        # Since we use ony one GPU, BN is used instead of SyncBN
        cfg.norm_cfg = dict(type='BN')
        cfg.model.decode_head.norm_cfg = dict(type='BN')
        cfg.model.auxiliary_head.norm_cfg = dict(type='BN')
        # modify num classes of the model in decode/auxiliary head
        cfg.model.decode_head.num_classes = self.data.n_classes
        cfg.model.auxiliary_head.num_classes = self.data.n_classes
        cfg.dataset_type = 'AerialDataset'
        cfg.data_root = self.data.data_root
        cfg.img_norm_cfg = dict(
            mean=self.data.mean, std=self.data.std)
        cfg.crop_size = (256, 256)
        cfg.train_pipeline = [
            dict(type='LoadImageFromFile'),
            dict(type='LoadAnnotations'),
            dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
            dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(type='PhotoMetricDistortion'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img', 'gt_semantic_seg']),
        ]
        cfg.test_pipeline = [
            dict(type='LoadImageFromFile'),
            dict(
                type='MultiScaleFlipAug',
                img_scale=(512, 512),
                # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(type='Normalize', **cfg.img_norm_cfg),
                    dict(type='ImageToTensor', keys=['img']),
                    dict(type='Collect', keys=['img']),
                ])
        ]
        cfg.data.train.type = cfg.dataset_type
        cfg.data.train.data_root = cfg.data_root
        cfg.data.train.img_dir = self.data.img_dir
        cfg.data.train.ann_dir = self.data.ann_dir
        cfg.data.train.pipeline = cfg.train_pipeline
        cfg.data.train.split = 'splits/train.txt'

        cfg.data.val.type = cfg.dataset_type
        cfg.data.val.data_root = cfg.data_root
        cfg.data.val.img_dir = self.data.validation_root
        cfg.data.val.ann_dir = self.data.val_annotation
        cfg.data.val.pipeline = cfg.test_pipeline
        cfg.data.val.split = 'splits/val.txt'

        cfg.data.test.type = cfg.dataset_type
        cfg.data.test.data_root = cfg.data_root
        cfg.data.test.img_dir = self.data.validation_root
        cfg.data.test.ann_dir = self.data.val_annotation
        cfg.data.test.pipeline = cfg.test_pipeline
        cfg.data.test.split = 'splits/val.txt'
        cfg.load_from = None

        # Set up working dir to save files and logs.
        cfg.work_dir = './work_directory'

        cfg.runner.max_iters = train['max']
        cfg.log_config.interval = train['log_int']
        cfg.evaluation.interval = train['eval']
        cfg.checkpoint_config.interval = train['checkpoint']
        cfg.optimizer.lr = train['lr_rate']
        cfg.work_dir = train['work_directory']
        # Set seed to facilitate reproducing the result
        cfg.seed = 0
        set_random_seed(0, deterministic=False)
        cfg.gpu_ids = range(1)
        cfg.model.pretrained = None
        self.cfg = cfg

    def init_model(self):
        self.model = build_segmentor(
            self.cfg.model, train_cfg=self.cfg.get('train_cfg'), test_cfg=self.cfg.get('test_cfg'))
        self.model.CLASSES = self.data.n_classes
        self.model.PALETTE = self.data.palette

    def train_model(self):
        datasets = [build_dataset(self.cfg.data.train)]
        mmcv.mkdir_or_exist(osp.abspath(self.cfg.work_dir))
        train_segmentor(self.model, datasets, self.cfg, distributed=False, validate=True,
                        meta=dict())

    def test_model(self, img_direction, checkpoint_file):
        if self.model is None:
            self.init_model()
        # checkpoint = load_checkpoint(self.model, checkpoint_file, map_location='cpu')
        self.model.eval()
        self.model.cfg = self.cfg
        for file in os.listdir(img_direction):
            image = os.path.join(img_direction, file)
            result = inference_segmentor(self.model, image)
            self.show_result(result)

    def show_result(self, img):
        colors = self.data.palette
        labels = self.data.classes
        plt.figure(figsize=(8, 6))
        patches = [m_patches.Patch(color=np.array(colors[i]) / 255.,
                                   label=labels[i]) for i in range(6)]
        # put those patched as legend-handles into the legend
        plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.,
                   fontsize='large')
        plt.show()


In [None]:
!mv  /content/drive/MyDrive/data/Val/Val/Rural/images_png/* /content/drive/MyDrive/data/Val/images_png

In [None]:
import os.path as opt
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/model_Experiment')

configs = {
    'Vit': '/content/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py',
    'Swin': '/content/mmsegmentation/configs/swin/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_22K.py',
    'dlv3': '/content/mmsegmentation/configs/deeplabv3plus/deeplabv3plus_r18b-d8_512x1024_80k_cityscapes.py'
}
data_spc = {
    'root': '/home/sara/images',
    'size': (512, 512),
    'annotation': ('png', '/content/drive/MyDrive/data/Train/masks_png'),
    'image': ('png', '/content/drive/MyDrive/data/Train/images_png'),
    'resize': False,
    'new_directory': '',
    'test_address': ('/content/drive/MyDrive/data/Val/images_png', '/content/drive/MyDrive/data/Val/masks_png'),
    'validation_address': ('/content/drive/MyDrive/data/Val/images_png', '/content/drive/MyDrive/data/Val/masks_png')
}
train_spc = {
    'max': 200,
    'log_int': 10,
    'eval': 10,
    'checkpoint': 5000,
    'lr_rate': 6e-6,
    'work_directory': ''
}
for model in configs:
    train_spc['work_directory'] = f'./{model}'
    Vit = MMSEGAerialAnalysis(data_spc)
    Vit.init_data(data_spc,True)
    Vit.prepare_config(configs[model], train_spc)
    Vit.init_model()
    Vit.train_model()
    Vit.test_model(opt.join(data_spc['root'], data_spc['test_images']), opt.join(train_spc['work_directory']))


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
  keepdims=keepdims, where=where)
  subok=False)
  ret = ret.dtype.type(ret / rcount)
  'Default ``avg_non_ignore`` is False, if you would like to '
2022-04-18 15:56:41,663 - mmseg - INFO - Loaded 2522 images
  cpuset_checked))
2022-04-18 15:56:41,694 - mmseg - INFO - Loaded 1669 images
2022-04-18 15:56:41,700 - mmseg - INFO - Start running, host: root@728bbc3fc1e5, work_dir: /content/mmsegmentation/Vit
2022-04-18 15:56:41,701 - mmseg - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) PolyLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) PolyLrUpdaterHook                  
(LOW         ) IterTimerHook                      
(LOW         ) EvalHook                           
(VERY_LOW    ) TextLoggerHook      

[                             ] 13/1669, 0.1 task/s, elapsed: 208s, ETA: 26436s

KeyboardInterrupt: ignored