In [1]:
import argparse
import logging
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

from utils.data_loading import BasicDataset
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet import UNet
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses.dice import DiceLoss
from segmentation_models_pytorch.losses.focal import FocalLoss
from segmentation_models_pytorch.losses.tversky import TverskyLoss
from segmentation_models_pytorch.losses.jaccard import JaccardLoss
from segmentation_models_pytorch.losses.lovasz import LovaszLoss

In [2]:
dir_img = Path('.\\Dataset\\Tier1\\Post\\Image512\\')
dir_mask = Path('.\\Dataset\\Tier1\\Post\\Label512\\')
dir_val_img = Path('.\\Dataset\\Validation\\Post\\Image512\\')
dir_val_mask = Path('.\\Dataset\\Validation\\Post\\Label512\\')
dir_checkpoint = Path('./checkpoints/MobileNet_V2/')

In [3]:
closs = nn.CrossEntropyLoss()
dloss = DiceLoss(mode = 'multiclass',
                 log_loss = False,
                 from_logits = True,
                 smooth = 0,
                 eps = 1e-7)
floss = FocalLoss(mode = 'multiclass',
                alpha = None,
                gamma = 1.0,
                ignore_index = None,
                reduction = "mean",
                normalized = False,
                reduced_threshold = None)
tloss = TverskyLoss(mode = 'multiclass',
        from_logits = True,
        alpha = 0.6,
        beta = 0.4,
        gamma = 1.0)
jloss = JaccardLoss(mode = 'multiclass',
        from_logits = True,
        eps = 1e-7)
lloss = LovaszLoss(mode = 'multiclass',
        per_image = False,
        from_logits = True)

In [4]:
def train_net(net,
              device,
              start_epoch: int = 1,
              epochs: int = 5,
              batch_size: int = 1,
              learning_rate: float = 1e-5,
              val_percent: float = 0.1,
              save_checkpoint: bool = True,
              img_scale: float = 0.5,
              amp: bool = False):
    # 1. Create dataset
    try:
        dataset = BasicDataset(dir_img, dir_mask, img_scale, values =  [1, False, False, 0, None, 0, 0], probabilities = [0,0,0,0,0,0,0],increase = 0,mask_suffix = '.png')
        valdataset = BasicDataset(dir_val_img, dir_val_mask, img_scale, mask_suffix = '.png')
    except (AssertionError, RuntimeError):
        print('error')

    # 2. Split into train / validation partitions
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

    # 3. Create data loaders
    loader_args = dict(batch_size=batch_size, num_workers=1, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(valdataset, shuffle=True,batch_size=1, num_workers=1, pin_memory=True)

    # (Initialize logging)
    #experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
    #experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
    #                              val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
    #                              amp=amp))

    #logging.info(f'''Starting training:
    #    Epochs:          {epochs}
    #    Batch size:      {batch_size}
    #    Learning rate:   {learning_rate}
    #    Training size:   {n_train}
    #    Validation size: {n_val}
    #    Checkpoints:     {save_checkpoint}
    #    Device:          {device}
    #    Images scaling:  {img_scale}
    #    Mixed Precision: {amp}
    #''')

    # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
    optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
    grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
    criterion =closs
    global_step = 0

    # 5. Begin training
    for epoch in range(start_epoch, start_epoch + epochs):
        print(epoch)
        net.train()
        epoch_loss = 0
        with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
            for batch in train_loader:
                images = batch['image']
                true_masks = batch['mask']
#                assert images.shape[1] == net.n_channels, \
#                    f'Network has been defined with {net.n_channels} input channels, ' \
#                    f'but loaded images have {images.shape[1]} channels. Please check that ' \
#                    'the images are loaded correctly.'

                images = images.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.long)
                with torch.cuda.amp.autocast(enabled=amp):
                    masks_pred = net(images)

#                    loss = criterion(masks_pred, true_masks) \
#                           + dice_loss(F.softmax(masks_pred, dim=1).float(),
#                                       F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
#                                       multiclass=True)
                    loss = criterion(masks_pred, true_masks)
                optimizer.zero_grad(set_to_none=True)
                grad_scaler.scale(loss).backward()
                grad_scaler.step(optimizer)
                grad_scaler.update()

                pbar.update(images.shape[0])
                global_step += 1
                epoch_loss += loss.item()
                #experiment.log({
                #    'train loss': loss.item(),
                #    'step': global_step,
                #    'epoch': epoch
                #})
                pbar.set_postfix(**{'loss (batch)': loss.item()})

                
