In [7]:
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models
import torchvision.utils
import os, fnmatch
from affine import Affine
import rasterio
from datetime import datetime
import cv2
from skimage import color
import json
from tifffile import imwrite
import torch
import torch.nn as nn
from collections import defaultdict
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy
import math
import csv
import albumentations as A
from albumentations import augmentations
import sys
import warnings
warnings.filterwarnings('ignore')
# warnings.filterwarnings(action='once')
from tqdm import tqdm

In [8]:
sys.path.append('../../mmsegmentation')

from mmseg.apis import inference_segmentor, init_segmentor

In [3]:
# DATASET

import mmcv
sys.path.append('../../mmsegmentation/')
import mmseg
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette
from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

data_root = '/home/jovyan/fominaav/CloudsDetection/masker/mmsegmentation/data/my_dataset/'
img_dir = 'images'
ann_dir = 'annotations'
classes = ('Background', 'Clouds', 'Shadows', 'Snow')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34]]

@DATASETS.register_module()
class ALCD(CustomDataset):
    CLASSES = classes
    PALETTE = palette
    def __init__(self, split, **kwargs):
        super().__init__(img_suffix='.jpg', seg_map_suffix='.png', 
                     split=split, **kwargs)
        assert osp.exists(self.img_dir) and self.split is not None

In [9]:
# CONFIG

from mmcv import Config
cfg = Config.fromfile('/home/jovyan/fominaav/CloudsDetection/masker/mmsegmentation/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py')

In [10]:
from mmseg.apis import set_random_seed
from mmseg.utils import get_device

# add CLASSES and PALETTE to checkpoint
cfg.checkpoint_config.meta = dict(
    CLASSES=classes,
    PALETTE=palette)

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 4

# Modify dataset type and path
cfg.dataset_type = 'ALCD'
cfg.data_root = data_root

cfg.data.samples_per_gpu = 32
cfg.data.workers_per_gpu=8

cfg.img_norm_cfg = dict(
    mean=[0.3965, 0.4287, 0.3811], std=[0.2790, 0.2724, 0.2750], to_rgb=True)
cfg.crop_size = (512, 512)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(512, 512), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(512, 512),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]


cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

# Set up working dir to save files and logs.
cfg.work_dir = '/home/jovyan/fominaav/CloudsDetection/masker/notebooks/work_dirs/vit'

cfg.runner.max_iters = 5000
cfg.log_config.interval = 100
cfg.evaluation.interval = 100
cfg.checkpoint_config.interval = 100

# Set seed to facitate reproducing the result
cfg.seed = 0
cfg.gpu_ids = [0]
cfg.device = get_device()

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')

Config:
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain/vit_base_patch16_224.pth',
    backbone=dict(
        type='VisionTransformer',
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        embed_dims=768,
        num_layers=12,
        num_heads=12,
        mlp_ratio=4,
        out_indices=(2, 5, 8, 11),
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        with_cls_token=True,
        norm_cfg=dict(type='BN', requires_grad=True),
        act_cfg=dict(type='GELU'),
        norm_eval=False,
        interpolate_mode='bicubic'),
    neck=dict(
        type='MultiLevelNeck',
        in_channels=[768, 768, 768, 768],
        out_channels=768,
        scales=[4, 2, 1, 0.5]),
    decode_head=dict(
        type='UPerHead',
        in_channels=[768, 768, 768, 768],
        in_index=[0, 1, 2, 3],
        pool_scales=(1, 2, 3, 6),
        channels=

In [None]:
with open('vit_config.py', 'w') as f:
    f.write(cfg.pretty_text)

In [None]:
# TRAIN
import os.path as osp
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor


# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(cfg.model)
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

In [None]:
img = mmcv.imread('/home/jovyan/fominaav/CloudsDetection/masker/mmsegmentation/data/low_res_dataset/images/20171013T081959_204.jpg')
model.cfg = cfg
result = inference_segmentor(model, img)
show_result_pyplot(model, img, result, palette, fig_size=(6, 4))