In [1]:
import os
import sys
import cv2
import json
import random
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torchvision.transforms import functional as FF
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import PolynomialLR
from torch.utils.data import Dataset, DataLoader
#from torchvision.ops import FocalLoss as  FocalLossTV
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors

sys.path.append('libs/torchgeo')
sys.path.append('libs/mmengine')
sys.path.append('libs/mmcv')
sys.path.append('libs/mmsegmentation')
sys.path.append('libs/mmdetection')

from mmengine.config import Config, DictAction
from mmengine.logging import print_log
from mmengine.optim.scheduler.lr_scheduler import PolyLR

from mmseg.registry import RUNNERS
from mmseg.models.losses import MSELoss, KLDistLoss

In [2]:
with open('params.txt') as file:
    params = file.read()
params  = params.split('\n')[0].split(',')
dataset = params[0]
arch    = params[1] #'swint'#'resnet50' # 'mnetv2' # mixvit
decode_head = params[2]
seed = int(params[3])
lw0  = float(params[4])
lw1  = float(params[5])
lw2  = float(params[6])

#dataset, arch, decode_head, seed , lw0, lw1, lw2 = 'lcai', 'resnet18', 'dlv3', 1, 0.0, 0.0, 0.0

if dataset == 'lcai':
    inp_channel  = 3
    num_class    = 5
    class_names  = 'bg,infra,water,green,road'
    class_weight = [1.0,1.5,1.0,1.0,1.5] #[1.0,1.5,1.0,1.0,1.5]
    cnames    = ['black', 'gray', 'forestgreen', 'cyan', 'blue']
    img_size  = 128
    epochs    = 15
    pat       = 10
    
if dataset in ['naip','lsat']:
    inp_channel  = 5
    num_class    = 4
    class_names  = "bg,water,green,infra"
    if dataset == 'naip':
        class_weight = [1.0,0.1,0.1,1.0]
    if dataset == 'lsat':
        class_weight = [1.0,1.0,1.0,1.0]
    cnames = ['black','cyan','forestgreen','red']
    img_size  = 128
    epochs    = 10
    pat       = 5
    lr        = 5e-5

if arch in ['unet','swint','mixvit']:
    mean  = 0.0
    optim = 'adam'
    batch_size = 4
    clip_grad  = None
    
if arch in ['resnet18', 'resnet34', 'resnet50', 'mnetv2']:
    mean  = 0.5
    optim = 'sgd'
    batch_size = 4
    clip_grad  = 1.0

if optim == 'adam':
   lr        = 5e-5
   regularization = None#'L2'
    
if optim == 'sgd':
   lr        = 0.01
   regularization = 'L2'

torch.random.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.use_deterministic_algorithms(True, warn_only=True)
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:21'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # ':4096:8' :16:8

if dataset == 'lcai':
    from lcai_dataset import LandCoverDataset, CustomDataLoader
    train_dataset = LandCoverDataset(split = 'train', img_size = img_size, aug_data = True, filter_data = True, mean = mean)
    val_dataset   = LandCoverDataset(split = 'test', img_size = img_size, filter_data = False, mean = mean)
    TrainDataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = False, num_workers =0)
    ValDataloader   = DataLoader(val_dataset, batch_size   = 3*batch_size, num_workers =0)

if dataset in ['naip','lsat']:
    from naip_dataset import NDVIDataset, GridGeoSampler, CustomDataLoader, Units
    train_dataset  = NDVIDataset(split = 'train', name = dataset, img_size = img_size)
    val_dataset    = NDVIDataset(split = 'val', name = dataset, img_size = img_size)

    train_sampler   = GridGeoSampler(train_dataset, img_size, img_size, units=Units.PIXELS)
    val_sampler     = GridGeoSampler(val_dataset, img_size, img_size, units=Units.PIXELS)
    #train_sampler   = RandomGeoSampler(train_dataset, size=256) #, length=4)
    #sampler        = RandomGeoSampler(ndvi_dataset, size=64) #, length=4)

    TrainDataloader = CustomDataLoader(train_dataset, sampler=train_sampler, batch_size = batch_size)#, collate_fn=collate_fn, num_workers =0)
    ValDataloader   = CustomDataLoader(val_dataset, sampler=val_sampler, batch_size = 3*batch_size, split = 'valid')#, collate_fn=collate_fn, num_workers =0)
    
