In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import wandb
api = wandb.Api()
def get_score(r, metric):
    hist = list(r.scan_history())
    scores = [h[METRIC] for h in hist if METRIC in h]
    if not scores: return None
    return max(scores)

In [None]:
import wandb
from xv import io
from xv.io import Config
import torch
from tqdm.notebook import tqdm
from datetime import datetime
import os

In [None]:
dmg_is_seg = True

seg_run_id = "qoijsx0h"
dmg_run_id = "0gvvydkt"

In [None]:
from datetime import datetime
import os
test_dir = '../../datasets/xview/test'
submission_dir = f"../submissions/{seg_run_id}_{dmg_run_id}_{str(datetime.now()).replace(' ', '_').replace(':', '-')}"

cache_dir = f'../submissions/cache/{seg_run_id}'
os.mkdir(submission_dir)
print(submission_dir)

In [None]:
import shutil

In [None]:
cache_dir

In [None]:
import ttach as tta
from glob import glob
from PIL import Image
import numpy as np

USE_CACHED_SEG = True
CACHE_SEG = False

if USE_CACHED_SEG:
    assert os.path.isdir(cache_dir)
    !cp {cache_dir}/* {submission_dir}
    
else:
    if CACHE_SEG:
        if os.path.isdir(cache_dir):
            shutil.rmtree(cache_dir)
        os.mkdir(cache_dir)

    seg_run_path = f"xvr-hlt/sky-eye-full/{seg_run_id}"
    conf_file = wandb.restore('config.yaml', run_path=seg_run_path, replace=True).name
    state_file = wandb.restore('state_dict.pth', run_path=seg_run_path, replace=True).name

    conf = Config(conf_file)
    model, preprocess_fn = io.load_segmentation_model(conf, state_file)
    model = model.eval().cuda()
    model = tta.SegmentationTTAWrapper(model, tta.aliases.d4_transform(), merge_mode='mean')
    pre_files = glob(f"{test_dir}/images/*pre*")
    
    with torch.no_grad():
        for f in tqdm(pre_files):
            i = io.load_img(f, preprocess_fn)
            out = model(i.cuda())[0][0] > 0.
            out = out.cpu().numpy().astype(np.uint8)
            fid = f.split('_')[-1].replace(".png", "")
            Image.fromarray(out).save(f"{submission_dir}/test_localization_{fid}_prediction.png")
            if CACHE_SEG:
                Image.fromarray(out).save(f"{cache_dir}/test_localization_{fid}_prediction.png")

In [None]:
from xv.tta import BoxClassifierTTA
import ttach as tta

if dmg_is_seg:
    dmg_run_path = f"xvr-hlt/building-seg-damage/{dmg_run_id}"
else:
    dmg_run_path = f"xvr-hlt/building-damage/{dmg_run_id}"

dmg_conf_file = wandb.restore('config.yaml', run_path=dmg_run_path, replace=True).name
dmg_state_file = wandb.restore('state_dict.pth', run_path=dmg_run_path, replace=True).name

dmg_conf = Config(dmg_conf_file)

if dmg_is_seg:
    dmg_model, dmg_preprocess_fn = io.load_segmentation_model(dmg_conf, dmg_state_file)
    dmg_model = tta.SegmentationTTAWrapper(dmg_model, tta.aliases.d4_transform(), merge_mode='mean')
else:
    dmg_model = io.load_damage_model(dmg_conf, dmg_state_file)
    dmg_model = BoxClassifierTTA(dmg_model)
    
dmg_model = dmg_model.eval().cuda()

In [None]:
from glob import glob 
from PIL import Image
import numpy as np
import cv2
from imantics import Polygons

post_files = glob(f"{test_dir}/images/*post*")

In [None]:
@torch.no_grad()
def get_box_damage(dmg_model, inp_file, mask_file):
    msk_dmg = np.zeros((1024, 1024), dtype=np.uint8)
    
    msk = np.array(Image.open(mask_file))

    polys = Polygons.from_mask(msk)
    polypoints = polys.points
    
    if not polypoints:
        return Image.fromarray(msk_dmg)
    
    inp = io.load_dmg_img(inp_file)
    boxes = torch.Tensor([[min(p[:,0]), min(p[:,1]), max(p[:,0]), max(p[:,1])] for p in polypoints])
    out = dmg_model(inp.cuda(), [boxes.cuda()])
    classes = (out.argmax(1) + 1).cpu().numpy()

    for poly, cls in zip(polypoints, classes):
        cv2.fillPoly(msk_dmg, [poly], int(cls))

    return Image.fromarray(msk_dmg)

@torch.no_grad()
def get_seg_damage(dmg_model, preprocess_fn, inp_file, mask_file):
    inp = io.load_img(f, dmg_preprocess_fn)
    out = dmg_model(inp.cuda())[0]
    out = (out.argmax(0) + 1).cpu().numpy()
    return Image.fromarray(out.astype(np.uint8))

@torch.no_grad()
def get_seg_damage(dmg_model, preprocess_fn, inp_file, mask_file):
    mb = np.array(Image.open(mask_file))
    inp = io.load_img(f, dmg_preprocess_fn)
    out = dmg_model(inp.cuda())[0]
    
    out = out.sigmoid()/out.sigmoid().sum(0)
    out = out.cpu().numpy()
    
    damage_map = np.zeros((1024,1024))
    
    for poly in Polygons.from_mask(mb):
        poly_mask = Polygons.create([poly]).mask(1024, 1024).array
        cls = out[:,poly_mask].mean(1).argmax() + 1
        damage_map[poly_mask] = cls
    
    return Image.fromarray(damage_map.astype(np.uint8))

In [None]:
for f in tqdm(post_files):
    fid = f.split('_')[-1].replace(".png", "")
    mask_file = f"{submission_dir}/test_localization_{fid}_prediction.png"
    
    if dmg_is_seg:
        img = get_seg_damage(dmg_model, dmg_preprocess_fn, f, mask_file)
    else:
        img = get_box_damage(dmg_model, f, mask_file)

    img.save(f"{submission_dir}/test_damage_{fid}_prediction.png")

In [None]:
_split = submission_dir.rfind('/') + 1
root, subdir = submission_dir[:_split], submission_dir[_split:]

! cd "{root}" && zip -r "{subdir}.zip" "{subdir}/"

In [None]:
!rm config.yaml state_dict.pth