In [1]:
import sys, os
from pathlib import Path
import torch

torch.cuda.init()
if torch.cuda.get_device_capability()[0] >= 8: # ampere
    torch.backends.cuda.matmul.allow_tf32 = True # This flag defaults to False
    torch.backends.cudnn.allow_tf32 = True
torch.cuda.set_per_process_memory_fraction(0.9) # 限制最高gpu顯存使用率，0-1之間浮點數，1 == 100%
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import datetime
now = datetime.datetime.now()
now = now.strftime("%Y-%m-%d-%H:%M:%S")

## Load flooding configuration file from local device or gcs

In [None]:
from ml4floods.models.config_setup import get_default_config
import pkg_resources

# Set filepath to configuration files
# config_fp = 'path/to/worldfloods_template.json'
config_fp = pkg_resources.resource_filename("ml4floods","models/configurations/worldfloods_template.json")
# 模型名稱

config = get_default_config(config_fp)
config

## Step 2: Setup Dataloader

In [None]:
config.experiment_name = 'new_1015' # only_water
config.data_params.channel_configuration = 'bgri'
config.data_params.batch_size = 24
config.model_params.hyperparameters.channel_configuration = 'bgri'
config.data_params.window_size = [256,256]
config.model_params.hyperparameters.num_channels = 4
config.model_params.hyperparameters.max_tile_size = 256
config.data_params.bucket_id = ""
config.model_params.hyperparameters.metric_monitor = 'val_dice_loss'
config.model_params.hyperparameters.weight_per_class = [1.93445299, 36.60054169, 2.19400729] 
# config.model_params.hyperparameters.weight_per_class = [1.0, 1.0, 1.0] 

# config.data_params.target_folder = 'gt'

In [None]:
%%time

# from ml4.dataset_setup import get_dataset
from ml4floods.models.dataset_setup import get_dataset
config.data_params.batch_size = 40 # control this depending on the space on your GPU! Hint: 8 with log, max about 20
config.data_params.loader_type = 'local'
config.data_params.path_to_splits = "/mnt/d/Flooding//worldfloods_v1_0" # local folder to download the data
config.data_params.train_test_split_file = "/mnt/d/Flooding/train_test_split.json"
config.data_params["download"] = {"train": True, "val": True, "test": True} # download only test data
# config.data_params.train_test_split_file = "2_PROD/2_Mart/worldfloods_v1_0/train_test_split.json" # use this to train with all the data
config.data_params.num_workers = 16

# If files are not in config.data_params.path_to_splits this will trigger the download of the products.
dataset = get_dataset(config.data_params)
dataset

## Verfify data loader

#### Verify training data
Data format here: https://github.com/spaceml-org/ml4floods/blob/891fe602880586e7ac821d2f282bf5ec9d4c0795/ml4floods/data/worldfloods/dataset.py#L106

In [None]:
train_dl = dataset.train_dataloader()
train_dl_iter = iter(train_dl)
print(len(train_dl_iter))
batch_train = next(train_dl_iter)

batch_train["image"].shape, batch_train["mask"].shape

Verify validation data

In [None]:
val_dl = dataset.val_dataloader()

val_dl_iter = iter(val_dl)
print(len(val_dl_iter))
batch_val = next(val_dl_iter)

batch_val["image"].shape, batch_val["mask"].shape

In [None]:
test_dl = dataset.test_dataloader()

test_dl_iter = iter(test_dl)
print(len(test_dl_iter))

batch_test = next(test_dl_iter)
batch_test["image"].shape, batch_test["mask"].shape

### Plot batch by using ml4flood model 
check detail here: https://github.com/spaceml-org/ml4floods/blob/891fe602880586e7ac821d2f282bf5ec9d4c0795/ml4floods/data/worldfloods/dataset.py#L106

In [None]:
import importlib
import matplotlib.pyplot as plt
from models import flooding_model
flooding_model = importlib.reload(flooding_model)

## Step 3: Setup Model

In [None]:
 # folder to store the trained model (it will create a subfolder with the name of the experiment)
config.model_params

In [None]:
config.model_params.model_folder = "train_models" 
os.makedirs("train_models", exist_ok=True)
config.model_params.test = False
config.model_params.train = True
config.model_params.hyperparameters.model_type = "res2vtunet" 
# Currently implemented: unet2, unet_xception, unet_3+, unet_3+_deepsub, unet_s2, unet_sep_s2, hunet
# res2_unet, res2_daunet, attunet, res2_attunet, daunet, res2_saunet, res2rdn_attunet, res2_attunet_sup, simp_res2unet
# Transformer: transunet, cvtunet, res2vtunet, deeplabv3+
# Compare: malunet, deeplabv3+, swinunet, mtunet, utnet
# config.model_params.hyperparameters.num_channels = 3

