In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

In [3]:
import os
import wandb
#os.environ['WANDB_MODE'] = 'dryrun'
os.environ['WANDB_NOTEBOOK_NAME'] = 'data.ipynb'
wandb.init("sky-eye-full")
conf = wandb.config

In [4]:
from glob import glob
train_dir = '/home/jupyter/datasets/xview/train'
test_dir = '/home/jupyter/datasets/xview/test'

In [5]:
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
from torch import nn
from xv.util import vis_im_mask
import segmentation_models_pytorch as smp

In [6]:
conf.aug_prob = .5
conf.train_pre = True
conf.train_post = False

conf.n_dmg_classes = 4
conf.batch_size = 6
conf.image_size = 512

conf.damage_scale_mode = 'ordinal'

conf.blocktype = 'bottleneck'

conf.blocks = [2, 4]
conf.strides = [2, 2]
conf.growth_rate = 1

#conf.pretrained_model = 'selimsef_spacenet4_densenet121unet'
#conf.pretrained = True

conf.segmentation_arch = 'Linknet'
conf.encoder = 'efficientnet-b7'

dmg_downscale=1
for s in conf.strides:
    dmg_downscale *= s
    
conf.dmg_downscale_ratio = dmg_downscale


In [7]:
import albumentations as al

augment = al.Compose([
        al.HorizontalFlip(p=conf.aug_prob),
        al.VerticalFlip(p=conf.aug_prob),
        al.RandomRotate90(p=conf.aug_prob),
        al.Transpose(p=conf.aug_prob),
        al.GridDistortion(p=conf.aug_prob, distort_limit=.2),
        al.ShiftScaleRotate(p=conf.aug_prob),
        al.RandomBrightnessContrast(p=conf.aug_prob)
])

In [8]:
from xv.nn.solaris.model_io import get_model
from xv.nn.nets import DownscaleLayer, XVNet
from torchvision.models.resnet import BasicBlock, Bottleneck
import segmentation_models_pytorch as smp


block_types = {
    'bottleneck': Bottleneck,
    'basic': BasicBlock
}

segmentation_types = {
    'PSPNet': smp.PSPNet,
    'FPN': smp.FPN,
    'Linknet': smp.Linknet,
    'Unet': smp.Unet
}

if 'pretrained_model' in dict(conf):
    building_seg = get_model(conf.pretrained_model, 'torch', pretrained=conf.pretrained)
    preprocess_fn=None
else:
    building_seg = segmentation_types[conf.segmentation_arch](conf.encoder, classes=1, activation='sigmoid')
    preprocess_fn = get_preprocessing_fn(conf.encoder)

if conf.train_post:
    damage = DownscaleLayer(inplanes=building_seg.final_filters, blocks=conf.blocks, strides=conf.strides,
                           block=block_types[conf.blocktype],
                           nclasses=conf.n_dmg_classes,
                           growth_rate=conf.growth_rate)
else:
    damage = None

model = XVNet(building_seg, damage).to('cuda')

In [9]:
from xv.nn import dataset
from xv import util
import random

instances = dataset.XViewSegmentationDataset.get_instances(train_dir)

random.seed(hash("😂"))
random.shuffle(instances)

dev_ix = int(len(instances)*.20)
dev_instances = instances[:dev_ix]
train_instances = instances[dev_ix:]
len(train_instances), len(dev_instances)

train_dataset = dataset.XViewSegmentationDataset(
    instances=train_instances,
    resolution=(conf.image_size, conf.image_size),
    dmg_downscale_ratio = conf.dmg_downscale_ratio,
    augment=augment,
    damage_scale_mode=conf.damage_scale_mode,
    preprocess_fn=preprocess_fn,
)

