# Notebook to train world floods model
    steps:
        1. Create dataloader and transforms
        2. Create model and pytorch lightening trainer
        3. Set up logging
        4. Save model to bucket

In [17]:
import sys, os
from pathlib import Path
from pyprojroot import here
# spyder up to find the root
root = here(project_files=[".here"])
# append to path
sys.path.append(str(here()))

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Step 1: Setup Configuration file

In [18]:
from src.models.config_setup import get_default_config
config_fp = os.path.join(root, 'src', 'models', 'configurations', 'worldfloods_template.json')
config = get_default_config(config_fp)

Loaded Config for experiment:  worldfloods_demo_test2
{   'data_params': {   'bands': 'all',
                       'batch_size': 32,
                       'bucket_id': 'ml4floods',
                       'image_count': 3,
                       'input_folder': 'S2',
                       'loader_type': 'bucket',
                       'path_to_splits': 'worldfloods/public',
                       'target_folder': 'gt',
                       'test_transformation': {   'normalize': True,
                                                  'num_classes': 3,
                                                  'totensor': True,
                                                  'use_channels': 'all'},
                       'train_transformation': {   'normalize': True,
                                                   'num_classes': 3,
                                                   'totensor': True,
                                                   'use_channels': 'all'},
            

In [19]:
from pytorch_lightning import seed_everything
# Seed
seed_everything(config.seed)

Global seed set to 12


12

### Step 1.b: Make it a unique experiment

In [20]:
config.experiment_name = 'worldfloods-notebook-training-demo'

### Step 2: Setup Dataloader

In [22]:
from src.models.dataset_setup import get_dataset

config.data_params.loader_type = 'bucket'
dataset = get_dataset(config.data_params)

Using remote bucket storate dataset for this run


RasterioIOError: '/vsigs/ml4floods/worldfloods/public/train/S2/EMSR265_18ARCISSURAUBE_DEL_v1_observed_event_a.tif' not recognized as a supported file format.

### Step 3: Setup Model

In [6]:
from src.models.model_setup import get_model
config.model_params.test = False
config.model_params.train = True
model = get_model(config.model_params)

13 3


### Step 4: WandB Logger (Replace with your wandb info)

In [7]:
import wandb
from pytorch_lightning.loggers import WandbLogger

config['wandb_entity'] = 'sambuddinc'
config['wandb_project'] = 'worldfloods-notebook-demo'

wandb_logger = WandbLogger(
    name=config.experiment_name,
    project=config.wandb_project, 
    entity=config.wandb_entity
)

### Step 5: Setup Lightning Callbacks

In [8]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"

checkpoint_callback = ModelCheckpoint(
    filepath=f"{experiment_path}/checkpoint",
    save_top_k=True,
    verbose=True,
    monitor='dice_loss',
    mode='min',
    prefix=''
)

early_stop_callback = EarlyStopping(
    monitor='dice_loss',
    patience=10,
    strict=False,
    verbose=False,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

print(f"{config.model_params.model_folder}/{config.experiment_name}")

gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart/worldfloods-notebook-training-demo




### Step 6: Setup Lighting Trainer
    -- add flags from 
    https://pytorch-lightning.readthedocs.io/en/0.7.5/trainer.html 

In [9]:
from pytorch_lightning import Trainer
trainer = Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    callbacks=callbacks,
    default_root_dir=f"{config.model_params.model_folder}/{config.experiment_name}",
    accumulate_grad_batches=1,
    gradient_clip_val=0.0,
    auto_lr_find=False,
    benchmark=False,
    distributed_backend=None,
#     gpus=config.gpus,
    max_epochs=config.model_params.hyperparameters.max_epochs,
    check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
    log_gpu_memory=None,
    resume_from_checkpoint=None
)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores


In [10]:
trainer.fit(model, dataset)

Failed to query for notebook name, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable
[34m[1mwandb[0m: Currently logged in as: [33msambuddinc[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name    | Type         | Params
-----------------------------------------
0 | network | SimpleLinear | 42    
-----------------------------------------
42        Trainable params
0         Non-trainable params
42        Total params


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



Epoch 0:   0%|          | 0/6187 [00:00<?, ?it/s]                     



Epoch 0:   0%|          | 11/6187 [03:16<30:39:58, 17.88s/it, loss=0.799, v_num=y7l7]

  s = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
Epoch 0, global step 10: dice_loss reached 0.78169 (best 0.78169), saving model to "gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart/worldfloods-notebook-training-demo/checkpoint-v0.ckpt" as top True


Epoch 0:   0%|          | 11/6187 [03:28<32:26:31, 18.91s/it, loss=0.799, v_num=y7l7]


RasterioIOError: '/vsigs/ml4floods/worldfloods/public/train/S2/EMSR271_05THESSALYOVERVIEW_DEL_MONIT01_v2_observed_event_a0000000000-0000000000.tif' not recognized as a supported file format.

### Step 7: Save trained model

In [None]:
from pytorch_lightning.utilities.cloud_io import atomic_save
atomic_save(model.state_dict(), f"{experiment_path}/model.pt")
torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model.pt'))
wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt'))
wandb.finish()

# Save cofig file in experiment_path
config_file_path = f"{experiment_path}/config.json"

if config_file_path.startswith("gs://"):
    from google.cloud import storage
    splitted_path = config_file_path.replace("gs://", "").split("/")
    bucket_name = splitted_path[0]
    blob_name = "/".join(splitted_path[1:])
    bucket = storage.Client().get_bucket(bucket_name)
    blob = bucket.blob(blob_name)
    blob.upload_from_string(
        data=json.dumps(config),
        content_type='application/json'
    )
else:
    with open(config_file_path, "w") as fh:
        json.dump(config, fh)