# Setup

In [27]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Install packages

In [None]:
! pip install -U datasets huggingface_hub
! pip install terratorch==1.0.1 gdown tensorboard

## Import modules

In [38]:
import os
import torch

## Download the dataset

In [None]:
! git lfs install
! git clone https://huggingface.co/datasets/hk-kaden-kim/BlueSky-Challenge-Dataset

In [43]:
# Dataset statistics
DATASET_DIR = "BlueSky-Challenge-Dataset"
for split in ["train", "test", "val"]:
  png_files_cnt = 0
  csv_files_cnt = 0
  for root, dirs, files in os.walk(os.path.join(DATASET_DIR, split)):
      for file in files:
          if file.endswith('.png'):
              png_files_cnt += 1
          if file.endswith('.csv'):
              csv_files_cnt += 1
  print(f"{split}: {png_files_cnt} images and {csv_files_cnt} masks")

train: 3184 images and 3184 masks
test: 1332 images and 1332 masks
val: 2128 images and 2128 masks


# Training

In [71]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import pytorch_lightning as pl

from dataset import FloodDataModule
from model import FloodRiskModel

In [76]:
MAX_EPOCHS = 10
BATCH_SIZE = 8
TiM = ["S1RTC","DEM","LULC","NDVI"]

In [77]:
# Initialize datamodule
datamodule = FloodDataModule(
    local_path=DATASET_DIR,
    batch_size=BATCH_SIZE,  # Adjust based on GPU memory
    num_workers=8
)
datamodule.setup("fit")
sample = datamodule.train_dataset[0]

print(f"\nDataset Summary:")
print(f"Training samples: {len(datamodule.train_dataset)}")
print(f"Validation samples: {len(datamodule.val_dataset)}")
print(f"Sample keys: {sample.keys()}")
print(f"Image shape: {sample['image'].shape}")
print(f"Mask shape: {sample['mask'].shape}")
print(f"Mask unique values: {torch.unique(sample['mask'])}")
print()

Training samples: 3184
Validation samples: 2128

Dataset Summary:
Training samples: 3184
Validation samples: 2128

Testing dataset loading...
Sample keys: dict_keys(['image', 'mask'])
Image shape: torch.Size([3, 256, 256])
Mask shape: torch.Size([256, 256])
Mask unique values: tensor([0, 1])



In [None]:
# Initialize model
model = FloodRiskModel(num_classes=2, tim=TiM)
assert isinstance(model, pl.LightningModule)

# Setup callbacks
checkpoint_callback = ModelCheckpoint(
    dirpath="output/flood_risk/checkpoints/",
    filename="best-flood-iou-{epoch}-{val_iou:.3f}",
    monitor="val_iou",
    mode="max",
    save_top_k=3,
    save_last=True
)

logger = TensorBoardLogger("output/flood_risk/logs", name="flood_risk")

# Setup trainer
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    callbacks=[checkpoint_callback],
    logger=logger,
    log_every_n_steps=100,
    gradient_clip_val=1.0,
    enable_checkpointing=True,
    detect_anomaly=False,  # Disable anomaly detection to save memory
)

print(" Starting flood segmentation training with CSV masks...")
trainer.fit(model, datamodule=datamodule)
print(" Training completed!")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


 Starting flood segmentation training with CSV masks...


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training samples: 3184
Validation samples: 2128


INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                     | Params | Mode 
---------------------------------------------------------------
0 | model     | SemanticSegmentationTask | 381 M  | train
1 | train_iou | MulticlassJaccardIndex   | 0      | train
2 | val_iou   | MulticlassJaccardIndex   | 0      | train
3 | train_acc | MulticlassAccuracy       | 0      | train
4 | val_acc   | MulticlassAccuracy       | 0      | train
5 | train_f1  | MulticlassF1Score        | 0      | train
6 | val_f1    | MulticlassF1Score        | 0      | train
---------------------------------------------------------------
128 M     Trainable params
253 M     Non-trainable params
381 M     Total params
1,525.412 Total estimated model params size (MB)
752       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]