In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import wandb
#os.environ['WANDB_MODE'] = 'dryrun'
wandb.init(project="sky-eye-full")
conf = wandb.config

In [3]:
from glob import glob
train_dir = '/home/jupyter/datasets/xview/train'
test_dir = '/home/jupyter/datasets/xview/test'

In [4]:
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
from torch import nn
from xv.util import vis_im_mask

In [5]:
conf.aug_prob = .5

In [6]:
import albumentations as al

augment = al.Compose([
        al.HorizontalFlip(p=conf.aug_prob),
        al.VerticalFlip(p=conf.aug_prob),
        al.RandomRotate90(p=conf.aug_prob),
        al.Transpose(p=conf.aug_prob),
        al.GridDistortion(p=conf.aug_prob, distort_limit=.2),
        al.ShiftScaleRotate(p=conf.aug_prob),
        al.RandomBrightnessContrast(p=conf.aug_prob)
])

In [7]:
conf.n_dmg_classes = 4
conf.batch_size = 6
conf.image_size = 512

conf.damage_scale_mode = 'ordinal'

conf.blocktype = 'bottleneck'

conf.blocks = [2, 4]
conf.strides = [2, 2]
conf.growth_rate = 1

dmg_downscale=1
for s in conf.strides:
    dmg_downscale *= s

conf.dmg_downscale_ratio = dmg_downscale

conf.pretrained_model = 'selimsef_spacenet4_densenet121unet'
conf.pretrained = True

In [8]:
from xv.nn.solaris.model_io import get_model
building_seg = get_model(conf.pretrained_model, 'torch', pretrained=conf.pretrained)

  nn.init.kaiming_normal(m.weight.data)


In [9]:
from torchvision.models.resnet import BasicBlock, Bottleneck

block_types = {
    'bottleneck': Bottleneck,
    'basic': BasicBlock
}