In [None]:
import copy
from models.flooding_model import WorldFloodsModel, DistilledTrainingModel, WorldFloodsModel2, WorldFloodsModel3,WorldFloodsModel1,WorldFloodsModel0, WorldFloodsModel_Sup, WorldFloodsModel_AdaCh
importlib.reload(flooding_model)
simple_model_params = copy.deepcopy(config.model_params)
simple_model_params['hyperparameters']['model_type']="res2vtunet"

# gt ={"train": batch_train['mask'], "val":batch_val['mask']}
# model = DistilledTrainingModel(config.model_params, simple_model_params)
# model = WorldFloodsModel(config.model_params)
model = WorldFloodsModel_Sup(config.model_params)
net = model.network
net

In [None]:
# Compuatation complexity of network
from ptflops import get_model_complexity_info
# torch.ones()返回一個由標量值1填充的張量，其形狀由變量參數size定義
# x = torch.ones((config.model_params.hyperparameters.num_channels, 
#                 config.model_params.hyperparameters.max_tile_size, 
#                 config.model_params.hyperparameters.max_tile_size))
# gt ={"train": batch_train['mask'], "val":batch_val['mask']}
# mask_train = (batch_train['mask'][1:])
macs, params = get_model_complexity_info(net, (config.model_params.hyperparameters.num_channels, 
                config.model_params.hyperparameters.max_tile_size, 
                config.model_params.hyperparameters.max_tile_size), as_strings=True, print_per_layer_stat=True, verbose=True)
print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

In [None]:
setup_weights_and_biases = True #False

### new: switch account
wandb_switch = True
old_api = 'bc1eeeb7d29c933290904edebfb0ec8da6e980fe'
new_api = 'add1f4d247a20ee1271f026e3cca5255f4053a61'
# key = "add1f4d247a20ee1271f026e3cca5255f4053a61"
#####

if setup_weights_and_biases:


    import wandb
    from pytorch_lightning.loggers import WandbLogger


    config_wb = {'wandb_entity': 'ntustyuyu',
                    'wandb_project': 'new_1012',
                    'experiment_name': '32_res2vtunet_dice+biou'}

    wandb.login(key = new_api)
    wandb.init(name=config_wb['experiment_name'],
                project=config_wb['wandb_project'], 
                entity=config_wb['wandb_entity'])

    wandb_logger = WandbLogger(name=config_wb['experiment_name'],
                                project=config_wb['wandb_project'], 
                                entity=config_wb['wandb_entity']
                                # resume = 'allow'
                              )
    

###
    # if wandb_switch:
    #     with open('/home/viplab/.netrc', 'w') as f2:
    #         f2.write(f_org)
    #         f2.close()
#####

else:
    wandb_logger = None

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
# ModelCheckpoint是Pytorch Lightning中的一个Callback，它就是用于模型缓存的。
# 它会监视某个指标，每次指标达到最好的时候，它就缓存当前模型

experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"
checkpoint_path = f"{experiment_path}/checkpoint/{config.model_params.hyperparameters.model_type}-{now}" #checkpoint

#clean empty folder
import glob
folder_list = glob.glob(f"{experiment_path}/checkpoint/*")
for folder in folder_list:
    try:
        if not any(os.scandir(folder)):
            os.removedirs(folder)
    except:
        continue
if not os.path.isdir(checkpoint_path):
    # os.mkdir(checkpoint_path)
    os.makedirs(checkpoint_path)

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_path, # 模型緩存目錄
    save_top_k=True,
    verbose=True,
    monitor='val_dice_loss', # 我們需要監視的指標
    mode='min',
#     prefix=''
)

early_stop_callback = EarlyStopping(
    monitor='val_dice_loss',
    patience=10,
    strict=False,
    verbose=False,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

print(f"The trained model will be stored in {config.model_params.model_folder}/{config.experiment_name}")

In [None]:
from pytorch_lightning import Trainer

config.gpus = 2 # which gpu to use
# config.gpus = None # to not use GPU
config.model_params.hyperparameters.max_epochs = 30 # train for maximum 4 epochs
# checkpoint_pth = "/home/viplab/VipLabProjects/yuyu/yuyu_38/train_models/new_1012/checkpoint/res2vtunet-2023-10-12-16:13:48/epoch=22-step=111642.ckpt"


trainer = Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    callbacks=callbacks,
    default_root_dir=f"{config.model_params.model_folder}/{config.experiment_name}",
    accumulate_grad_batches=1,
    gradient_clip_val=0.0,
    auto_lr_find=False,
    benchmark=False,
    max_epochs=config.model_params.hyperparameters.max_epochs,
    check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
    strategy='dp',
    accelerator='gpu',
    devices=config.gpus
    # resume_from_checkpoint=checkpoint_pth
 )


