# SpaceNet8 
## Foundation Baseline

1. foundation


## overview
- `data_prep` までは公開ベースラインと同じ
- 学習部分のみの改善を試みる
- pytorch lightning + wandb + SMP or TimmUNet の導入


## Env

In [4]:
%%writefile requirements.txt
joblib
python-box
tqdm
timm
ttach
adabelief-pytorch
albumentations
segmentation-models-pytorch
wandb
tensorboard
tensorboardX
pytorch-lightning

Overwriting requirements.txt


In [9]:
# !pip install -q -r ../docker/requirements.txt
!pip install -q -r requirements.txt
!conda install -y gdal
!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 4.9.2
  latest version: 4.13.0

Please update conda by running

    $ conda update -n base -c defaults conda



# All requested packages already installed.

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.8.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torch-1.8.0%2Bcu111-cp38-cp38-linux_x86_64.whl (1982.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 GB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:04[0m
[?25hCollecting torchvision==0.9.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torchvision-0.9.0%2Bcu111-cp38-cp38-linux_x86_64.whl (17.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.6/17.6 MB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tor

## Import

In [1]:
import os
import warnings
import random
from pprint import pprint
import copy
from typing import List, Tuple
import glob
import json
import csv
# import dataclasses
from joblib import Parallel, delayed

from tqdm import tqdm
import numpy as np
import pandas as pd
from box import Box
import matplotlib.pyplot as plt

import tifffile
from osgeo import gdal

from sklearn.model_selection import StratifiedKFold, KFold
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timm import create_model
from adabelief_pytorch import AdaBelief
import segmentation_models_pytorch as smp

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import callbacks
from pytorch_lightning.callbacks.progress import ProgressBarBase
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import LightningDataModule, LightningModule

import wandb
wandb.login(key='****')

warnings.filterwarnings("ignore")

torch.autograd.set_detect_anomaly(True)
pd.options.display.max_colwidth = 250
pd.options.display.max_rows = 30

# インライン表示
%matplotlib inline

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msyuchimu[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/syu/.netrc


## Config

In [2]:
from typing import List, Set, Dict, Any

class CFG(object):
    # basic
    debug: bool = False
    debug_sample: int = 32
    folds: int  = 5
    seed: int   = 417
    eps: float  = 1e-12
    outdir: str = '../../train/output/foundation/'
    
    # data
    PATH_FOLD_CSV: str  =  f'../../data/folds/'
    
    # train
    epoch: int  = 00
    trainer: Dict[str, Any]   = {
        'gpus': 1,
        'accumulate_grad_batches': 1,
        'progress_bar_refresh_rate': 1,
        'stochastic_weight_avg': False,
        'fast_dev_run': False,
        'num_sanity_val_steps': 0,
        'resume_from_checkpoint': None,
        'check_val_every_n_epoch': 2,
        'val_check_interval': 1.0,
        'precision' : 16,
        'gradient_clip_val': 10., 
        'gradient_clip_algorithm': "value"
    }
    optimizer: Dict[str, Any] = {
        'name': 'optim.AdamW',
        'params': {
            'lr': 1e-3,
            },
    }
    scheduler: Dict[str, Any] = {
        'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
        'params':{
            'T_0': 20,
            'eta_min': 1e-5,
            }
    }
    model: Dict[str, Any] = {
        "architecture": 'smp', # timmu, smp
        "threshold": 0.2,
        
        'loss': 'MultiBCEDiceLoss(raito=0.5)',
        # smp loss mode: https://smp.readthedocs.io/en/latest/_modules/segmentation_models_pytorch/losses/dice.html
        'loss_mode': 'multilabel', # 'binary', 'multiclass' ,'multilabel'
        
        'in_channels': 0,
        'out_channels': 0,
        
        # unet++ :https://smp.readthedocs.io/en/latest/_modules/segmentation_models_pytorch/decoders/unetplusplus/model.html
        'decoder_channels': [int(512 / 2**i) for i in range(5)],
        'encoder_name': 'efficientnet-b6',
        'act': None,
        # 'dropout_rato': 0.1,
    }
    train_loader: Dict[str, Any] = {
        'batch_size': 8,
        'shuffle': True,
        'num_workers': 8,
        'pin_memory': False,
        'drop_last': True,
    }
    val_loader :Dict[str, Any]= {
        'batch_size': 8,
        'shuffle': False,
        'num_workers': 8,
        'pin_memory': False,
        'drop_last': False
    }
    
    # preprocess
    features :List[str] = ["preimg", "building","road","roadspeed","flood"]
    # ["preimg","postimg","building","road","roadspeed","flood"]
    
    preprocess: Dict = {
        "input_size": 512,
    }
    
    # logging
    project: str = "SpaceNet8_foundation"
    runname: str = "3090"
    group: str   = f'3090_V1_IMG{preprocess["input_size"]}_effb6_decoder512_flip-trans'
    notebook: str = 'baseline_foundation.ipynb'
    
    # post info
    augmentation: str =  ''
    fold: int = -1
    
    
    # channels
    for f in features:
        if f == 'preimg':
            model['in_channels'] += 3
        elif f == 'postimg':
            model['in_channels'] += 3

        if f == 'building':
            model['out_channels'] += 1
        # elif f == 'road':
        #     model['out_channels']  += 1
        elif f == 'roadspeed':
            model['out_channels']  += 8
        # elif f == 'flood':
        #     model['out_channels']  += 4
        

    if debug:
        epoch = 2
        group = 'DEBUG'


# box
cfg = Box({k:v for k, v in dict(vars(CFG)).items() if '__' not in k})
    
# 乱数のシードを設定
seed_everything(cfg.seed)
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
random.seed(cfg.seed)
    
pprint(cfg)

Global seed set to 417


{'PATH_FOLD_CSV': '../../data/folds/',
 'augmentation': '',
 'debug': False,
 'debug_sample': 32,
 'epoch': 100,
 'eps': 1e-12,
 'f': 'flood',
 'features': BoxList(['preimg', 'building', 'road', 'roadspeed', 'flood']),
 'fold': -1,
 'folds': 5,
 'group': '3090_V1_IMG512_effb6_decoder512_flip-trans',
 'model': {'act': None,
           'architecture': 'smp',
           'decoder_channels': BoxList([512, 256, 128, 64, 32]),
           'encoder_name': 'efficientnet-b6',
           'in_channels': 3,
           'loss': 'MultiBCEDiceLoss(raito=0.5)',
           'loss_mode': 'multilabel',
           'out_channels': 9,
           'threshold': 0.2},
 'notebook': 'baseline_foundation.ipynb',
 'optimizer': Box({'name': 'optim.AdamW', 'params': {'lr': 0.001}}),
 'outdir': '../../train/output/foundation/',
 'preprocess': Box({'input_size': 512}),
 'project': 'SpaceNet8_foundation',
 'runname': '3090',
 'scheduler': {'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
               'params': Bo

## Augmentation

In [3]:
# augmentation
tf_dict = {
    
    'train': A.Compose(
        [

            # A.CoarseDropout(max_holes=4, max_height=4, max_width=4, 
            #                     min_holes=None, min_height=None, min_width=None, 
            #                     fill_value=0.15, mask_fill_value=0.0, always_apply=False, p=0.25),
            # A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, interpolation=1,
            #                     border_mode=4, value=None, mask_value=None, always_apply=False,
            #                     approximate=False, same_dxdy=False, p=0.25),
            # A.GridDistortion(num_steps=5, distort_limit=0.4, interpolation=1, 
            #                     border_mode=4, value=None, mask_value=None, always_apply=False, p=0.25),
            # A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, interpolation=1, 
            #                     border_mode=4, value=0.01, mask_value=0.0, shift_limit_x=None, shift_limit_y=None, 
            #                     p=0.5),
            # A.OneOf([
            #     # A.GaussNoise(var_limit=(1e-3, 8e-1), mean=0.15, p=0.5),
            #     A.Blur(blur_limit=9, p=0.25),
            #     A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, brightness_by_max=True, p=0.5),
            # ], p=0.9),
            A.Transpose(p=0.25),
            A.Flip(p=0.5),
            # A.HueSaturationValue (hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=5, p=0.5),
            # A.Rotate(limit=30, p=0.5),
            A.Resize(cfg.preprocess.input_size, cfg.preprocess.input_size),
    #         A.Normalize(mean=(0.485), std=(0.229)),
            ToTensorV2(),
            ]
        ),
    'val': A.Compose(
        [
            A.Resize(cfg.preprocess.input_size, cfg.preprocess.input_size),
            # A.Normalize(mean=(0.485), std=(0.229)),
            ToTensorV2(),
        ]
    ),
}

cfg.augmentation = str(tf_dict).replace('\n', '').replace(' ', '')
cfg.augmentation

"{'train':Compose([Transpose(always_apply=False,p=0.25),Flip(always_apply=False,p=0.5),Resize(always_apply=False,p=1,height=512,width=512,interpolation=1),ToTensorV2(always_apply=True,p=1.0,transpose_mask=False),],p=1.0,bbox_params=None,keypoint_params=None,additional_targets={}),'val':Compose([Resize(always_apply=False,p=1,height=512,width=512,interpolation=1),ToTensorV2(always_apply=True,p=1.0,transpose_mask=False),],p=1.0,bbox_params=None,keypoint_params=None,additional_targets={})}"

## Dataset

In [4]:
class SpaceNnet8Dataset(Dataset):
    def __init__(self,
                 fold: int,
                 phase: str,
                 ):
        """ pytorch dataset for spacenet-8 data. loads images from a csv that contains filepaths to the images
        
        Parameters:
        ------------
        fold: (int) 
            preimg column contains filepaths to the pre-event image tiles (.tif)
            postimg column contains filepaths to the post-event image tiles (.tif)
            building column contains the filepaths to the binary building labels (.tif)
            road column contains the filepaths to the binary road labels (.tif)
            roadspeed column contains the filepaths to the road speed labels (.tif)
            flood column contains the filepaths to the flood labels (.tif)
        data_to_load (list): a list that defines which of the images and labels to load from the .csv. 
        
        """
        self.all_data_types = ["preimg", "postimg", "building", "road", "roadspeed", "flood"]
        
        self.data_to_load = cfg.features
        self.phase = phase
        csv_filename = os.path.join(cfg.PATH_FOLD_CSV, f'fold{fold}_seed{cfg.seed}_{self.phase}.csv')
        self.transform = tf_dict[self.phase]
        
        self.files = []

        dict_template = {}
        for i in self.all_data_types:
            dict_template[i] = None
        
        with open(csv_filename, newline='') as csvfile:
            reader = csv.DictReader(csvfile)
            for k, row in enumerate(reader):
                in_data = copy.copy(dict_template)
                for j in self.data_to_load:
                    in_data[j]=row[j]
                self.files.append(in_data)
                
                if cfg.debug and k > cfg.debug_sample:
                    break
        
        print("loaded", len(self.files), "image filepaths")

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        data_dict = self.files[index]

        imgs, masks = [], []
        
        # gather
        for i in self.all_data_types:
            filepath = data_dict[i]
            if filepath is not None:
                # need to resample postimg to same spatial resolution/extent as preimg and labels.
                if i == "postimg":
                    ds = self.get_warped_ds(data_dict["postimg"])
                else:
                    ds = gdal.Open(filepath)
                image = ds.ReadAsArray()
                ds = None
            
                if i in ['preimg' ,'postimg']:
                    imgs.append(image.transpose(1, 2, 0))
                else:
                    # 1 channel
                    if len(image.shape) <= 2:
                        masks.append(image[:,:, np.newaxis])
                    else:
                        masks.append(image.transpose(1, 2, 0))
                        
                
        
        
        # align channel last
        imgs = np.concatenate(imgs, axis=2)
        masks = np.concatenate(masks, axis=2)
        
        # augmentation
        transformed = self.transform(image=imgs, mask=masks)
        
        imgs = transformed["image"]
        masks = transformed["mask"].permute(2, 0, 1) # torch channel fast
        
        return imgs, masks

    def get_image_filename(self, index: int) -> str:
        """ return pre-event image absolute filepath at index """
        data_dict = self.files[index]
        return data_dict["preimg"]

    def get_warped_ds(self, post_image_filename: str) -> gdal.Dataset:
        """ gdal warps (resamples) the post-event image to the same spatial resolution as the pre-event image and masks 
        
        SN8 labels are created from referencing pre-event image. Spatial resolution of the post-event image does not match the spatial resolution of the pre-event imagery and therefore the labels.
        In order to align the post-event image with the pre-event image and mask labels, we must resample the post-event image to the resolution of the pre-event image. Also need to make sure
        the post-event image covers the exact same spatial extent as the pre-event image. this is taken care of in the the tiling"""
        ds = gdal.Warp("", post_image_filename,
                       format='MEM', width=1300, height=1300,
                       resampleAlg=gdal.GRIORA_Bilinear,
                       outputType=gdal.GDT_Byte)
        return ds
    
class SpaceNnet8Module(LightningDataModule):
    def __init__(
        self,
        fold,
        cfg,
    ):
        super().__init__()
        self.fold = fold
        self._cfg = cfg

    def train_dataloader(self):
        dataset = SpaceNnet8Dataset(self.fold, phase='train')
        return DataLoader(dataset, **self._cfg.train_loader)

    def val_dataloader(self):
        dataset = SpaceNnet8Dataset(self.fold, phase='val')
        return DataLoader(dataset, **self._cfg.val_loader)

## Model

In [5]:
class DiceLoss(smp.utils.base.Loss):
    """DiceLoss which supports ignore mask.
    """
    def __init__(self, eps=cfg.eps, beta=0.5, ignore_mask_channel=None, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps
        self.beta = beta
        self.ignore_mask_channel = ignore_mask_channel

    def forward(self, y_pr, y_gt):
        # y_pr, y_gt = _apply_ignore_mask(y_pr, y_gt, self.ignore_mask_channel)

        return 1 - smp.utils.functional.f_score(
            y_pr,
            y_gt,
            beta=self.beta,
            eps=self.eps,
            threshold=None,
            ignore_channels=None,
        )

class BCEDiceLoss(torch.nn.Module):
    def __init__(self, raito=0.5):
        super(BCEDiceLoss, self).__init__()
        assert 0 <= raito <= 1, "loss raito invalid."
        
        self.raito = raito
        self.bce_criterion = torch.nn.BCEWithLogitsLoss()
        self.dice_criterion = DiceLoss()
        
    def forward(self, y_pr, y_gt):
        loss = self.raito * self.bce_criterion(y_pr, y_gt) + (1 - self.raito) * self.dice_criterion(torch.sigmoid(y_pr), y_gt)
        return loss
    
class MultiBCEDiceLoss(torch.nn.Module):
    def __init__(self, raito=0.5):
        super(MultiBCEDiceLoss, self).__init__()
        assert 0 <= raito <= 1, "loss raito invalid."
        
        self.raito = raito
        self.bce_criterion = torch.nn.BCEWithLogitsLoss()
        self.dice_criterion = smp.losses.DiceLoss(mode=cfg.model.loss_mode)
        
    def forward(self, y_pr, y_gt):
        loss_bce = self.raito * self.bce_criterion(y_pr, y_gt)

        if cfg.model.loss_mode == 'multiclass':
            y_gt = y_gt.long()
            y_gt = torch.argmax(y_gt, dim=1)
        loss_dice = (1 - self.raito) * self.dice_criterion(torch.sigmoid(y_pr), y_gt)
        
        loss = loss_bce + loss_dice
        return loss

class MultiBCETverskyLoss(torch.nn.Module):
    def __init__(self, raito=0.5):
        super(MultiBCETverskyLoss, self).__init__()
        assert 0 <= raito <= 1, "loss raito invalid."
        
        self.raito = raito
        self.bce_criterion = torch.nn.BCEWithLogitsLoss()
        self.tvrsky_criterion = smp.losses.TverskyLoss(mode=cfg.model.loss_mode, log_loss=False)
        
    def forward(self, y_pr, y_gt):
        loss_bce = self.raito * self.bce_criterion(y_pr, y_gt)

        if cfg.model.loss_mode == 'multiclass':
            y_gt = y_gt.long()
            y_gt = torch.argmax(y_gt, dim=1)
        loss_dice = (1 - self.raito) * self.tvrsky_criterion(torch.sigmoid(y_pr), y_gt)
        
        loss = loss_bce + loss_dice
        return loss

In [6]:
def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
    assert alpha > 0, "alpha should be larger than 0"
    assert x.size(0) > 1, "Mixup cannot be applied to a single instance."

    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0])
    mixed_x = lam * x + (1 - lam) * x[rand_index, :]
    target_a, target_b = y, y[rand_index]
    return mixed_x, target_a, target_b, lam


class SpaceNet8Model(LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.__build_model()
        self._criterion = eval(cfg.model.loss)
        
    def __build_model(self):
        self.backbone = smp.UnetPlusPlus(encoder_name=cfg.model.encoder_name,
                                              encoder_weights="imagenet",
                                      decoder_attention_type='scse',
                                      in_channels=cfg.model.in_channels, activation=cfg.model.act,
                                      decoder_channels=cfg.model.decoder_channels,
                                      classes=cfg.model.out_channels)

    def forward(self, x):
        feat = self.backbone(x)
        
        return feat

    def training_step(self, batch, batch_idx):
        return self.__share_step(batch, 'train')
        
    def validation_step(self, batch, batch_idx):
        return self.__share_step(batch, 'val')
    
    def __share_step(self, batch, mode):
        images, labels = batch
        labels = labels.float()
        images = images.float()

        feat = self.forward(images)
        build = feat[:, 0, :, :].unsqueeze(dim=1)
        speed = feat[:, 1:9, :, :]

        # loss
        loss_build = self._criterion(build, labels[:, 0,...].unsqueeze(dim=1))
        # loss_road = self._criterion(road, labels[:, 1,...].unsqueeze(dim=1))
        loss_speed = self._criterion(speed, labels[:, 1:9,...])
        # loss_flood = self._criterion(flood, labels[:, 10:14,...])
        # sum
        loss = loss_build + loss_speed
        
        logits = torch.sigmoid(torch.cat([build, speed], dim=1))
        preds = (logits > cfg.model.threshold).float()
        return_dict = {'loss': loss , 
                       'loss_build': loss_build, 'loss_speed': loss_speed}
        
        # metrics
        for c in range(cfg.model.out_channels):
            preds_c, labels_c = preds[:, c, :, :], labels[:, c, :, :]
            tp = (preds_c * labels_c).sum().to(torch.float32)
            tn = ((1. - preds_c) * (1. - labels_c)).sum().to(torch.float32)
            fp = (preds_c * (1. - labels_c)).sum().to(torch.float32)
            fn = ((1. - preds_c) * labels_c).sum().to(torch.float32)
            return_dict[f'TP_{c}'] = tp.unsqueeze(dim=0).detach().cpu()
            return_dict[f'TN_{c}'] = tn.unsqueeze(dim=0).detach().cpu()
            return_dict[f'FP_{c}'] = fp.unsqueeze(dim=0).detach().cpu()
            return_dict[f'FN_{c}'] = fn.unsqueeze(dim=0).detach().cpu()
            
            precision = tp / (tp + fp + cfg.eps)
            recall = tp / (tp + fn + cfg.eps)
            f1 = 2 * (precision*recall) / (precision + recall + cfg.eps)
            iou = tp / (tp + fp + fn + cfg.eps)
            
            return_dict[f'Precision_{c}'] = precision.unsqueeze(dim=0).detach().cpu()
            return_dict[f'Recall_{c}'] = recall.unsqueeze(dim=0).detach().cpu()
            return_dict[f'F1_{c}'] = f1.unsqueeze(dim=0).detach().cpu()
            return_dict[f'IoU_{c}'] = iou.unsqueeze(dim=0).detach().cpu()
            
            # logging 
            self.log(f'{mode}/iter_TP_{c}', tp)
            self.log(f'{mode}/iter_TN_{c}', tn)
            self.log(f'{mode}/iter_FP_{c}', fp)
            self.log(f'{mode}/iter_FN_{c}', fn)
            
            self.log(f'{mode}/iter_Precision_{c}', precision)
            self.log(f'{mode}/iter_Recall_{c}', recall)
            
            self.log(f'{mode}/iter_F1_{c}', f1)
            self.log(f'{mode}/iter_IoU_{c}', iou) 
            
        self.log(f'{mode}/iter_loss', loss)
        for target in ['build', 'speed',]:
            self.log(f'{mode}/iter_loss_{target}', eval(f'loss_{target}'))

        return return_dict
        
    def training_epoch_end(self, outputs):
        self.__share_epoch_end(outputs, 'train')

    def validation_epoch_end(self, outputs):
        self.__share_epoch_end(outputs, 'val')    
        
    def __share_epoch_end(self, outputs, mode):
        
        # loss
        losses = []
        for target in ['build', 'speed',]:
            exec(f'losses_{target} = []')
            
        for out in outputs:
            losses.append(out['loss'].cpu().detach().numpy())
        losses = np.mean(losses)
        self.log(f'{mode}/epoch_loss', losses)
        
        for target in ['build', 'speed',]:
            for out in outputs:
                exec(f"losses_{target}.append(out[f'loss_{target}'].cpu().detach().numpy())")
            exec(f'losses_{target} = np.mean(losses_{target})')
            exec(f'self.log(f"{mode}/epoch_loss_{target}", losses_{target})')
        
        mean_iou = 0
        mean_f1 = 0
    
        # metrics
        for c in range(cfg.model.out_channels):
            tps, tns, fps, fns, precisions, recalls, f1s, IoUs = \
                [], [], [], [], [], [], [], []
            for out in outputs:
                # assert False, (out[f'TP_{c}'], out[f'TP_{c}'].shape)
                for (tp, tn, fp, fn, precision, recall, f1, iou) in zip(out[f'TP_{c}'], 
                                         out[f'TN_{c}'],
                                         out[f'FP_{c}'],
                                         out[f'FN_{c}'],
                                         out[f'Precision_{c}'],
                                         out[f'Recall_{c}'],
                                         out[f'F1_{c}'],
                                         out[f'IoU_{c}'],):
                    
                    tps.append(tp.unsqueeze(dim=0))
                    tns.append(tn.unsqueeze(dim=0))
                    fps.append(fp.unsqueeze(dim=0))
                    fns.append(fn.unsqueeze(dim=0))
                    
                    precisions.append(precision.unsqueeze(dim=0))
                    recalls.append(recall.unsqueeze(dim=0))
                    f1s.append(f1.unsqueeze(dim=0))
                    IoUs.append(iou.unsqueeze(dim=0))
                    
            tps = torch.cat(tps, dim=0).squeeze()
            tns = torch.cat(tns, dim=0).squeeze()
            fps = torch.cat(fps, dim=0).squeeze()
            fns = torch.cat(fns, dim=0).squeeze()
            
            precisions = torch.cat(precisions, dim=0).squeeze()
            recalls = torch.cat(recalls, dim=0).squeeze()
            f1s = torch.cat(f1s, dim=0).squeeze()
            IoUs = torch.cat(IoUs, dim=0).squeeze()
            
            
            # logging 
            self.log(f'{mode}/epoch_TP_{c}', tps)
            self.log(f'{mode}/epoch_TN_{c}', tns)
            self.log(f'{mode}/epoch_FP_{c}', fps)
            self.log(f'{mode}/epoch_FN_{c}', fns)
            
            self.log(f'{mode}/epoch_Precision_{c}', precisions)
            self.log(f'{mode}/epoch_Recall_{c}', recalls)
            
            self.log(f'{mode}/epoch_F1_{c}', f1s)
            self.log(f'{mode}/epoch_IoU_{c}', IoUs)
            
            mean_iou += np.mean(IoUs.numpy()).item()
            mean_f1 += np.mean(f1s.numpy()).item()
        
        mean_iou /= cfg.model.out_channels
        mean_f1 /= cfg.model.out_channels
        self.log(f'{mode}/mean_IoU', mean_iou)
        self.log(f'{mode}/mean_F1', mean_f1)

    def configure_optimizers(self):
        optimizer = eval(self.cfg.optimizer.name)(
            self.parameters(), **self.cfg.optimizer.params
        )
        scheduler = eval(self.cfg.scheduler.name)(
            optimizer,
            **self.cfg.scheduler.params
        )
        return [optimizer], [scheduler]

## Train

In [7]:
for fold in range(cfg.folds):
    print(f'#'*60)
    print(f'### Fold: {fold}')
    print(f'#'*60)
    
    # Setting   
    cfg.fold = fold
    wandb_logger = WandbLogger(
        config=cfg,
        name=f"{cfg.runname}_fold{fold}",
        project=cfg.project,
        group=cfg.group,
        tags=[f'fold{fold}', '3090', 'notebook', 'foundation', 'build', 'speed'],
        # entity='spaceshift',
    )
    
    # Data
    datamodule = SpaceNnet8Module(fold, cfg)
    
    # Model
    model = SpaceNet8Model(cfg)
    
    # PATH
    dirpath = f'{cfg.outdir}{cfg.group}/{cfg.runname}_fold-{fold}/'
    filename = f"best_fold-{fold}"
    best_model_path = dirpath + filename + ".ckpt"


    # Logging
    lr_monitor = callbacks.LearningRateMonitor()
    loss_checkpoint = callbacks.ModelCheckpoint(
        dirpath=dirpath,
        filename=filename,
        monitor="val/mean_IoU",
        save_top_k=1,
        mode="max",
        save_last=True,
    )
    wandb.save(cfg.notebook)
    # logger = TensorBoardLogger()
    
    print(f'### Start Trainig')
    # Train
    trainer = Trainer(
        logger=wandb_logger,
        max_epochs=cfg.epoch,
        callbacks=[lr_monitor, loss_checkpoint],
        **cfg.trainer,
    )
    # 実行
    trainer.fit(model, datamodule=datamodule)
    
    # saving
    with open(f'{dirpath}cfg.json', 'w') as f:
        json.dump(cfg.to_dict(), f, indent=4)
        
    wandb.save(cfg.notebook)
    wandb.finish()
    break
      

############################################################
### Fold: 0
############################################################


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


### Start Trainig


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | backbone   | UnetPlusPlus     | 50.5 M
1 | _criterion | MultiBCEDiceLoss | 0     
------------------------------------------------
50.5 M    Trainable params
0         Non-trainable params
50.5 M    Total params
101.057   Total estimated model params size (MB)


loaded 640 image filepaths
loaded 161 image filepaths


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='1.063 MB of 1.124 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.946009…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr-AdamW,██▇▆▄▃▂▁██▇▆▄▃▂▁██▇▆▄▃▂▁██▇▆▄▃▂▁██▇▆▄▃▂▁
train/epoch_F1_0,▁▅▅▆▆▇▇▇▆▆▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇█▇████▇▇██████
train/epoch_F1_1,▁▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇▇▇████▇▇▇█████
train/epoch_F1_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/epoch_F1_3,▁▂▃▃▄▄▅▅▄▄▅▅▆▆▆▆▆▅▆▆▆▆▇▇▆▆▆▇▇▇▇█▆▇▇▇████
train/epoch_F1_4,▁▁▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▅▆▆▆▆▇▇▆▆▆▇▇▇▇█▆▇▇▇▇███
train/epoch_F1_5,▁▁▁▁▂▃▄▄▄▄▅▆▆▆▆▆▅▅▅▆▆▇▇▇▆▆▇▇▇▇▇▇▇▇▇▇▇███
train/epoch_F1_6,▁▁▁▁▁▁▁▂▂▄▅▆▆▆▆▆▆▇▆▇▇▇▇▇▇▆▇▇▇▇▇▆▇▇▇▇▇▇██
train/epoch_F1_7,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,99.0
lr-AdamW,2e-05
train/epoch_F1_0,0.82772
train/epoch_F1_1,0.79131
train/epoch_F1_2,0.0
train/epoch_F1_3,0.73161
train/epoch_F1_4,0.72683
train/epoch_F1_5,0.67084
train/epoch_F1_6,0.52681
train/epoch_F1_7,0.0


In [None]:
try:
    wandb.finish()
except Exception as e:
    print(f'[Error] {e} --> OK')

VBox(children=(Label(value='1.100 MB of 1.100 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
lr-AdamW,▁
trainer/global_step,▁

0,1
lr-AdamW,0.001
trainer/global_step,0.0