if (lw1 > 0) and (lw2 > 0):
    use_aux_head = decode_head
else:
    use_aux_head = ''

config = 'libs/mmsegmentation/configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-40k_cityscapes-769x769.py'
cfg = Config.fromfile(config)

#cfg.optimizer = dict(type = 'Adam', lr = 0.0001)
#dict
cfg.aux1_lw = lw0
cfg.aux2_lw = lw1
cfg.aux3_lw = lw2
cfg.T       = 1.0

try:
    lr_pfx = str(lr).split('.')[1]
except:
    lr_pfx = str(lr)
    
exp_path = f"logs/v5/{optim}/{lr_pfx}/{dataset}_{arch}_{decode_head}_{num_class}cls_{lr}_{cfg.aux1_lw}_{cfg.aux2_lw}_{cfg.aux3_lw}_s{seed}"

if arch == 'unet':
    channels = [64, 64, 64, 64, 64]
    
if arch == 'resnet18':
    resnet_depth = 18
    channels = [64, 128, 256, 512]
    cfg.model.feat_channel = channels
    init_cfg= None#{'type': 'Pretrained', 'checkpoint': 'torchvision://resnet18'}

elif arch == 'resnet34':
    resnet_depth = 34
    channels = [64, 128, 256, 512]
    cfg.model.feat_channel = channels
    init_cfg= None#{'type': 'Pretrained', 'checkpoint': 'torchvision://resnet34'}
    
elif arch == 'resnet50':
    resnet_depth = 50
    channels = [256, 512, 1024, 2048]
    cfg.model.feat_channel = channels
    init_cfg = None #{'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnet50_v1c'}
else:
    resnet_depth = 34
    channels = [64, 128, 256, 512]
    cfg.model.feat_channel = channels
    init_cfg = {'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnet50_v1c'}
    
if arch == 'mnetv2':
    channels = [32, 96, 160, 320] # 32
    cfg.model.feat_channel = channels
    #init_cfg={'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnet50_v1c'}

if arch == 'swint':
    swint_ind = 1
    channels = [[96, 192, 384, 768], [128, 256, 512, 1024]][swint_ind] 
    cfg.model.feat_channel = channels
    #init_cfg={'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnet50_v1c'}
    
if arch == 'mixvit':
    channels = [64, 128, 320, 512]
    cfg.model.feat_channel = channels
    #init_cfg={'type': 'Pretrained', 'checkpoint': 'open-mmlab://resnet50_v1c'}

cmap = colors.ListedColormap(cnames)
bounds=np.arange(num_class + 1) - 0.5
norm = colors.BoundaryNorm(bounds, cmap.N)

decoder_loss = dict(type = 'CrossEntropyLoss', use_sigmoid = False, loss_weight = 1.0, class_weight = class_weight, ignore_index = None)
#decoder_loss = dict(type = 'FocalLoss', use_sigmoid = True, loss_weight = 1.0, class_weight = class_weight)   
itr_per_epoch = 250

  check_for_updates()


Using the whole train set --> 7470


  res['data_samples']['gt_sem_seg']  = torch.tensor(gt,dtype=torch.uint8)
 14%|█▍        | 1080/7470 [00:51<05:06, 20.88it/s]


Using the whole test set --> 1602


 17%|█▋        | 270/1602 [00:10<00:51, 25.66it/s]


In [5]:
os.makedirs(f'{exp_path}/src/libs/', exist_ok = True)
os.system(f'cp -r libs/mmengine {exp_path}/src/libs/')
os.system(f'cp -r libs/mmsegmentation {exp_path}/src/libs/')
os.system(f'cp -r *py {exp_path}/src/')
os.system(f'cp -r train_v2.ipynb {exp_path}/src/')

0

In [7]:
print(len(TrainDataloader), len(ValDataloader))

270 23


In [8]:
#fig.savefig('inputs_cpk.pdf')