class DamageHeatmap(nn.Module):
    
    def __init__(self, inplanes, blocks, strides, block, nclasses, growth_rate=2):
        super().__init__()
        self.block = block
        features = []
        planes = inplanes
        for stride, nblock in zip(strides, blocks):
            planes *= growth_rate
            features.append(self._make_layer(inplanes, planes, stride, nblock))
            inplanes = planes * block.expansion
            
        self.features = nn.ModuleList(features)
        self.head = nn.Conv2d(inplanes, nclasses, kernel_size=1, padding=0)
        self._init_weights()

        
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        for m in self.modules():
            if isinstance(m, Bottleneck):
                nn.init.constant_(m.bn3.weight, 0)
            elif isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)


    def _make_layer(self, inplanes, planes, stride, nblocks):
        if stride != 1 or inplanes != planes * self.block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(inplanes, planes * self.block.expansion, kernel_size=1, stride=stride, bias=False), 
                nn.BatchNorm2d(planes * self.block.expansion)
            )
        layers = []
        layers.append(self.block(inplanes, planes, stride, downsample))
        inplanes = planes * self.block.expansion
        for _ in range(1, nblocks):
            layers.append(self.block(inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        for feature in self.features:
            x = feature(x)
        return self.head(x)

In [10]:
damage = DamageHeatmap(inplanes=building_seg.final_filters, blocks=conf.blocks, strides=conf.strides,
                       block=block_types[conf.blocktype],
                       nclasses=conf.n_dmg_classes,
                       growth_rate=conf.growth_rate)

In [11]:
class XVNet(nn.Module):
    def __init__(self, building_seg, dmg_heatmap):
        super().__init__()
        self.building_seg = building_seg
        self.dmg_heatmap = dmg_heatmap
    
    def forward(self, pre, post):
        pre = self.building_seg(pre)
        post = self.building_seg(post, apply_head=False)
        post = self.dmg_heatmap(post)
        return pre, post

In [12]:
model = XVNet(building_seg, damage).to('cuda')

In [13]:
from xv.nn import dataset
from xv import util
import random

instances = dataset.XViewSegmentationDataset.get_instances(train_dir)

random.seed(hash("😂"))
random.shuffle(instances)

dev_ix = int(len(instances)*.20)
dev_instances = instances[:dev_ix]
train_instances = instances[dev_ix:]
len(train_instances), len(dev_instances)

train_dataset = dataset.XViewSegmentationDataset(
    instances=train_instances,
    resolution=(conf.image_size, conf.image_size),
    dmg_downscale_ratio = conf.dmg_downscale_ratio,
    augment=augment,
    damage_scale_mode=conf.damage_scale_mode
)
dev_dataset = dataset.XViewSegmentationDataset(
    instances=dev_instances,
    resolution=(conf.image_size, conf.image_size),
    dmg_downscale_ratio = conf.dmg_downscale_ratio,
    augment=None,
    damage_scale_mode=conf.damage_scale_mode
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=conf.batch_size,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

dev_loader = torch.utils.data.DataLoader(
    dev_dataset,
    batch_size=conf.batch_size,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

HBox(children=(IntProgress(value=0, max=2799), HTML(value='')))




In [14]:
from xv.nn.losses import loss_dict, WeightedLoss

conf.loss_weights = {
    'dice': 1,
    'focal': 1,
    #'bcewithlogits': 1,
    #'jaccard': 1
}

loss = WeightedLoss({loss_dict[l]():w for l, w in conf.loss_weights.items()})

In [15]:
import apex

optims = {
    'adam': torch.optim.Adam,
    'adamw': torch.optim.AdamW
}

conf.optim = 'adam'
conf.lr = 0.0005

optim = optims[conf.optim](model.parameters(), lr=conf.lr)

In [16]:
from apex import amp

In [17]:
from apex import amp
conf.amp_opt_level = 'O1'
model, optim = amp.initialize(model, optim, opt_level=conf.amp_opt_level)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


In [18]:
wandb.watch(model);

In [19]:
conf.scheduler_factor = 0.5
conf.scheduler_patience = 5

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=conf.scheduler_factor, patience=conf.scheduler_patience)

In [20]:
conf.pre_weight = 1.
conf.post_weight = 1.

In [21]:
from collections import defaultdict
def train(model, optim, data, loss_fn):
    model = model.train()
    loss_sum, loss_pre_sum, loss_post_sum = 0., 0., 0.
    
    for batch in tqdm(iter(data)):
        optim.zero_grad()
        
        pre, post = batch['images']['image'].to('cuda'), batch['images']['post'].to('cuda')
        pre_out, post_out = model(pre, post)
        
        loss_pre = conf.pre_weight*loss_fn(pre_out, batch['masks']['buildings'].to('cuda'))
        loss_post = conf.post_weight*sum((loss_fn(mask_out, mask) for mask_out, mask in zip(post_out, batch['masks']['damage'].to('cuda'))))
        loss_post /= post_out.shape[1]
        
        loss = (loss_pre + loss_post)/(conf.pre_weight+conf.post_weight)
        
        with amp.scale_loss(loss, optim) as scaled_loss:
            scaled_loss.backward()
        optim.step()
        
        loss_sum += loss
        loss_pre_sum += loss_pre
        loss_post_sum += loss_post
        
    return {
        'loss':loss_sum/len(data), 
        'loss_pre': loss_pre_sum/len(data),
        'loss_post': loss_post_sum/len(data)
    }

def batch_metrics(outputs, targets, threshold=0.5):
    metrics = {}
    pr_sum, re_sum, f_sum = 0., 0., 0.
    for output, target in zip(outputs, targets):
        target_bool = target.bool()
        output_bool = output.sigmoid() > threshold

        recall = output_bool[target_bool].float().mean()
        recall = recall if recall == recall else 1.

        precision = target_bool[output_bool].float().mean()
        precision = precision if precision == precision else 1.

        pr_sum += precision
        re_sum += recall
        f_sum += 2*precision*recall/(precision + recall) if (precision + recall) > 0. else 0.

    return {
        'recall': re_sum/len(outputs),
        'precision': pr_sum/len(outputs),
        'f1': f_sum/len(outputs)
    }
    

def evaluate(model, optim, data, loss_fn, threshold=0.5):
    model = model.eval()
    metrics = defaultdict(float)
    
    with torch.no_grad():
        metric_sums = defaultdict(float)
        for batch in tqdm(iter(data)):
            pre, post = batch['images']['image'].to('cuda'), batch['images']['post'].to('cuda')
            pre_out, post_out = model(pre, post)
            pre_targets, post_targets = batch['masks']['buildings'].to('cuda'), batch['masks']['damage'].to('cuda')
            
            loss_pre = conf.pre_weight*loss_fn(pre_out, pre_targets)
            loss_post = conf.post_weight*sum((loss_fn(mask_out, mask) for mask_out, mask in zip(post_out, post_targets)))
            loss = (loss_pre + loss_post)/(conf.pre_weight+conf.post_weight)
            
            metrics['loss_pre'] += loss_pre
            metrics['loss_post'] += loss_post
            metrics['loss'] += loss
            
            macro_metrics = defaultdict(float)
            for dmg_type, ix in train_dataset.DAMAGE_CLASSES.items():
                for k,v in batch_metrics(post_out[:,ix], post_targets[:,ix]).items():
                    metrics[f'dmg_{dmg_type}_{k}'] += v
                    macro_metrics[k] += v
            for k, v in macro_metrics.items():
                metrics[f'dmg_macro_{k}'] += v/len(train_dataset.DAMAGE_CLASSES)
            for k,v in batch_metrics(pre_out, pre_targets).items():    
                metrics[f'building_{k}'] += v
    return {k:v/len(data) for k, v in metrics.items()}

In [22]:
conf.epochs = 140
best_loss = 1e5
epoch = 0

In [None]:
for epoch in range(epoch, conf.epochs):
    metrics = {}
    metrics.update(({f'train_{k}':v for k,v in train(model, optim, train_loader, loss).items()}))
    metrics.update(evaluate(model, optim, dev_loader, loss))
    scheduler.step(metrics['loss'])
    wandb.log(metrics)
    if metrics['loss'] < best_loss:
        torch.save(model.state_dict(), os.path.join(wandb.run.dir, "state_dict.pth"))
        best_loss = metrics['loss']

HBox(children=(IntProgress(value=0, max=374), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 16384.0


In [None]:
ix = 1000
i = train_dataset[ix]
images, masks = i['images'], i['masks']
image = images['post']
image = np.array(train_dataset.inverse_transform_image(image))

util.vis_im_mask(image, masks['damage'], size=(512*2,512*2), opacity=.3);

In [None]:
from collections import Counter
counts = Counter(len(i['pre']['features']) for i in train_dataset.instances)

In [None]:
counts[0]/sum(counts.values())