#                # Evaluation round
#                division_step = (n_train // (10 * batch_size))
#                if division_step > 0:
#                    if global_step % division_step == 0:
#                        histograms = {}
#                        for tag, value in net.named_parameters():
#                            tag = tag.replace('/', '.')
#                            histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
#                            histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
#
#                        val_score = evaluate(net, val_loader, device)
#                        scheduler.step(val_score)
#
#                        logging.info('Validation Dice score: {}'.format(val_score))
#                        experiment.log({
#                            'learning rate': optimizer.param_groups[0]['lr'],
#                            'validation Dice': val_score,
#                            'images': wandb.Image(images[0].cpu()),
#                            'masks': {
#                                'true': wandb.Image(true_masks[0].float().cpu()),
#                                'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
#                            },
#                            'step': global_step,
#                            'epoch': epoch,
#                            **histograms
#                        })
#
        val_score = evaluate(net,val_loader, device)
        print(val_score)
        print(epoch_loss/n_train)
        if save_checkpoint:
            Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
            logging.info(f'Checkpoint {epoch} saved!')



In [5]:
classes = 5
bilinear = False
loadstate = False
load = './checkpoints/Base-UNET-Focal1LossFlip/checkpoint_epoch20.pth'
start_epoch = 1
epochs = 10
batch_size = 1
lr = 1e-6
scale = 1
val = 50
amp = False
save_checkpoint = True
if __name__ == '__main__':

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = 'cuda'
    #logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #net = UNet(n_channels=3, n_classes = classes, bilinear=bilinear)
    net = smp.Unet(
        encoder_name='mobilenet_v2',
        encoder_depth=5,
        encoder_weights='imagenet',
        decoder_use_batchnorm=False,
        decoder_channels=(1024,512,256, 128, 64),
        decoder_attention_type=None,
        in_channels=3,
        classes=5,
        activation=None,
        aux_params=None
    )
    
    #logging.info(f'Network:\n'
    #             f'\t{net.n_channels} input channels\n'
    #             f'\t{net.n_classes} output channels (classes)\n'
    #             f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if loadstate:
        net.load_state_dict(torch.load(load, map_location=device))
        logging.info(f'Model loaded from {load}')

    net.to(device=device)
    train_net(net=net,
                  start_epoch = start_epoch,
                  epochs=epochs,
                  batch_size=batch_size,
                  learning_rate=lr,
                  device=device,
                  img_scale=scale,
                  val_percent=val / 100,
                  amp=amp,
                  save_checkpoint = save_checkpoint)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to C:\Users\thanh/.cache\torch\hub\checkpoints\mobilenet_v2-b0353104.pth


  0%|          | 0.00/13.6M [00:00<?, ?B/s]

1


Epoch 1/10: 100%|█████████████████████████████████████████████| 1400/1400 [04:22<00:00,  5.34img/s, loss (batch)=0.659]
                                                                                                                       

tensor(0.5867, device='cuda:0')
0.2600370811059007
2


Epoch 2/10: 100%|████████████████████████████████████████████| 1400/1400 [04:05<00:00,  5.71img/s, loss (batch)=0.0634]
                                                                                                                       

tensor(0.5932, device='cuda:0')
0.18701699178324946
3


Epoch 3/10: 100%|████████████████████████████████████████████| 1400/1400 [04:09<00:00,  5.62img/s, loss (batch)=0.0674]
                                                                                                                       

tensor(0.6081, device='cuda:0')
0.1532199166743833
4


Epoch 4/10: 100%|████████████████████████████████████████████| 1400/1400 [04:04<00:00,  5.72img/s, loss (batch)=0.0306]
                                                                                                                       

tensor(0.4958, device='cuda:0')
0.1381645401022979
5


Epoch 5/10: 100%|███████████████████████████████████████████| 1400/1400 [04:05<00:00,  5.71img/s, loss (batch)=0.00213]
                                                                                                                       

tensor(0.4267, device='cuda:0')
0.12430743845325196
6


Epoch 6/10: 100%|███████████████████████████████████████████| 1400/1400 [04:04<00:00,  5.72img/s, loss (batch)=0.00797]
                                                                                                                       

tensor(0.4283, device='cuda:0')
0.11372364565471539
7


Epoch 7/10: 100%|████████████████████████████████████████████| 1400/1400 [04:04<00:00,  5.73img/s, loss (batch)=0.0113]
                                                                                                                       

tensor(0.5087, device='cuda:0')
0.1040942569213647
8


Epoch 8/10: 100%|█████████████████████████████████████████████| 1400/1400 [04:03<00:00,  5.75img/s, loss (batch)=0.377]
                                                                                                                       

tensor(0.4108, device='cuda:0')
0.09513723054521377
9


Epoch 9/10: 100%|████████████████████████████████████████████| 1400/1400 [04:05<00:00,  5.70img/s, loss (batch)=0.0211]
                                                                                                                       

tensor(0.4214, device='cuda:0')
0.08628506282238049
10


Epoch 10/10: 100%|████████████████████████████████████████████| 1400/1400 [04:06<00:00,  5.69img/s, loss (batch)=0.173]
                                                                                                                       

tensor(0.4169, device='cuda:0')
0.08063722700580456