In [9]:
class SoftCrossEntropyLoss(torch.nn.Module):
   def __init__(self, loss_weight = 1.0, class_weight = [1.0], T = 1.0):
      super().__init__()
      self.lw = loss_weight
      self.cw = torch.tensor(class_weight)[None,:,None,None]
      self.T  = T

   def forward(self, y_hat, y, **kwargs):
      #y     = F.softmax(y/self.T, dim = 1)
      y_hat = F.softmax(y_hat/self.T, dim = 1)
      y     = self.cw.to(y.device)*y
      loss  = -(y*y_hat.log()).sum(dim = 1)
      return self.lw * loss.mean()
       
   @property
   def loss_name(self):
      return 'soft_ce_loss'


class CrossEntropyLoss(torch.nn.Module):
   def __init__(self, loss_weight = 1.0, class_weight = [1.0], T = 1.0):
      super().__init__()
      self.lw = loss_weight
      self.cw = np.array(class_weight)
      self.T  = T

   def forward(self, y_hat, y, **kwargs):
      #y     = F.softmax(y/self.T, dim = 1)
      y_hat = F.softmax(y_hat/self.T, dim = 1)
      #y     = self.cw[None,:,None,None]*y
      loss  = -(y*y_hat.log()).sum(dim = 1)
      return self.lw * loss.mean()
       
   @property
   def loss_name(self):
      return 'soft_ce_loss'

class SoftFocalLoss(torch.nn.Module):
   def __init__(self, loss_weight = 1.0, class_weight = [1.0], T = 1.0):
      super().__init__()
      self.lw = loss_weight
      self.cw = np.array(class_weight)
      self.T  = T

   def forward(self, y_hat, y, alpha =  0.75, gamma = 2, **kwargs):
      y_hat = F.sigmoid(y_hat/self.T)
      ce_loss  = - ((y * y_hat.log()) + ((1 - y) * (1 - y_hat).log()))

      p_t = y_hat * y + (1 - y_hat) * (1 - y)
      loss = ce_loss * ((1 - p_t) ** gamma)

      if alpha >= 0:
          alpha_t = alpha * y + (1 - alpha) * (1 - y)
          loss    = alpha_t * loss
      return self.lw * loss.mean()
       
   @property
   def loss_name(self):
      return 'soft_ce_loss'


class SoftBinaryCrossEntropyLoss(torch.nn.Module):
   def __init__(self, loss_weight = 1.0, T = 1.0):
      super().__init__()
      self.lw = loss_weight
      self.T  = T

   def forward(self, y_hat, y, **kwargs):
      #y     = F.sigmoid(y/self.T)
      y_hat = F.sigmoid(y_hat/self.T)
      loss  = - ((y * y_hat.log()) + ((1 - y) * (1 - y_hat).log()))
      return self.lw * loss.mean()

   @property
   def loss_name(self):
      return 'soft_bce_loss'
       
class SoftKLDivLoss():
   def __init__(self, weights = 1.0, T = 1.0):
      super().__init__()
      self.weights = weights

   def forward(self, y_hat, y):
      p = F.softmax(y_hat, 1)
      y = self.weights*y
      loss = (y*(F.log(y/p))).sum() / (w_labels).sum()
      return loss
       
if arch == 'swint':      
    swint =dict(
            type='SwinTransformer',
            pretrain_img_size=224,
            embed_dims= [96, 128][swint_ind],
            patch_size=4,
            window_size=7,
            mlp_ratio=4,
            depths=[[2, 2, 6, 2], [2, 2, 18, 2]][swint_ind],
            num_heads=[[3, 6, 12, 24], [4, 8, 16, 32]][swint_ind], 
            strides=(4, 2, 2, 2),
            out_indices=(0, 1, 2, 3),
            qkv_bias=True,
            qk_scale=None,
            patch_norm=True,
            drop_rate=0.,
            attn_drop_rate=0.,
            drop_path_rate=0.3,
            use_abs_pos_embed=False,
            act_cfg=dict(type='GELU'),
            norm_cfg = dict(type='LN', requires_grad=True))

mixvit = dict(type = 'MixVisionTransformer',
  in_channels = 3,
  embed_dims  = 64,
  num_stages  = 4,
  num_layers  = [3, 4, 6, 3], # 2, 2, 2, 2
  num_heads   = [1, 2, 5, 8],
  patch_sizes = [7, 3, 3, 3],
  sr_ratios   = [8, 4, 2, 1],
  out_indices = (0, 1, 2, 3),
  mlp_ratio   = 4,
  qkv_bias    = True,
  drop_rate   = 0.0,
  attn_drop_rate = 0.0,
  drop_path_rate = 0.1)
  #init_cfg       = dict(type = 'Pretrained',
  #checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b1_20220624-02e5a6a1.pth'))

