In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import yaml
import wandb 

#os.environ['WANDB_MODE'] = 'dryrun'

run_type = 'building-damage'
conf_file = 'config-damage-od.yaml'

wandb.init(project=run_type, config=yaml.load(open(conf_file)))
conf = wandb.config

  # Remove the CWD from sys.path while we load stuff.


wandb: Wandb version 0.8.14 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [3]:
from glob import glob
from tqdm import tqdm_notebook as tqdm
import numpy as np
import torch
from torch import nn
from xv.util import vis_im_mask
from pprint import pprint

train_dir = '/home/jupyter/datasets/xview/train'
suppl_dir = '/home/jupyter/datasets/xview/tier3'
test_dir = '/home/jupyter/datasets/xview/test'

pprint(dict(conf))

{'add_supl': True,
 'amp_opt_level': 'O1',
 'aug_prob': 0.5,
 'backbone': 'resnet101',
 'batch_size': 6,
 'class_weight': [1, 3, 2, 2],
 'data_prefix': 'post',
 'epochs': 100,
 'filter_none': True,
 'loss_weights': {'dicemulti': 1},
 'lr': 1e-05,
 'metric': 'hmean:damage:categorical:f1',
 'min_bbox_visibility': 0.2,
 'nclasses': 4,
 'optim': 'adam',
 'pretrain_weights': '/home/jupyter/sky-eye/notebooks/wandb/run-20191106_084154-35ipugq1/final_state_dict.pth',
 'run_name': 'building_damage',
 'scheduler_factor': 0.5,
 'scheduler_patience': 5,
 'train_repeat': 2,
 'training_scales': [0.5, 0.75, 1.0]}


In [4]:
import albumentations as al
from albumentations import BboxParams

augment = al.Compose([
    al.HorizontalFlip(p=conf.aug_prob),
    al.VerticalFlip(p=conf.aug_prob),
    al.RandomRotate90(p=conf.aug_prob),
    al.Transpose(p=conf.aug_prob),
    al.RandomBrightnessContrast(p=conf.aug_prob),
    al.Rotate(p=conf.aug_prob),
    #al.RandomSizedBBoxSafeCrop(1024, 1024, erosion_rate=.1, p=conf.aug_prob)
],  bbox_params=BboxParams('pascal_voc', label_fields = ['labels'], min_visibility=conf.min_bbox_visibility))

In [5]:
import random
from xv import dataset

random.seed(hash("😂"))


all_files = glob(f'{train_dir}/labels/*{conf.data_prefix}_disaster.json')
random.shuffle(all_files)

dev_ix = int(len(all_files)*.20)
dev_files = all_files[:dev_ix]
train_files = all_files[dev_ix:]

train_instances = dataset.get_instances(train_files, filter_none=conf.filter_none)

dev_instances = dataset.get_instances(dev_files, filter_none=conf.filter_none)

len(train_instances), len(dev_instances)

HBox(children=(IntProgress(value=0, max=2240), HTML(value='')))




HBox(children=(IntProgress(value=0, max=559), HTML(value='')))




(1793, 448)

In [6]:
conf.add_suppl = False

wandb: Wandb version 0.8.14 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


In [7]:
if conf.add_suppl:
    train_instances *= conf.train_repeat
    suppl_files = glob(f'{suppl_dir}/labels/*{conf.data_prefix}_disaster.json')
    suppl_instances = dataset.get_instances(suppl_files, filter_none=conf.filter_none)
    train_instances += suppl_instances
    print(len(train_instances))

In [8]:
train_dataset = dataset.DamageClassificationDataset(
    train_instances,
    conf.nclasses,
    augment=augment
)

dev_dataset = dataset.DamageClassificationDataset(
    dev_instances,
    conf.nclasses,
    augment=None
)

In [9]:
def collate(batch):
    include = [len(bx) > 0 for _, bx, _ in batch]
    ims = torch.stack([torch.Tensor(ims) for ims, _, _ in batch])[include]
    bxs, clss = [], []
    for _, bx, cl in batch:
        if len(bx) == 0:
            continue
        bxs.append(torch.Tensor(bx))
        clss.append(torch.Tensor(cl))
    return ims, bxs, clss

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=conf.batch_size,
    shuffle=True,
    num_workers=10,
    collate_fn=collate
)

dev_loader = torch.utils.data.DataLoader(
    dev_dataset,
    batch_size=conf.batch_size,
    shuffle=False,
    num_workers=10,
    collate_fn=collate
)

In [4]:
import random
import torch
from torch import nn
from xv.nn.nets import BoxClassifier

class MultiScaleResize(nn.Module):
    def __init__(self, scales = (0.5, 0.75, 1.)):
        super().__init__()
        self.scales = scales
    
    def forward(self, x, boxes=None):
        scale = random.choice(self.scales)
        x = torch.nn.functional.interpolate(x, scale_factor=scale, mode='bilinear', align_corners=False)
        if not boxes:
            return x
        boxes = [b*scale for b in boxes]
        return x, boxes

In [11]:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

