In [13]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

import rasterio
import numpy as np
from rasterio import plot as rasterioplt
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib.patches as mpatches

from typing import Optional, Tuple, Union

import os
import sys
sys.path.append('../..')
from src.models.modelmodule import WorldFloodsModel
from src.models.architectures.unets import UNet
from src.models.utils import model_setup, metrics
from src.models.utils.configuration import AttrDict


# Init wandb
import wandb
wandb.init(project="ml4floods-test", entity="sambuddinc")

0,1
bce_loss,0.16456
dice_loss,0.5807
epoch,9.0
_runtime,6945.0
_timestamp,1613673187.0
_step,19.0


0,1
bce_loss,▁▇█▆▄▁▁▁▁▁
dice_loss,█▅▁▂▆▅▅▅▅▄
epoch,▁▂▃▃▄▅▆▆▇█
_runtime,▁▁▁▁▁▁▁▁▁█
_timestamp,▁▁▁▁▁▁▁▁▁█
_step,▁▂▃▃▄▅▆▆▇█


In [5]:
@torch.no_grad()
def read_inference_pair(layer_name:str, window:Optional[Union[rasterio.windows.Window, Tuple[slice,slice]]], 
                        return_ground_truth: bool=False, channels:bool=None, 
                        return_permanent_water=True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, rasterio.Affine]:
    """
    Read a pair of layers from the worldfloods bucket and return them as Tensors to pass to a model, return the transform for plotting with lat/long
    
    Args:
        layer_name: filename for layer in worldfloods bucket
        window: window of layer to use
        return_ground_truth: flag to indicate if paired gt layer should be returned
        channels: list of channels to read from the image
        return_permanent_water: Read permanent water layer raster
    
    Returns:
        (torch_inputs, torch_targets, transform): inputs Tensor, gt Tensor, transform for plotting with lat/long
    """
    tiff_inputs = f"gs://ml4floods/worldfloods/tiffimages/S2/{layer_name}.tif"
    tiff_targets = f"gs://ml4floods/worldfloods/tiffimages/gt/{layer_name}.tif"

    with rasterio.open(tiff_inputs, "r") as rst:
        inputs = rst.read((np.array(channels) + 1).tolist(), window=window)
        # Shifted transform based on the given window (used for plotting)
        transform = rst.transform if window is None else rasterio.windows.transform(window, rst.transform)
        torch_inputs = torch.Tensor(inputs.astype(np.float32)).unsqueeze(0)
    
    if return_permanent_water:
        tiff_permanent_water = f"gs://ml4floods/worldfloods/tiffimages/PERMANENTWATERJRC/{layer_name}.tif"
        with rasterio.open(tiff_permanent_water, "r") as rst:
            permanent_water = rst.read(1, window=window)  
            torch_permanent_water = torch.tensor(permanent_water)
    else:
        torch_permanent_water = torch.zeros_like(torch_inputs)
        
    if return_ground_truth:
        with rasterio.open(tiff_targets, "r") as rst:
            targets = rst.read(1, window=window)
        
        torch_targets = torch.tensor(targets).unsqueeze(0)
    else:
        torch_targets = torch.zeros_like(torch_inputs)
    
    return torch_inputs, torch_targets, torch_permanent_water, transform


class DummyWorldFloodsDataset(torch.utils.data.Dataset):
    def __init__(self, layer_names, windows, channels):
        self.inputs = []
        self.targets = []
        self.permanent_water = []
        self.plot_transforms = []
        
        for i in range(len(layer_names)):
            torch_inputs, torch_targets, torch_permanent_water, transform = read_inference_pair(layer_names[i], windows[i], return_ground_truth=True, channels=channels[i])
            
            self.inputs.append(torch_inputs)
            self.targets.append(torch_targets)
            self.permanent_water.append(torch_permanent_water)
            self.plot_transforms.append(transform)
                    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        this_dict = {
            'input': self.inputs[idx],
            'target': self.targets[idx],
            'permanent_water': self.permanent_water[idx],
            'plot_transforms': self.plot_transforms[idx]
        }
        return self.inputs[idx].squeeze(), self.targets[idx].squeeze().long()

In [6]:
# Define some options
model_name = 'linear' # options: 'unet', 'linear', 'simplecnn'
channel_configuration_name = 'all'

opt = {
    'model': model_name,
    'device': 'cpu',
    'model_folder': f'../../src/models/checkpoints/{model_name}/', # TODO different channel configuration means different model
    'max_tile_size': 128,
    'num_class': 3,
    'channel_configuration' : channel_configuration_name,
    'num_channels': len(model_setup.CHANNELS_CONFIGURATIONS[channel_configuration_name]),
}
opt = AttrDict.from_nested_dicts(opt)

In [7]:
layer_names = ["EMSR333_02PORTOPALO_DEL_MONIT01_v1_observed_event_a", "EMSR347_07ZOMBA_DEL_v2_observed_event_a"]
windows = [(slice(256,256+256),slice(0,256)), (slice(256,256+256),slice(0,256))]
channels = [model_setup.CHANNELS_CONFIGURATIONS[opt.channel_configuration], model_setup.CHANNELS_CONFIGURATIONS[opt.channel_configuration]]

ds = DummyWorldFloodsDataset(layer_names, windows, channels)

dl = torch.utils.data.DataLoader(ds, batch_size=1)

2


In [14]:

model = WorldFloodsModel(network_architecture=UNet(opt.num_channels, opt.num_class), num_class=opt.num_class, weight_per_class=[0.120252 + 0.396639, 0.027322, .455787])

wandb_logger = WandbLogger(name=f"floodbusters-test")

trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
trainer.fit(model, dl, dl)

# Save model to wandb
torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model.pt'))

GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name    | Type | Params
---------------------------------
0 | network | UNet | 7.8 M 
---------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params


Epoch 0:  50%|█████     | 2/4 [00:15<00:15,  7.62s/it, loss=1.31, v_num=l99p]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/2 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████| 4/4 [00:17<00:00,  4.50s/it, loss=1.31, v_num=l99p]
Epoch 0: 100%|██████████| 4/4 [00:20<00:00,  5.24s/it, loss=1.31, v_num=l99p]
Epoch 1:  50%|█████     | 2/4 [00:14<00:14,  7.16s/it, loss=1.21, v_num=l99p]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/2 [00:00<?, ?it/s][A
Epoch 1: 100%|██████████| 4/4 [00:17<00:00,  4.29s/it, loss=1.21, v_num=l99p]
Epoch 1: 100%|██████████| 4/4 [00:20<00:00,  5.03s/it, loss=1.21, v_num=l99p]
Epoch 2:  50%|█████     | 2/4 [00:14<00:14,  7.18s/it, loss=1.14, v_num=l99p]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/2 [00:00<?, ?it/s][A
Epoch 2: 100%|██████████| 4/4 [00:17<00:00,  4.29s/it, loss=1.14, v_num=l99p]
Epoch 2: 100%|██████████| 4/4 [00:20<00:00,  5.02s/it, loss=1.14, v_num=l99p]
Epoch 3:  50%|█████     | 2/4 [00:14<