In [None]:
from tqdm.notebook import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
import torch
import ttach as tta
import torch.nn as nn
import segmentation_models_pytorch as smp

In [None]:
data_folder = Path.cwd() / 'sentinel_one_test_data' / 'tiles'

df_paths = pd.DataFrame({
    'vv_image_path': list(data_folder.glob('vv/*')),
    'vh_image_path': list(data_folder.glob('vh/*')),
    'flood_label_path': list(data_folder.glob('flood_label/*'))
})


In [None]:
def s1_to_rgb(vv_image, vh_image):
    ratio_image = np.clip(np.nan_to_num(vh_image/vv_image, 0), 0, 1)
    rgb_image = np.stack((vv_image, vh_image, 1-ratio_image), axis=2)
    return rgb_image

class ETCIDataset(Dataset):
    def __init__(self, dataframe, split, transform=None):
        self.split = split
        self.dataset = dataframe
        self.transform = transform

    def __len__(self):
        return self.dataset.shape[0]


    def __getitem__(self, index):
        example = {}
        
        df_row = self.dataset.iloc[index]

        print(df_row['vv_image_path'])

        # load vv and vh images
        vv_image = cv2.imread(str(df_row['vv_image_path']), 0) / 255.0
        vh_image = cv2.imread(str(df_row['vh_image_path']), 0) / 255.0
        
        # convert vv and ch images to rgb
        rgb_image = s1_to_rgb(vv_image, vh_image)

        if self.split == 'test':
            # no flood mask should be available
            example['image'] = rgb_image.transpose((2,0,1)).astype('float32')
        else:
            # load ground truth flood mask
            flood_mask = cv2.imread(df_row['flood_label_path'], 0) / 255.0

            # compute transformations
            if self.transform:
                augmented = self.transform(image=rgb_image, mask=flood_mask)
                rgb_image = augmented['image']
                flood_mask = augmented['mask']

            example['image'] = rgb_image.transpose((2,0,1)).astype('float32')
            example['mask'] = flood_mask.astype('int64')

        return example
    
etci_dataset = ETCIDataset(df_paths, split='test', transform=None)

# Original NVIDIA settings
# batch_size = 96 * torch.cuda.device_count()
# num_workers=os.cpu_count(),
batch_size = 1
num_workers=0

data_loader = DataLoader(
    etci_dataset, batch_size=batch_size, shuffle=False, 
    num_workers=num_workers,
    pin_memory=True
)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def get_predictions_single(model_def, weights):
#   model_def.load_state_dict(torch.load(weights))
    model_def.load_state_dict(torch.load(weights, map_location=torch.device('cpu')))
    model = tta.SegmentationTTAWrapper(model_def, tta.aliases.d4_transform(), merge_mode='mean') # mean yields the best results
    model.to(device)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    final_predictions = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(data_loader):
            # load image and mask into device memory
            image = batch['image'].to(device)

            # pass images into model
            pred = model(image)

            # add to final predictions
            final_predictions.append(pred.detach().cpu().numpy())

    final_predictions = np.concatenate(final_predictions, axis=0)
    
    return final_predictions

In [None]:
unet_mobilenet = smp.Unet(
    encoder_name="mobilenet_v2", 
    encoder_weights=None, 
    in_channels=3,                  
    classes=2                      
)

upp_mobilenet = smp.UnetPlusPlus(
    encoder_name="mobilenet_v2", 
    encoder_weights=None, 
    in_channels=3,                  
    classes=2                      
)

unet_pseudo_round2 = smp.Unet(
    encoder_name="mobilenet_v2", 
    encoder_weights=None, 
    in_channels=3,                  
    classes=2                      
)

model_defs = [unet_mobilenet, upp_mobilenet, unet_pseudo_round2]

model_paths = [
    Path.cwd() / 'models' / 'unet_mobilenet_v2_0.pth',
    Path.cwd() / 'models' / 'upp_mobilenetv2_0.pth',
    Path.cwd() / 'models' / 'unet_pseudo_mobilenetv2_round2_0.pth'
]

In [None]:
all_preds = []

for defi, path in zip(model_defs, model_paths):
    all_preds.append(get_predictions_single(defi, path))
    
all_preds = np.array(all_preds)
all_preds = np.mean(all_preds, axis=0)
class_preds = all_preds.argmax(axis=1).astype('uint8')

save_path = Path.cwd() / 'sentinel_one_test_data' / 'output' / 'submission.npy'
np.save(save_path, class_preds, fix_imports=True, allow_pickle=False)

In [None]:
image_num = 4

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1,4, figsize=(15, 15))

ax1.set_title('vh')
ax1.imshow(cv2.imread(str(df_paths['vh_image_path'][image_num]), 0) / 255.0)

ax2.set_title('vv')
ax2.imshow(cv2.imread(str(df_paths['vv_image_path'][image_num]), 0) / 255.0)

flood_labels = cv2.imread(str(df_paths['flood_label_path'][image_num]), 0) / 255
ax3.set_title('label')
ax3.imshow(flood_labels)

flood_preds = class_preds[image_num]
ax4.set_title('Pred')
ax4.imshow(flood_preds)

intersection = np.logical_and(flood_labels, flood_preds)
union = np.logical_or(flood_labels, flood_preds)
iou_score = np.sum(intersection) / np.sum(union)
print(f"IOU score: {iou_score:.2}")

In [None]:
flood_labels = cv2.imread(str(df_paths['flood_label_path'][image_num]), 0) / 255
flood_preds = class_preds[image_num]

intersection = np.logical_and(flood_labels, flood_preds)
union = np.logical_or(flood_labels, flood_preds)
iou_score = np.sum(intersection) / np.sum(union)
iou_score