train_resize = MultiScaleResize(conf.training_scales)
backbone = resnet_fpn_backbone(conf.backbone, True)
model = BoxClassifier(backbone, conf.nclasses)
model = model.cuda()

In [14]:
if conf.pretrain_weights:
    state_dict = torch.load(conf.pretrain_weights)
    model.load_state_dict(state_dict)

In [15]:
import apex
optims = {'adam': torch.optim.Adam}
optim = optims[conf.optim](model.parameters(), lr=conf.lr)

In [16]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, factor=conf.scheduler_factor, patience=conf.scheduler_patience
)

In [17]:
from pytorch_toolbelt import losses
from torch.nn.modules.loss import CrossEntropyLoss
from xv.nn.losses import loss_dict, WeightedLoss
from torch.nn.modules.loss import CrossEntropyLoss


#loss_fn = WeightedLoss({loss_dict[l]():w for l, w in conf.loss_weights.items()})

#loss_fn = losses.JaccardLoss('multiclass')
#loss_fn = CrossEntropyLoss(weights)

if 'class_weight' in dict(conf):
    weights = torch.Tensor(conf.class_weight).float().cuda()
    loss_fn = CrossEntropyLoss(weights)

In [18]:
from tqdm import tqdm_notebook as tqdm
import pdb
import logging

def run(model, optim, data, train_resize, loss_fn):
    model = model.train()
    total_loss = 0.
    for im, boxes, clss in tqdm(iter(data)):
        if im.shape[0] == 0:
            logging.warning("Empty batch.")
            continue
        im, boxes = train_resize(im, boxes)
        optim.zero_grad()
        out = model(im.cuda(), [b.cuda() for b in boxes])
        loss = loss_fn(out, torch.cat(clss).long().cuda())
        loss.backward()
        total_loss += loss
        optim.step()
    return {'train_loss': total_loss/len(train_loader)}

import scipy
from collections import defaultdict
from xv.run import get_metrics_for_counts

def weighted_tp_fp_fn(pred, targ, weights, c):
    tp = (np.logical_and(pred == c, targ == c) * weights).sum()
    fp = (np.logical_and(pred != c, targ == c) * weights).sum()
    fn = (np.logical_and(pred == c, targ != c) * weights).sum()
    return tp, fp, fn

def evaluate(model, data, nclasses, loss_fn):
    model.eval()
    nclasses = conf.nclasses
    loss_sum = 0.
    tps, fps, fns = [0. for _ in range(nclasses)], [0. for _ in range(nclasses)], [0. for _ in range(nclasses)]
    with torch.no_grad():
        for im, boxes, clss in tqdm(data):
            res = im.shape[-1]
            out = model(im.cuda(), [b.cuda() for b in boxes])
            clss = torch.cat(clss).long()
            loss_sum += loss_fn(out, clss.cuda())

            out_ix = np.array(out.argmax(1).cpu())
            clss = clss.cpu()
            boxes_flat = torch.cat(boxes)
            areas = (boxes_flat[:,2] - boxes_flat[:,0]) * (boxes_flat[:,3] - boxes_flat[:,1])

            for cl in range(nclasses):
                tp, fp, fn = weighted_tp_fp_fn(out_ix, clss, areas, cl)
                tps[cl] += tp
                fps[cl] += fp
                fns[cl] += fn

    metrics = {}
    metrics['loss'] = loss_sum / len(data)
    
    aggregate = defaultdict(list)
    for ix in range(nclasses):
        categorical_ix_metrics =  get_metrics_for_counts(tps[ix], fps[ix], fns[ix])
        for k,v in categorical_ix_metrics.items():
            metrics[f'damage:categorical:{ix}:{k}'] = v
            aggregate[f'damage:categorical:{k}'].append(v)

    hmean = {f'hmean:{k}': scipy.stats.hmean(v) if all(v) else 0. for k,v in aggregate.items()}
    metrics.update(hmean)

    mean = {f'mean:{k}':scipy.mean(v) for k,v in aggregate.items()}
    metrics.update(mean)
    
    return metrics

In [19]:
epoch, best_score = 0, 0

In [None]:
for epoch in range(epoch, conf.epochs):
    metrics = {'epoch': epoch}
    train_metrics = run(model, optim, train_loader, train_resize, loss_fn)
    metrics.update(train_metrics)
    
    dev_metrics = evaluate(model, dev_loader, conf.nclasses, loss_fn)
    metrics.update(dev_metrics)
    
    wandb.log(metrics)
    scheduler.step(metrics['loss'])
    score = metrics[conf.metric]
    
    if score > best_score:
        torch.save(model.state_dict(), os.path.join(wandb.run.dir, "state_dict.pth"))
        best_score = score

HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




wandb: Wandb version 0.8.14 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))




HBox(children=(IntProgress(value=0, max=75), HTML(value='')))




HBox(children=(IntProgress(value=0, max=299), HTML(value='')))

In [19]:
torch.save(model.state_dict(), os.path.join(wandb.run.dir, "final_state_dict.pth"))

In [23]:
f'{wandb.run.dir}/final_state_dict.pth'

'/home/jupyter/sky-eye/notebooks/wandb/run-20191106_084154-35ipugq1/final_state_dict.pth'