In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from tqdm import tqdm_notebook as tqdm
test_dir = '../../datasets/xview/test'

In [None]:
from datetime import datetime
import os
submission_dir = f"../submissions/{str(datetime.now()).replace(' ', '_').replace(':', '-')}"
os.mkdir(submission_dir)

In [None]:
import yaml
class Config:
    def __init__(self, file_path, file_type='json'):
        with open(file_path) as f:
            if file_path.endswith('.json'):
                _conf = json.load(f)
            elif file_path.endswith('.yaml'):
                _conf = yaml.load(f)
        for k,v in _conf.items():
            setattr(self, k, v)

In [None]:
import json
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn
from xv.nn.layers import FrozenBatchNorm2d

def load_segmentation_model(conf_file, state_file):
    conf = Config(conf_file)

    segmentation_types = {
        'PSPNet': smp.PSPNet,
        'FPN': smp.FPN,
        'Linknet': smp.Linknet,
        'Unet': smp.Unet
    }

    model_classes = conf.nclasses

    model = segmentation_types[conf.segmentation_arch](
        conf.encoder,
        classes=model_classes,
        activation='sigmoid',
        attention_type=conf.attention
    )


    if conf.freeze_encoder_norm:
        model.encoder = FrozenBatchNorm2d.convert_frozen_batchnorm(model.encoder)

    if conf.freeze_decoder_norm:
        model.decoder = FrozenBatchNorm2d.convert_frozen_batchnorm(model.decoder)

    preprocess_fn = get_preprocessing_fn(conf.encoder)

    state_dict = torch.load(state_file)

    model.load_state_dict(state_dict)
    return model, preprocess_fn

In [None]:
conf_file = "../weights/b-run-20191115_213743-qoijsx0h/conf.json"
state_file = "../weights/b-run-20191115_213743-qoijsx0h/state_dict.pth"

model, preprocess_fn = load_segmentation_model(conf_file, state_file)
model = model.eval().cuda()

In [None]:
import ttach as tta
model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')

In [None]:
from glob import glob
pre_files = glob(f"{test_dir}/images/*pre*")

In [None]:
from PIL import Image
import numpy as np
import torch

def load_img(img_path, preprocess_fn):
    image = np.array(Image.open(img_path))
    image = preprocess_fn(image)
    image = image.transpose(2,0,1)
    image = image.astype(np.float32)
    return torch.Tensor(image[None])

In [None]:
with torch.no_grad():
    for f in tqdm(pre_files):
        i = load_img(f, preprocess_fn)
        out = model(i.cuda())[0][0] > 0
        out = out.cpu().numpy().astype(np.uint8)
        fid = f.split('_')[-1].replace(".png", "")
        Image.fromarray(out).save(f"{submission_dir}/test_localization_{fid}_prediction.png")

In [None]:
from xv.nn.nets import BoxClassifier

In [None]:
dmg_conf_file = "../weights/rcndcpcc/config-damage-od.yaml"
dmg_state_file = "../weights/rcndcpcc/state_dict.pth"

In [None]:
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

def load_damage_model(conf_file, state_file):
    conf = Config(conf_file)
    backbone = resnet_fpn_backbone(conf.backbone, True)
    model = BoxClassifier(backbone, conf.nclasses)
    state_dict = torch.load(state_file)
    model.load_state_dict(state_dict)
    model = model.eval().cuda()
    return model

def load_dmg_img(img_path, image_mean = (0.485, 0.456, 0.406), image_std = (0.229, 0.224, 0.225)):
    image = np.array(Image.open(img_path))
    image = image.astype(np.float32)
    image /= 255.
    image = (image-image_mean)/image_std
    image = image.transpose(2,0,1)
    return torch.Tensor(image[None])

In [None]:
post_files = glob(f"{test_dir}/images/*post*")

In [None]:
from imantics import Polygons

In [None]:
dmg_model = load_damage_model(dmg_conf_file, dmg_state_file)

In [None]:
import cv2
for f in tqdm(post_files):
    fid = f.split('_')[-1].replace(".png", "")
    mask_file = f"{submission_dir}/test_localization_{fid}_prediction.png"
    msk = np.array(Image.open(mask_file))
    polys = Polygons.from_mask(msk)
    
    polypoints = polys.points
    if polypoints:
        with torch.no_grad():
            boxes = torch.Tensor([[min(p[:,0]), min(p[:,1]), max(p[:,0]), max(p[:,1])] for p in polypoints])
            inp = load_dmg_img(f)
            out = dmg_model(inp.cuda(), [boxes.cuda()])
            classes = (out.argmax(1) + 1).cpu().numpy()
    else:
        classes = []

    msk_dmg = np.zeros((1024, 1024), dtype=np.uint8)
    
    for poly, cls in zip(polypoints, classes):
        cv2.fillPoly(msk_dmg, [poly], int(cls))

    Image.fromarray(msk_dmg).save(f"{submission_dir}/test_damage_{fid}_prediction.png")

In [None]:
_split = submission_dir.rfind('/') + 1
root, subdir = submission_dir[:_split], submission_dir[_split:]

! cd "{root}" && zip -r "{subdir}.zip" "{subdir}/"