resnet = dict(
    type = 'ResNetV1c',
    depth       = resnet_depth,
    num_stages  = 4,
    out_indices = (0, 1, 2, 3),
    dilations   = (1, 1, 2, 4), # 1, 4, 8, 16 #1, 1, 2, 4
    strides     = (1, 2, 2, 2), #1, 2, 2, 2
    norm_cfg    = {'type': 'SyncBN', 'requires_grad': True},
    norm_eval   = False,
    style = 'pytorch',
    contract_dilation = True,
    init_cfg=init_cfg)

mobilenet=dict(
        type='MobileNetV2',
        widen_factor=1.,
        strides=(1, 1, 1, 2, 2, 1, 1), # (1, 2, 2, 1, 1, 1, 1)
        dilations=(1, 1, 1, 2, 2, 4, 4),
        out_indices=(2, 4, 5, 6), # 1, 2, 4, 6
        norm_cfg=dict(type='SyncBN', requires_grad=True))

unet = dict(
        type='UNet',
        in_channels=3,
        base_channels=64,
        num_stages=5,
        strides=(1, 1, 1, 1, 1),
        enc_num_convs=(2, 2, 2, 2, 2),
        dec_num_convs=(2, 2, 2, 2),
        downsamples=(True, True, True, True),
        enc_dilations=(1, 1, 1, 1, 1),
        dec_dilations=(1, 1, 1, 1),
        with_cp=False,
        conv_cfg=None,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        act_cfg=dict(type='ReLU'),
        upsample_cfg=dict(type='InterpConv'),
        norm_eval=False)