dev_dataset = dataset.XViewSegmentationDataset(
    instances=dev_instances,
    resolution=(conf.image_size, conf.image_size),
    dmg_downscale_ratio = conf.dmg_downscale_ratio,
    augment=None,
    damage_scale_mode=conf.damage_scale_mode,
    preprocess_fn=preprocess_fn
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=conf.batch_size,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

dev_loader = torch.utils.data.DataLoader(
    dev_dataset,
    batch_size=conf.batch_size,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

HBox(children=(IntProgress(value=0, max=2799), HTML(value='')))




In [10]:
from xv.nn.losses import loss_dict, WeightedLoss

conf.loss_weights = {
    #'dice': 1,
    #'focal': 1,
    'bcewithlogits': 1,
    #'jaccard': 1
}

loss = WeightedLoss({loss_dict[l]():w for l, w in conf.loss_weights.items()})

In [11]:
import apex

optims = {
    'adam': torch.optim.Adam
}

conf.optim = 'adam'
conf.lr = 0.0005

optim = optims[conf.optim](model.parameters(), lr=conf.lr)

In [12]:
from apex import amp
conf.amp_opt_level = 'O1'
model, optim = amp.initialize(model, optim, opt_level=conf.amp_opt_level)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [13]:
wandb.watch(model);

In [14]:
conf.scheduler_factor = 0.5
conf.scheduler_patience = 5

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=conf.scheduler_factor, patience=conf.scheduler_patience)

In [15]:
conf.pre_weight = 1. if conf.train_pre else None
conf.post_weight = 1. if conf.train_post else None

[autoreload of xv.nn.losses failed: Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 450, in superreload
    update_generic(old_obj, new_obj)
  File "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 387, in update_generic
    update(a, b)
  File "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 357, in update_class
    update_instances(old, new)
  File "/opt/anaconda3/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 315, in update_instances
    if hasattr(obj, 'items') or (hasattr(obj, '__contains__')
  File "/opt/anaconda3/lib/python3.7/site-packages/wandb/wandb_config.py", line 189, in __getattr__
    return self.__getitem__(key)
  File "/opt/anaconda3/lib/python3.7/site-pack

In [16]:
conf.epochs = 100
best_loss = 1e5
epoch = 0

In [None]:
from xv import run
for epoch in range(epoch, conf.epochs):
    metrics = {'epoch':epoch}
    metrics.update((
        {f'train_{k}':v for k,v in
         run.train(model, optim, train_loader, loss, pre_weight=conf.pre_weight, post_weight=conf.post_weight).items()}
    ))
    metrics.update(
        run.evaluate(model, dev_loader, loss, pre_weight=conf.pre_weight, post_weight=conf.post_weight
    ))
    wandb.log(metrics)
    scheduler.step(metrics['loss'])
    if metrics['loss'] < best_loss:
        torch.save(model.state_dict(), os.path.join(wandb.run.dir, "state_dict.pth"))
        best_loss = metrics['loss']

100%|██████████| 374/374 [05:59<00:00,  1.04it/s]
100%|██████████| 94/94 [00:27<00:00,  3.36it/s]
100%|██████████| 374/374 [05:52<00:00,  1.06it/s]
100%|██████████| 94/94 [00:27<00:00,  3.39it/s]
100%|██████████| 374/374 [05:53<00:00,  1.06it/s]
100%|██████████| 94/94 [00:28<00:00,  3.33it/s]
100%|██████████| 374/374 [05:54<00:00,  1.06it/s]
100%|██████████| 94/94 [00:28<00:00,  3.32it/s]
100%|██████████| 374/374 [05:55<00:00,  1.05it/s]
100%|██████████| 94/94 [00:28<00:00,  3.32it/s]
100%|██████████| 374/374 [05:54<00:00,  1.05it/s]
100%|██████████| 94/94 [00:28<00:00,  3.35it/s]
100%|██████████| 374/374 [05:51<00:00,  1.06it/s]
100%|██████████| 94/94 [00:27<00:00,  3.38it/s]
 11%|█         | 41/374 [00:42<05:13,  1.06it/s]

In [None]:
ix = 1000
i = train_dataset[ix]
images, masks = i['images'], i['masks']
image = images['post']
image = np.array(train_dataset.inverse_transform_image(image))

util.vis_im_mask(image, masks['damage'], size=(512*2,512*2), opacity=.3);

In [None]:
from collections import Counter
counts = Counter(len(i['pre']['features']) for i in train_dataset.instances)