In [None]:
trainer.fit(model, dataset)

In [None]:
logits= model(batch_val["image"].to(model.device))
print(f"Shape of logits: {logits.shape}")
probs = torch.softmax(logits, dim=1)
print(f"Shape of probs: {probs.shape}")
prediction = torch.argmax(probs, dim=1).long().cpu()
print(f"Shape of prediction: {prediction.shape}")

In [None]:
config.model_params.max_tile_size = config.model_params.hyperparameters.max_tile_size
config

In [None]:
import numpy as np
from ml4floods.models.utils import metrics
from ml4floods.models.model_setup import get_model_inference_function

import pandas as pd
from torch.utils.data import DataLoader,TensorDataset
from models.dataset_rgbih import RGBIH_Dataset

model.to("cuda")
# print(model[0])
# logits, all_loss = model
inference_function = get_model_inference_function(model, config, apply_normalization=False, activation="softmax")
# print(inference_function.dim)

dl = dataset.val_dataloader() # pytorch Dataloader
# print(type(dl))
thresholds_water = [0,1e-3,1e-2]+np.arange(0.5,.96,.05).tolist() + [.99,.995,.999]

mets = metrics.compute_metrics(
    dl,  # dl
    inference_function, 
    thresholds_water=thresholds_water, 
    plot=False, convert_targets=False)

label_names = ["land", "water", "cloud"]
metrics.plot_metrics(mets, label_names)

In [None]:
if hasattr(dl.dataset, "image_files"):
    cems_code = [os.path.basename(f).split("_")[0] for f in dl.dataset.image_files]
else:
    cems_code = [os.path.basename(f.file_name).split("_")[0] for f in dl.dataset.list_of_windows]

iou_per_code = pd.DataFrame(metrics.group_confusion(mets["confusions"],cems_code, metrics.calculate_iou,
                                                    label_names=[f"IoU_{l}"for l in ["land", "water", "cloud"]]))

recall_per_code = pd.DataFrame(metrics.group_confusion(mets["confusions"],cems_code, metrics.calculate_recall,
                                                       label_names=[f"Recall_{l}"for l in ["land", "water", "cloud"]]))

join_data_per_code = pd.merge(recall_per_code,iou_per_code,on="code")
join_data_per_code = join_data_per_code.set_index("code")
join_data_per_code = join_data_per_code*100
print(f"Mean values across flood events: {join_data_per_code.mean(axis=0).to_dict()}")
join_data_per_code

model_rgbnir_worldflood_transunet_epoch10# import torch
from pytorch_lightning.utilities.cloud_io import atomic_save
from ml4floods.models.config_setup import save_json

# Save in the cloud and in the wandb logger save dir
atomic_save(model.state_dict(), f"{experiment_path}/model_rgbnir_worldflood_unet3+_epoch20.pt")
# Save cofig file in experiment_path
config_file_path = f"{experiment_path}/config_rgbnir_worldflood_unet3+_epoch20.json"
save_json(config, config_file_path)

In [None]:
torch.save(model.state_dict(),f"{experiment_path}/model_rgbnir_worldfloodsup_res2vtunet.pt")
# Save cofig file in experiment_path
config_file_path = f"{experiment_path}/config_rgbnir_worldfloodsup_res2vtunet.json"
import json
with open(config_file_path, 'w') as f:
    json.dump(config, f)

In [None]:
if setup_weights_and_biases:
    torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model_rgbnir_worldfloodsup_res2vtunet.pt'))
    wandb.save(os.path.join(wandb_logger.save_dir, 'model_rgbnir_worldfloodsup_res2vtunet.pt')) # Copy weights to weights and biases server
    wandb.finish()

    wandb.login(key = old_api)
    wandb.finish()

In [None]:
n_image_start=0
n_images=8
count=int(n_images-n_image_start)
fig, axs = plt.subplots(4, count, figsize=(18,14),tight_layout=True)
importlib.reload(flooding_model)
flooding_model.plot_batch(batch_val["image"][n_image_start:n_images],channel_configuration="bgri",axs=axs[0],max_clip_val=3500.)
flooding_model.plot_batch(batch_val["image"][n_image_start:n_images],channel_configuration="bgri",bands_show=["B8","B8", "B8"],axs=axs[1],max_clip_val=3500.)
# flooding_model.plot_batch(batch_val["image"][:n_images],bands_show=["B11","B8", "B4"],axs=axs[1],max_clip_val=4500.)
flooding_model.plot_batch_output_v1(batch_val["mask"][n_image_start:n_images, 0],axs=axs[2], show_axis=True)
flooding_model.plot_batch_output_v1(prediction[n_image_start:n_images] + 1,axs=axs[3], show_axis=True)


for ax in axs.ravel():
    ax.grid(False)