# Notebook to train world floods model
    steps:
        1. Create dataloader and transforms
        2. Create model and pytorch lightening trainer
        3. Set up logging
        4. Save model to bucket

In [28]:
import sys, os
from pathlib import Path
from pyprojroot import here
# spyder up to find the root
root = here(project_files=[".here"])
# append to path
sys.path.append(str(here()))

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "/opt/creds/ML4CC_creds.json"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Step 1: Setup Configuration file

In [29]:
from src.models.config_setup import get_default_config
config_fp = os.path.join(root, 'src', 'models', 'configurations', 'worldfloods_template.json')
config = get_default_config(config_fp)

Loaded Config for experiment:  worldfloods_demo_test
{   'data_params': {   'bands': 'all',
                       'batch_size': 32,
                       'bucket_id': 'ml4floods',
                       'image_count': 3,
                       'input_folder': 'S2',
                       'loader_type': 'local',
                       'path_to_splits': '/worldfloods/public',
                       'target_folder': 'gt',
                       'test_transformation': {   'normalize': True,
                                                  'num_classes': 3,
                                                  'totensor': True,
                                                  'use_channels': 'all'},
                       'train_transformation': {   'normalize': True,
                                                   'num_classes': 3,
                                                   'totensor': True,
                                                   'use_channels': 'all'},
             

In [30]:
from pytorch_lightning import seed_everything
# Seed
seed_everything(config.seed)

Global seed set to 12


12

### Step 1.b: Make it a unique experiment

In [31]:
config.experiment_name = 'worldfloods-notebook-training-demo'

### Step 2: Setup Dataloader

In [32]:
from src.models.dataset_setup import get_dataset

config.data_params.loader_type = 'local'
dataset = get_dataset(config.data_params)

Using local dataset for this run
Folder '/worldfloods/public/train/S2' Is Already There.
Folder '/worldfloods/public/train/gt' Is Already There.
Folder '/worldfloods/public/val/S2' Is Already There.
Folder '/worldfloods/public/val/gt' Is Already There.
Folder '/worldfloods/public/test/S2' Is Already There.
Folder '/worldfloods/public/test/gt' Is Already There.
train 196648  tiles
val 1284  tiles
test 11  tiles


### Step 3: Setup Model

In [45]:
from src.models.model_setup import get_model
config.model_params.test = False
config.model_params.train = True
model = get_model(config.model_params)

13 3


### Step 4: WandB Logger (Replace with your wandb info)

In [34]:
import wandb
wandb.init()

In [41]:
import wandb
from pytorch_lightning.loggers import WandbLogger

config['wandb_entity'] = 'sambuddinc'
config['wandb_project'] = 'worldfloods-notebook-demo'

wandb_logger = WandbLogger(
    name=config.experiment_name,
    project=config.wandb_project, 
    entity=config.wandb_entity
)

### Step 5: Setup Lightning Callbacks

In [42]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
experiment_path = f"{config.model_params.model_folder}/{config.experiment_name}"

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{experiment_path}/checkpoint",
    save_top_k=True,
    verbose=True,
    monitor='dice_loss',
    mode='min',
    prefix=''
)

early_stop_callback = EarlyStopping(
    monitor='dice_loss',
    patience=10,
    strict=False,
    verbose=False,
    mode='min'
)

callbacks = [checkpoint_callback, early_stop_callback]

print(f"{config.model_params.model_folder}/{config.experiment_name}")

gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart/worldfloods-notebook-training-demo


### Step 6: Setup Lighting Trainer
    -- add flags from 
    https://pytorch-lightning.readthedocs.io/en/0.7.5/trainer.html 

In [43]:
from pytorch_lightning import Trainer

config.gpus = '1'

trainer = Trainer(
    fast_dev_run=False,
    logger=wandb_logger,
    callbacks=callbacks,
    default_root_dir=f"{config.model_params.model_folder}/{config.experiment_name}",
    accumulate_grad_batches=1,
    gradient_clip_val=0.0,
    auto_lr_find=False,
    benchmark=False,
    distributed_backend=None,
    gpus=config.gpus,
    max_epochs=config.model_params.hyperparameters.max_epochs,
    check_val_every_n_epoch=config.model_params.hyperparameters.val_every,
    log_gpu_memory=None,
    resume_from_checkpoint=None
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [46]:
trainer.fit(model, dataset)


  | Name    | Type         | Params
-----------------------------------------
0 | network | SimpleLinear | 42    
-----------------------------------------
42        Trainable params
0         Non-trainable params
42        Total params
0.000     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/6187 [00:00<?, ?it/s]                     



Epoch 0:  13%|█▎        | 779/6187 [20:34<2:22:49,  1.58s/it, loss=0.562, v_num=3p2p]

Epoch 0, global step 778: dice_loss reached 0.85572 (best 0.85572), saving model to "gs://ml4cc_data_lake/0_DEV/2_Mart/2_MLModelMart/worldfloods-notebook-training-demo/checkpoint/epoch=0-step=778.ckpt" as top True


Epoch 0:  13%|█▎        | 779/6187 [20:35<2:22:56,  1.59s/it, loss=0.562, v_num=3p2p]


1

### Step 7: Save trained model

In [None]:
from pytorch_lightning.utilities.cloud_io import atomic_save
atomic_save(model.state_dict(), f"{experiment_path}/model.pt")
torch.save(model.state_dict(), os.path.join(wandb_logger.save_dir, 'model.pt'))
wandb.save(os.path.join(wandb_logger.save_dir, 'model.pt'))
wandb.finish()

# Save cofig file in experiment_path
config_file_path = f"{experiment_path}/config.json"

if config_file_path.startswith("gs://"):
    from google.cloud import storage
    splitted_path = config_file_path.replace("gs://", "").split("/")
    bucket_name = splitted_path[0]
    blob_name = "/".join(splitted_path[1:])
    bucket = storage.Client().get_bucket(bucket_name)
    blob = bucket.blob(blob_name)
    blob.upload_from_string(
        data=json.dumps(config),
        content_type='application/json'
    )
else:
    with open(config_file_path, "w") as fh:
        json.dump(config, fh)