## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

## Install packages

In [None]:
! pip install terratorch gdown tensorboard

## Import modules

In [None]:
import os
import torch
from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.metrics import jaccard_score
from sklearn.metrics import f1_score
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

from dataset import FloodDataModule
from model import FloodRiskModel

## Download dataset

In [None]:
! git lfs install
! git clone https://huggingface.co/datasets/hk-kaden-kim/BlueSky-Challenge-Dataset

In [None]:
# Dataset statistics
DATASET_DIR = "BlueSky-Challenge-Dataset"
for split in ["train", "test", "val"]:
  png_files_cnt = 0
  csv_files_cnt = 0
  for root, dirs, files in os.walk(os.path.join(DATASET_DIR, split)):
      for file in files:
          if file.endswith('.png'):
              png_files_cnt += 1
          if file.endswith('.csv'):
              csv_files_cnt += 1
  print(f"{split}: {png_files_cnt} images and {csv_files_cnt} masks")

train: 3184 images and 3184 masks
test: 1332 images and 1332 masks
val: 2128 images and 2128 masks


# Inference

In [None]:
BATCH_SIZE = 4
CHECKPOINT = "..."
TiM = [
    "S1RTC",
    "DEM",
    "LULC",
    "NDVI",
    ]


In [None]:
# Initialize datamodule
datamodule = FloodDataModule(
    local_path=DATASET_DIR,
    batch_size=BATCH_SIZE,  # Adjust based on GPU memory
    num_workers=8
)
datamodule.setup("test")
sample = datamodule.test_dataset[0]

print(f"\nDataset Summary:")
print(f"Test samples: {len(datamodule.test_dataset)}")
print(f"Sample keys: {sample.keys()}")
print(f"Image shape: {sample['image'].shape}")
print(f"Mask shape: {sample['mask'].shape}")
print(f"Mask unique values: {torch.unique(sample['mask'])}")
print()

Validation samples: 1332

Dataset Summary:
Test samples: 1332
Sample keys: dict_keys(['image', 'mask'])
Image shape: torch.Size([3, 256, 256])
Mask shape: torch.Size([256, 256])
Mask unique values: tensor([0, 1])



In [None]:
# Initialize model
model = FloodRiskModel.load_from_checkpoint(CHECKPOINT, tim=TiM)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
_ = model.eval()

In [None]:
# Inference
all_results = []
with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(datamodule.test_dataloader())):

        imgs = batch['image'].to(device)
        logits = model(imgs).output

        probs = torch.softmax(logits, dim=1)
        confs = probs[:,1].cpu().numpy()

        pred_masks = torch.argmax(probs, dim=1).cpu()
        true_masks = batch['mask']

        # compute the starting index for this batch
        base_idx = batch_idx * datamodule.batch_size

        for i in range(imgs.size(0)):

            sample_idx = base_idx + i
            img_path   = datamodule.test_dataset.samples[sample_idx][0]

            pred_flatten  = pred_masks[i].numpy().flatten()
            true_flatten = true_masks[i].numpy().flatten()

            all_results.append({
                'image_path':     img_path,
                'accuracy':       accuracy_score(true_flatten, pred_flatten),
                'iou':            jaccard_score(true_flatten, pred_flatten, average='binary'),
                'f1':             f1_score(true_flatten, pred_flatten, average='binary'),
                'flood_pct_pred': pred_flatten.sum() / pred_flatten.size * 100,
                'flood_pct_gt':   true_flatten.sum() / true_flatten.size * 100,
                'flood_mask':     pred_masks[i].numpy(),
                'confidence':     confs[i]
            })

print(f"Tested {len(all_results)} samples.")
print(f"Mean Acc: {np.mean([r['accuracy'] for r in all_results]):.3f}")
print(f"Mean IoU: {np.mean([r['iou']     for r in all_results]):.3f}")
print(f"Mean F1:  {np.mean([r['f1']      for r in all_results]):.3f}")

In [None]:
TEST_SAMPLES = ['acbe', 'ahaw', 'ayvf', 'bctf', 'biqb', 'blad', 'bmwg', 'cafp', 'cbre', 'ccea',
                'cdhx', 'cnyv', 'crgs', 'cxsh', 'cyrz', 'czjn', 'ddwt', 'dhkk', 'dwge', 'dwgx',
                'dxqd', 'emqi', 'fajg', 'faqi', 'fdjr', 'fixa', 'fnho', 'gfky', 'gjeh', 'glsx',
                'gmwz', 'gnrt', 'gtrx', 'gvqo', 'gykv', 'hehl', 'hghn', 'hueg', 'igwf', 'ikfi',
                'iusn', 'iyqz', 'kmfc', 'lgyh', 'ljwq', 'lxlx', 'mjqy', 'mjwv', 'mkdo', 'mmnr',
                'mmwp', 'neqf', 'nezm', 'nrai', 'ntjc', 'ntxe', 'nwhe', 'ogpq', 'omqk', 'ourg',
                'paty', 'pdfj', 'pevf', 'pfmg', 'qasf', 'qdta', 'qszb', 'qxrg', 'rhcz', 'riwa',
                'roic', 'rxrr', 'skos', 'soul', 'uhce', 'wrex', 'xbqv', 'xfon', 'xsfm', 'ymhx',
                'ynon', 'yxeg', 'zean', 'zjkl', 'zkmf', 'zoan',]

test_samples = TEST_SAMPLES[:]

In [None]:
# Create subplots
rows = len(test_samples)
cols = 3
fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
if rows == 1:
    axes = [axes]

zone_id_all = [r['image_path'].name for r in all_results]
for i, zone_id in enumerate(test_samples):

  idx = zone_id_all.index(zone_id+'.png')

  # Prediction - Flood Risky Zone
  pred_arr = all_results[idx]['flood_mask']
  pred_rgba = np.zeros((*pred_arr.shape, 4), dtype=np.float32)
  pred_rgba[pred_arr == 1] = [1, 0, 0, 0.3]     # Red

  # Satellite RGB Imagery
  img_path = str(all_results[idx]['image_path'])
  image = Image.open(img_path).convert('RGB')
  image_resized = image.resize((pred_arr.shape[1], pred_arr.shape[0]), Image.NEAREST)
  image_resized = np.array(image_resized)

  # Ground Truth - Flood / Water Body / Water & Flood
  mask_path = img_path.replace('image','mask').replace('png','csv')
  mask_arr = pd.read_csv(mask_path,header=None).to_numpy()
  K = int(pred_arr.shape[0]/ mask_arr.shape[0])
  mask_arr = mask_arr.repeat(K, axis=0).repeat(K, axis=1)
  mask_rgba = np.zeros((*mask_arr.shape, 4), dtype=np.float32)
  mask_rgba[mask_arr == 1] = [1, 0, 0, 0.3]     # Red
  mask_rgba[mask_arr == 2] = [0, 0, 1, 0.3]     # Blue
  mask_rgba[mask_arr == 3] = [1, 1, 0, 0.3]     # Yellow

  # Plot
  axes[i][0].imshow(image_resized)
  axes[i][0].set_title(f"{zone_id}")
  axes[i][0].axis('off')

  axes[i][1].imshow(image_resized)
  axes[i][1].imshow(mask_rgba)
  axes[i][1].set_title(f'Ground Truth')
  axes[i][1].axis('off')

  axes[i][2].imshow(image_resized)
  axes[i][2].imshow(pred_rgba)
  axes[i][2].set_title(f'Predicted')
  axes[i][2].axis('off')

plt.tight_layout()
plt.show()