if decode_head == 'fcn':
    fcn_head = dict(
            type='FCNHead',
            in_channels=64,
            in_index=4,
            channels=64,
            num_convs=1,
            concat_input=False,
            dropout_ratio=0.1,
            num_classes=num_class,
            norm_cfg=dict(type='SyncBN', requires_grad=True),
            align_corners=False,
            loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
_nc = num_class
if decode_head == 'dlv3':
    if len(use_aux_head): dlv3_in_channels = 4*(max(channels[0]//2,32)) + _nc +1
    else: dlv3_in_channels = channels[3]

    deeplab_head = dict(
         type = 'DepthwiseSeparableASPPHead',
         in_channels = dlv3_in_channels,
         in_index    = 1 if use_aux_head else 3,
         #input_transform = 'resize_concat',
         channels  = min(channels[1],256),
         dilations = (1, 6, 12, 24), # 1, 12, 24, 36 # (1, 6, 12, 24)
         c1_in_channels = channels[0],
         c1_channels    = channels[0]//4, #48
         dropout_ratio  = 0.1,
         num_classes    = num_class,
         activation     = None,
         norm_cfg       = {'type': 'SyncBN', 'requires_grad': True},
         align_corners  = True,
         loss_decode    = dict(type = 'CrossEntropyLoss', use_sigmoid = False, loss_weight = 1.0, class_weight = class_weight))
         #loss_decode   = dict(type = 'DiceLoss', use_sigmoid= False, loss_weight= 1.0))
         #loss_decode    = dict(type = 'LovaszLoss', loss_type = 'multi_class', per_image = True))

if decode_head == 'upn':
    if len(use_aux_head): upn_in_channels  = [256 +1, 256 + num_class, 256, 256]
    else: upn_in_channels  = channels
        
    upper_head = dict(type = 'UPerHead',
      in_channels= upn_in_channels,
      in_index= [0, 1, 2, 3],
      pool_scales= (1, 2, 3, 6),
      channels= 512,
      dropout_ratio= 0.1,
      num_classes= num_class,
      norm_cfg= {'type': 'SyncBN', 'requires_grad': True},
      align_corners = False,
      loss_decode   = dict(type = 'CrossEntropyLoss', use_sigmoid= False, loss_weight= 1.0, class_weight = class_weight))

if decode_head == 'sfm':
    if len(use_aux_head): sfm_in_channels  = [256 +1, 256 + num_class, 256, 256]
    else: sfm_in_channels  = channels

    #16,24,32,64,96,160,320
    segfm_head = dict(
            type          = 'SegformerHead',
            in_channels   = sfm_in_channels, #[32, 64, 160, 256] [96, 192, 384, 768], #
            in_index      = [0, 1, 2, 3],
            dilations     = [1, 1, 1, 1],
            channels      = 256,
            dropout_ratio = 0.1,
            num_classes   = num_class,
            activation    = None,
            norm_cfg=dict(type = 'SyncBN', requires_grad=True),
            align_corners      = False,
            loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0, class_weight = class_weight))

if len(use_aux_head):
    auxiliary_head_1 = dict(
       type = 'FCNHead',
       loss_key    = 'gt_hres_img',
       in_channels = channels[0],
       in_index    = 0,
       channels    = channels[0]//2,
       num_convs   = 1,
       concat_input  = False,
       dropout_ratio = 0.1,
       num_classes   = 5,
       activation    = None,
       norm_cfg = {'type': 'SyncBN', 'requires_grad': True},
       align_corners = True,
       loss_decode   = dict(type = 'MSELoss', loss_weight = cfg.aux1_lw))
    
    auxiliary_head_2 = dict(
       type = 'FCNHead',
       loss_key    = 'gt_hres_edge',
       in_channels = channels[0],
       in_index    = 0,
       channels    = channels[0]//2,
       #input_transform = 'resize_concat',
       num_convs     = 1,
       concat_input  = False,
       dropout_ratio = 0.1,
       num_classes   = 1,
       activation    = None,
       norm_cfg = {'type': 'SyncBN', 'requires_grad': True},
       align_corners = True,
       loss_decode   = SoftBinaryCrossEntropyLoss(loss_weight = cfg.aux2_lw, T = cfg.T))
    
    auxiliary_head_3 = dict(
       type = 'FCNHead',
       loss_key    = 'gt_hres_mask',
       in_channels =  channels[1],
       in_index    =  1,
       #input_transform = 'resize_concat',
       channels    = channels[1]//2,
       num_convs   = 1,
       concat_input  = False,
       dropout_ratio = 0.1,
       num_classes   = _nc,
       activation    = None,
       norm_cfg      = {'type': 'SyncBN', 'requires_grad': True},
       align_corners = True,
       loss_decode   = SoftCrossEntropyLoss(loss_weight = cfg.aux3_lw, class_weight = class_weight, T = cfg.T))

    cfg.model.auxiliary_head_1 = auxiliary_head_1
    cfg.model.auxiliary_head_2 = auxiliary_head_2
    cfg.model.auxiliary_head_3 = auxiliary_head_3
    

In [10]:
del cfg.model.auxiliary_head
del cfg.model.pretrained
cfg.model.use_aux_head = use_aux_head
if arch == 'swint':
   cfg.model.backbone = swint#resnet50
if arch == 'mixvit':
   cfg.model.backbone = mixvit
if arch in ['resnet18','resnet34','resnet50']:
   cfg.model.backbone = resnet
if arch == 'mnetv2':
    cfg.model.backbone = mobilenet
if arch == 'unet':
    cfg.model.backbone = unet
    
cfg.model.num_class = num_class

if decode_head == 'dlv3':
   cfg.model.decode_head   = deeplab_head
if decode_head == 'upn':
    cfg.model.decode_head  = upper_head
if decode_head == 'sfm':
    cfg.model.decode_head  = segfm_head
if decode_head == 'fcn':
    cfg.model.decode_head  = fcn_head

In [11]:
#itr_per_epoch = 281

In [13]:
from mmseg.registry import MODELS, EVALUATOR
from mmseg.utils import register_all_modules
from mmseg.evaluation import IoUMetric

#from mmdet.utils import register_all_modules as mmengine_register_all_modules
#mmengine_register_all_modules()
register_all_modules()
#cfg.train_cfg = {'type' : 'EpochBasedTrainLoop' , 'max_epochs' :20, 'val_interval' :1 } #IterBasedTrainLoop
cfg.train_cfg = {'type' : 'IterBasedTrainLoop' , 'max_iters' : epochs*itr_per_epoch, 'val_interval' : itr_per_epoch, 'output_dir' : exp_path}
default_hooks = dict(checkpoint=dict(type='CheckpointHook', by_epoch = False, interval = -1, save_best = ['mIoU','mIoU1'], rule = ['greater','greater']))
cfg.model.data_preprocessor['mean'] = [0.5] * inp_channel #[0.5, 0.5, 0.5]
cfg.model.data_preprocessor['std']  = [1.0] * inp_channel
cfg.model.data_preprocessor['size'] = (img_size,img_size)
#cfg.model.test_cfg
cfg.output_dir = exp_path

model = MODELS.build(cfg.model)
if arch == 'mnetv2':
    model.backbone.conv1.conv = torch.nn.Conv2d(inp_channel, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
if arch in ['resnet18','resnet34','resnet50']:
    model.backbone.stem[0] = torch.nn.Conv2d(inp_channel, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
if arch == 'swint':
    model.backbone.patch_embed.projection = torch.nn.Conv2d(inp_channel, channels[0], kernel_size=(4, 4), stride=(4, 4))
if arch == 'mixvit':
    model.backbone.layers[0][0].projection = torch.nn.Conv2d(inp_channel, 64, kernel_size=(7, 7), stride=(7, 7))
if arch == 'unet':
    model.backbone.encoder[0][0].convs[0].conv = torch.nn.Conv2d(inp_channel, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    
val_evaluator = IoUMetric(ignore_index = None, iou_metrics = ['mIoU','mFscore'], output_dir= exp_path, num_class =num_class, class_names = class_names, save_fig_int = 0)





In [None]:
if arch == 'resnet50':
    mdl = torch.load('weights/deeplabv3plus_r50-d8_769x769_80k_cityscapes_20200606_210233-0e9dfdc4.pth')['state_dict']
    ignore_keys = ['backbone.stem.0.','decode_head.conv_seg.weight', 'auxiliary_head.conv_seg.weight','decode_head.conv_seg.bias','auxiliary_head.conv_seg.bias']
    new_mdl = {}
    for key in mdl.keys():
        if inp_channel > 3 and key == 'backbone.stem.0':
            continue
        if 'auxiliary_head' in key or 'decode_head' in key:
            continue
        #if key not in ignore_keys:
        new_mdl[key] = mdl[key]
    model.load_state_dict(new_mdl,strict = False)

if arch in ['resnet18', 'resnet34']:
    mdl = torch.load('weights/deeplabv3plus_r18-d8_769x769_80k_cityscapes_20201226_083346-f326e06a.pth')['state_dict']
    ignore_keys = ['backbone.stem.0.','decode_head.conv_seg.weight', 'auxiliary_head.conv_seg.weight','decode_head.conv_seg.bias','auxiliary_head.conv_seg.bias']
    new_mdl = {}
    for key in mdl.keys():
        if inp_channel > 3 and key == 'backbone.stem.0':
            continue
        if 'auxiliary_head' in key or 'decode_head' in key:
            continue
        #if key not in ignore_keys:
        new_mdl[key] = mdl[key]
    model.load_state_dict(new_mdl,strict = False)
    
if arch == 'unet':
    mdl = torch.load('weights/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes_20211210_145204-6860854e.pth')['state_dict']
    #backbone.encoder.0.0.
    ignore_keys = ['decode_head.conv_seg.weight', 'decode_head.conv_seg.bias']
    new_mdl = {}
    for key in mdl.keys():
        if 'auxiliary_head' in key:
            continue
        if key in ignore_keys:
            continue
        new_mdl[key] = mdl[key]
    model.load_state_dict(new_mdl,strict = False)

if arch == 'swint':
    paths =  ['weights/upernet_swin_tiny_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210531_112542-e380ad3e.pth']
    paths += ['weights/upernet_swin_base_patch4_window7_512x512_160k_ade20k_pretrain_224x224_1K_20210526_192340-593b0e13.pth']
    mdl = torch.load(paths[swint_ind])
    ignore_keys = ['backbone.stem.0.','decode_head.conv_seg.weight', 'auxiliary_head.conv_seg.weight','decode_head.conv_seg.bias','auxiliary_head.conv_seg.bias']
    new_mdl = {}
    ignore_keys = ['decode_head.conv_seg.weight', 'decode_head.conv_seg.bias']
    new_mdl = {}
    for key in mdl.keys():
        if 'auxiliary_head' in key:
            continue
        if key in ignore_keys:
            continue
        new_mdl[key] = mdl[key]
    model.load_state_dict(new_mdl,strict = False)

if arch == 'mixvit':
    mdl = torch.load('weights/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth')['state_dict']
    ignore_keys = ['backbone.stem.0.','decode_head.conv_seg.weight', 'auxiliary_head.conv_seg.weight','decode_head.conv_seg.bias','auxiliary_head.conv_seg.bias']
    new_mdl = {}
    ignore_keys = ['decode_head.conv_seg.weight', 'decode_head.conv_seg.bias']
    new_mdl = {}
    for key in mdl.keys():
        if 'auxiliary_head' in key:
            continue
        if key in ignore_keys:
            continue
        new_mdl[key] = mdl[key]
    model.load_state_dict(new_mdl,strict = False)


In [None]:
if len(use_aux_head):
    cfg.model.auxiliary_head_2.loss_decode = cfg.model.auxiliary_head_2.loss_decode.loss_name
    cfg.model.auxiliary_head_3.loss_decode = cfg.model.auxiliary_head_3.loss_decode.loss_name
with open(f'{exp_path}/config.json', 'w') as file:
    json.dump(cfg.to_dict(), file)
model.to('cuda')
if len(use_aux_head):
    for head in model.auxiliary_heads:
        head.to('cuda')

In [None]:
from train import *
if optim == 'sgd':
    optimizer = SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = 0.0005)
    scheduler = None#PolyLR(optimizer, eta_min = 0.0001, power = 0.9, begin = 0, end = 40000, by_epoch = False)
    #scheduler = PolynomialLR(optimizer, total_iters = 40000, power=0.9, last_epoch=-1)
if optim == 'adam':
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = None
    
train(model, 
      TrainDataloader, 
      ValDataloader, 
      epochs, 
      optimizer, 
      scheduler,
      evaluator = val_evaluator,
      clip_grad = clip_grad,
      regularization = regularization,#"L2", 
      reg_lambda = 1e-6, 
      patience = pat, 
      verbose = True, 
      device = 'cuda', 
      output_dir = exp_path)
torch.save(model.state_dict(), f'{exp_path}/final_checkpoint.pth')

In [14]:
import glob

val_evaluator = IoUMetric(ignore_index = None, iou_metrics = ['mIoU','mFscore'], output_dir= 'test', num_class =num_class, class_names = class_names)
if dataset == 'lcai':
   val_dataset   = LandCoverDataset(split = 'test', img_size = img_size, filter_data = False, mean = mean)
   ValDataloader = DataLoader(val_dataset, batch_size  = 3*batch_size)
model.train()

#exp_path = 'logs/005/lcai_v3_resnet50_dlv3_5cls_0.005_0.0_0.0_0.0_s1'
path = f'{exp_path}/checkpoint.pt'
mdl = torch.load(path)#['state_dict']
model.load_state_dict(mdl)
model.to('cuda')
if len(use_aux_head):
    for head in model.auxiliary_heads:
        head.to('cuda')
model.eval()

Using the whole test set --> 1602


EncoderDecoder(
  (data_preprocessor): SegDataPreProcessor()
  (backbone): ResNetV1c(
    (stem): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): SyncBatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): SyncBatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
      (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (7): SyncBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU(inplace=True)
    )
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): ResLayer(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): Sy

In [15]:
def sm(x):
    x = torch.nn.functional.softmax(x, dim=1)
    return x
    
def sg(x):
    x = torch.nn.functional.sigmoid(x)
    return x
    
def img_norm(x):
    #x = x - x.min(axis=-1)[0].min(axis=-1)[0][...,None,None]
    #x = x / x.max(axis=-1)[0].max(axis=-1)[0][...,None,None]
    return x

In [None]:
import pickle
os.makedirs(f'{exp_path}/figs', exist_ok = True)
os.makedirs(f'{exp_path}/res', exist_ok = True)
#ValDataloader   = CustomDataLoader(val_dataset, batch_size   = batch_size, split = 'valid')#,collate_fn=collate_fn, num_workers =0)
titles = ['Image', 'Ground Truth', 'Prediction', 'Edges', 'Saliency']

def dec_io(data_batch, outputs, aux_feat, atn):
    imgs  = []
    masks = []
    preds = []
    aux_feats_0 =[]
    aux_feats_1 = []
    atns = []
    for j in range(batch_size):
        imgs  += [data_batch['inputs'][j][:3].permute(1,2,0).detach().cpu().numpy()+mean]
        masks += [data_batch['data_samples']['gt_sem_seg'][j].detach().cpu().numpy()]
        preds += [outputs[j].detach().cpu().numpy()]
        if len(atn):
            aux_feats_0 += [aux_feat[0][j].detach().cpu().numpy()]
            aux_feats_1 += [aux_feat[1][j].detach().cpu().numpy()]
            atns += [atn[j].detach().cpu().numpy()]
    return imgs, masks, preds, aux_feats_0, aux_feats_1, atns
       
def dec(x):
    return x.detach().cpu().numpy()

for idx, data_batch in tqdm(enumerate(ValDataloader)):
    #with autocast(enabled=False):
    with torch.no_grad():
        #outputs   = model.val_step(data_batch)
        #data_batch             = model.data_preprocessor(data_batch, True)
        feats, aux_feat, atn,_ = model.extract_feat(data_batch['inputs'].float().to('cuda'))
        outputs = model.decode_head(feats)
        if outputs.shape[-1] != img_size:
           outputs = FF.resize(outputs,(img_size, img_size), interpolation = FF.InterpolationMode.BILINEAR)
        logits = F.softmax(outputs) #model(data_batch['inputs'].float().cuda()), dim=1)
        preds  = torch.max(logits, dim=1)[1]
        #continue
        with open(f'{exp_path}/res/res_{idx}.pkl','wb') as file:
            pickle.dump(dec_io(data_batch, preds, aux_feat, atn), file)
        
    #continue
    for j in range(batch_size):
        fig, axes = plt.subplots(nrows = 1, ncols = 5, figsize = (35,4))
        axes[0].imshow(data_batch['inputs'][j][:3].permute(1,2,0).detach().cpu().numpy()+mean)
        axes[1].imshow(data_batch['data_samples']['gt_sem_seg'][j].detach().cpu().numpy(), cmap = cmap, norm = norm)
        axes[2].imshow(preds[j].detach().cpu().numpy(), cmap = cmap, norm = norm)
        if len(aux_feat):
            aux_edge, aux_sal = aux_feat[0][j][0], aux_feat[1][j]
            temp     = aux_edge[2:-2,2:-2]
            aux_edge = aux_edge * 0.0
            aux_edge[2:-2,2:-2] = temp
            _max    = aux_sal.max(axis=-1)[0].max(axis=-1)[0][...,None,None]
            aux_sal = aux_sal/ _max
            axes[3].imshow(aux_edge.detach().cpu().numpy())
            axes[4].imshow(aux_sal.max(axis=0)[0].detach().cpu().numpy())
        for k, ax in enumerate(axes):
            #ax.set_title(f'{titles[k]}')
            ax.set_xticks([])
            ax.set_yticks([])
        fig.savefig(f'{exp_path}/figs/res_{idx*batch_size + j}.png')

In [None]:
val_evaluator.results = []
num_classes = 5
for idx, data_batch in tqdm(enumerate(ValDataloader)):
    with torch.no_grad():
        #outputs   = model.val_step(data_batch)
        #data_batch             = model.data_preprocessor(data_batch, True)
        preds = model.predict(data_batch['inputs'].float().to('cuda'),data_batch['data_samples'])
        val_evaluator.process(data_samples=preds, data_batch = None)
        

  area_intersect = torch.histc(
  area_pred_label = torch.histc(
  area_label = torch.histc(
109it [00:25,  3.91it/s]

In [None]:
import pandas as pd
res = val_evaluator.compute_metrics(val_evaluator.results)
pd.DataFrame(res, index = ['']).to_csv(f'{exp_path}/best_result.csv')

In [None]:
#test_scores(model, ValDataloader, device = 'cuda')

In [None]:
with open(f'{exp_path}/best_result_all.txt','w') as file:
    file.write(val_evaluator.log.get_csv_string())

In [None]:
pd.DataFrame(res, index = [''])