In [1]:
%load_ext autoreload
%autoreload 2

import os
import csv
import numpy as np
import shapely
import pytorch_lightning as pl

from torch.utils.data import DataLoader, RandomSampler, ConcatDataset
from pytorch_lightning.utilities import CombinedLoader
from pathlib import Path
from datetime import datetime

import dl_toolbox.callbacks as callbacks
import dl_toolbox.modules as modules 
import dl_toolbox.networks as networks
import dl_toolbox.datasets as datasets
import dl_toolbox.torch_collate as collate
import dl_toolbox.utils as utils

import rasterio.windows as windows

if os.uname().nodename == 'WDTIS890Z': 
    data_root = Path('/mnt/d/pfournie/Documents/data')
    home = Path('/home/pfournie')
    save_root = data_root / 'outputs'
elif os.uname().nodename == 'qdtis056z': 
    data_root = Path('/data')
    home = Path('/d/pfournie')
    save_root = data_root / 'outputs'
else:
    #data_root = Path('/work/OT/ai4geo/DATA/DATASETS')
    data_root = Path(os.environ['TMPDIR'])
    home = Path('/home/eh/fournip')
    save_root = Path('/work/OT/ai4usr/fournip') / 'outputs'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# datasets params
dataset_name = 'DIGITANIE'
data_path = data_root / dataset_name
nomenclature = datasets.DigitanieNomenclatures['building'].value
num_classes=len(nomenclature)
crop_size=256
crop_step=256
bands = [1,2,3]

# split params
split = home / f'dl_toolbox/dl_toolbox/datamodules/digitanie_Toulouse.csv'

train_idx = [1,2,3,4,5]
train_aug = 'd4_color-3'

val_idx = [6,7]
val_aug = 'd4'

unsup_idx = [0]
unsup_aug = 'd4'

# dataloaders params
batch_size = 8
epoch_steps = 500
num_samples = epoch_steps * batch_size
num_workers=6

# network params
in_channels=len(bands)
out_channels=num_classes
pretrained = False
encoder='efficientnet-b0'

# module params
mixup=0. # incompatible with ignore_zero=True
class_weights = [1.] * num_classes
initial_lr=0.001
ttas=[]
alpha_ramp=utils.SigmoidRamp(2,4,0.,0.)
pseudo_threshold=0.9
consist_aug='color-5'
ema_ramp=utils.SigmoidRamp(2,4,0.9,0.99)

# trainer params
num_epochs = 30
#max_steps=num_epochs * epoch_steps
accelerator='gpu'
devices=1
multiple_trainloader_mode='min_size'
limit_train_batches=1.
limit_val_batches=1.
save_dir = save_root / dataset_name
log_name = 'toulouse'
ckpt_path=None

In [3]:
network = networks.SmpUnet(
    encoder=encoder,
    in_channels=in_channels,
    out_channels=out_channels,
    pretrained=pretrained
)

### Building lightning module
module = modules.Supervised(
    mixup=mixup, # incompatible with ignore_zero=True
    network=network,
    num_classes=num_classes,
    class_weights=class_weights,
    initial_lr=initial_lr,
    ttas=ttas,
    #alpha_ramp=alpha_ramp,
    #pseudo_threshold=pseudo_threshold,
    #consist_aug=consist_aug,
    #ema_ramp=ema_ramp
)

In [4]:
train_data_src = [
    src for src in datasets.datasets_from_csv(
        data_path,
        split,
        train_idx
    )
]

train_sets = [
    datasets.Raster(
        data_src=src,
        crop_size=crop_size,
        aug=train_aug,
        bands=bands,
        nomenclature=nomenclature
    ) for src in train_data_src
]

train_set = ConcatDataset(train_sets)

val_data_src = [
    src for src in datasets.datasets_from_csv(
        data_path,
        split,
        val_idx
    )
]

val_sets = [
    datasets.PretiledRaster(
        data_src=src,
        crop_size=crop_size,
        crop_step=crop_size//2,
        aug=val_aug,
        bands=bands,
        nomenclature=nomenclature
    ) for src in val_data_src
]

val_set = ConcatDataset(val_sets)

In [8]:
%matplotlib inline

pl_aug = 'd4'

pred_dir = Path('/data/outputs/DIGITANIE/train=[1, 2, 3, 4, 5]_val=[6, 7]/19Apr23-18h24m50/DIGITANIE')

pl_data_src = datasets.Digitanie(
    image_path=data_path/'Toulouse/Toulouse_EPSG32631_7.tif',
    label_path=pred_dir/'Toulouse_EPSG32631_7_0_0.tif',
    zone=windows.Window(0,0,2048,2048)
) 

pl_set = datasets.Raster(
    data_src=pl_data_src,
    crop_size=crop_size,
    aug=pl_aug,
    bands=bands,
    nomenclature=nomenclature
)

pl_train_set = ConcatDataset([train_set, pl_set])

pl_train_dataloaders = {}

pl_train_dataloaders['sup'] = DataLoader(
    dataset=pl_train_set,
    #sampler=pl_sampler,
    sampler=RandomSampler(
        data_source=pl_train_set,
        replacement=True,
        num_samples=num_samples
    ),
    collate_fn=collate.CustomCollate(),
    batch_size=batch_size,
    num_workers=num_workers,
    drop_last=True
)

val_dataloader = DataLoader(
    dataset=val_set,
    sampler=RandomSampler(
        data_source=val_set,
        replacement=True,
        num_samples=num_samples//10
    ),
    collate_fn=collate.CustomCollate(),
    batch_size=batch_size,
    num_workers=num_workers
)
    
metrics_from_confmat = callbacks.MetricsFromConfmat(        
    num_classes=num_classes,
    class_names=[label.name for label in nomenclature]
)

logger = pl.loggers.TensorBoardLogger(
    save_dir=save_dir,
    name=log_name,
    version=f'{datetime.now():%d%b%y-%Hh%Mm%S}'
)

### Trainer instance
pl_trainer = pl.Trainer(
    max_epochs=num_epochs,
    accelerator=accelerator,
    devices=devices,
    num_sanity_val_steps=0,
    limit_train_batches=limit_train_batches,
    limit_val_batches=limit_val_batches,
    logger=logger,
    callbacks=[
        pl.callbacks.ModelCheckpoint(),
        pl.callbacks.EarlyStopping(
            monitor='Val_loss',
            patience=10
        ),
        metrics_from_confmat,
        callbacks.MyProgressBar()
    ]
)

pl_trainer.fit(
    model=module,
    train_dataloaders=CombinedLoader(pl_train_dataloaders, mode=multiple_trainloader_mode),
    val_dataloaders=val_dataloader,
    ckpt_path=ckpt_path
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | network | SmpUnet          | 6.3 M 
1 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
6.3 M     Trainable params
0         Non-trainable params
6.3 M     Total params
25.007    Total estimated model params size (MB)


Epoch 0:   0%|▎                                                                                                                                              | 1/500 [09:11<76:26:30, 551.48s/it, v_num=3m49]
Epoch 0:  37%|█████████████████████████████████████████████████████▎                                                                                           | 184/500 [00:47<01:22,  3.85it/s, v_num=2m30]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
