In [1]:
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from evaluate import evaluate

from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from unet import UNet
from unet import MultiResUnet
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')

# model = MultiResUnet(input_channels=1, num_classes=2)
model = UNet(n_channels=1, n_classes=2)
model = model.to(memory_format=torch.channels_last)

logging.info(f'Network:\n'
             f'\t{model.n_channels} input channels\n'
             f'\t{model.n_classes} output channels (classes)\n')
model = model.to(memory_format=torch.channels_last)

model = model.to(device=device)

INFO: Using device cuda
INFO: Network:
	1 input channels
	2 output channels (classes)



In [9]:
epochs: int = 15
batch_size: int = 1
learning_rate: float = 1e-3
val_percent: float = 0.1
save_checkpoint: bool = True
img_scale: float = 0.25
amp: bool = False
weight_decay: float = 1e-8
momentum: float = 0.999
gradient_clipping: float = 1.0

dir_img = Path('./data/imgs/')
dir_mask = Path('./data/new_AC3_masks/')
dir_checkpoint = Path('./checkpoints/')

# 1. Create dataset
# try:
#     dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
# except (AssertionError, RuntimeError, IndexError):
#     dataset = BasicDataset(dir_img, dir_mask, img_scale)

dataset = BasicDataset(dir_img, dir_mask, img_scale)

# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))

# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)


logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {learning_rate}
    Training size:   {n_train}
    Validation size: {n_val}
    Checkpoints:     {save_checkpoint}
    Device:          {device.type}
    Images scaling:  {img_scale}
    Mixed Precision: {amp}
''')



INFO: Creating dataset with 1264 examples
INFO: Scanning mask files to determine unique values
100%|██████████████████████████████████████████████████████████████████████████████| 1264/1264 [00:07<00:00, 169.64it/s]
INFO: Unique mask values: [0, 255]
INFO: Starting training:
    Epochs:          15
    Batch size:      1
    Learning rate:   0.001
    Training size:   1138
    Validation size: 126
    Checkpoints:     True
    Device:          cuda
    Images scaling:  0.25
    Mixed Precision: False



In [4]:
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
opt = optim.Adam(model.parameters(),
                          lr=learning_rate)
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
global_step = 0

In [13]:
step=0
train_losses = []
val_losses = []
for epoch in range(epochs):
    print("Epoch",epoch+1,"/",epochs)
    model.train()
    for batch in tqdm(train_loader):
        frame, mask = batch['image'], batch['mask']
        frame=frame.to(device=device,dtype=torch.float32)
        mask=mask.to(device=device,dtype=torch.long)
        pred=model(frame)
        loss=criterion(pred, mask)
        opt.zero_grad()
        loss.backward()
        opt.step()
        train_losses.append(loss.item())
        step+=1
        # update()
    
    model.eval()
    val_loss_avg=0
    n_val=0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            frame, mask = batch['image'], batch['mask']
            frame=frame.to(device=device,dtype=torch.float32)
            mask=mask.to(device=device,dtype=torch.long)
            pred=model(frame)
            loss=criterion(pred,mask)
            val_loss_avg+=loss.item()
            n_val+=1
    val_loss_avg=val_loss_avg/n_val
    print("Train Loss - ", 
    # t_vals.append(step)
    val_losses.append(val_loss_avg)
    torch.save(model.state_dict(),os.path.join("checkpoints","unet_{epoch:d}_{val_loss:.2e}.pth".format(epoch=epoch+1,val_loss=val_loss_avg)))
    # update()
    print()

Epoch 1 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:58<00:00,  3.81it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 18.61it/s]



Epoch 2 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:57<00:00,  3.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 18.40it/s]



Epoch 3 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:59<00:00,  3.80it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 18.45it/s]



Epoch 4 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:58<00:00,  3.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 19.10it/s]



Epoch 5 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:57<00:00,  3.83it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 18.58it/s]



Epoch 6 / 15


100%|███████████████████████████████████████████████████████████████████████████████| 1138/1138 [04:58<00:00,  3.82it/s]
100%|█████████████████████████████████████████████████████████████████████████████████| 126/126 [00:06<00:00, 18.47it/s]



Epoch 7 / 15


 54%|███████████████████████████████████████████                                     | 612/1138 [02:41<02:18,  3.79it/s]


KeyboardInterrupt: 

In [18]:
# print(val_losses)

[0.2138034257269095, 0.19865679619685997, 0.20998530548125033, 0.19740086711115307, 0.2258514946119653, 0.19894285405439044]


In [11]:
train_loss = []

for epoch in range(1, epochs + 1):
    model.train()
    epoch_loss = 0
    print(model.n_classes)
    with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            images, true_masks = batch['image'], batch['mask']

            assert images.shape[1] == model.n_channels, \
                f'Network has been defined with {model.n_channels} input channels, ' \
                f'but loaded images have {images.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
            true_masks = true_masks.to(device=device, dtype=torch.long)

            with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                masks_pred = model(images)
                # print(masks_pred.shape,true_masks.shape)
                if model.n_classes == 1:
                    loss = criterion(masks_pred.squeeze(1), true_masks.float())
                    loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
                else:
                    loss = criterion(masks_pred, true_masks)
                    # loss += dice_loss(
                    #     F.softmax(masks_pred, dim=1).float(),
                    #     F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
                    #     multiclass=True
                    # )

            optimizer.zero_grad(set_to_none=True)
            grad_scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
            grad_scaler.step(optimizer)
            grad_scaler.update()

            pbar.update(images.shape[0])
            global_step += 1
            epoch_loss += loss.item()
            pbar.set_postfix(**{'loss (batch)': loss.item()})

            # Evaluation round
            division_step = (n_train // (5 * batch_size))
            if division_step > 0:
                if global_step % division_step == 0:
                    
                    val_score = evaluate(model, val_loader, device, amp)
                    scheduler.step(val_score)

                    logging.info('Validation Dice score: {}'.format(val_score))
                    
    train_loss.append(epoch_loss)
    if save_checkpoint:
        Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
        state_dict = model.state_dict()
        state_dict['mask_values'] = dataset.mask_values
        torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
        logging.info(f'Checkpoint {epoch} saved!')

2


Epoch 1/15:  20%|█████████▌                                      | 227/1138 [01:01<03:57,  3.84img/s, loss (batch)=0.41]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:49,  2.53batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.09batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:08, 13.36batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 16.17batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 18.18batch/s][A
Validation round:  13%|███████▋                                                     | 16/126 [00:01<00:05, 19.27batch/s][A
Validation 

2


Epoch 2/15:  20%|█████████▎                                     | 224/1138 [01:00<04:09,  3.66img/s, loss (batch)=0.294]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<01:06,  1.87batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:17,  7.09batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:12,  9.70batch/s][A
Validation round:   6%|███▉                                                          | 8/126 [00:00<00:09, 11.88batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:01<00:08, 13.70batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:01<00:07, 15.06batch/s][A
Validation 

2


Epoch 3/15:  19%|█████████▏                                     | 221/1138 [00:59<04:12,  3.63img/s, loss (batch)=0.241]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:58,  2.15batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:19,  6.32batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:11, 10.86batch/s][A
Validation round:   6%|███▉                                                          | 8/126 [00:00<00:09, 12.89batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:08, 14.49batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:01<00:07, 15.77batch/s][A
Validation 

2


Epoch 4/15:  19%|█████████                                      | 218/1138 [00:58<03:58,  3.85img/s, loss (batch)=0.202]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:55,  2.24batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:18,  6.53batch/s][A
Validation round:   4%|██▍                                                           | 5/126 [00:00<00:12,  9.83batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:09, 12.37batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 15.23batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:01<00:07, 16.19batch/s][A
Validation 

2


Epoch 5/15:  19%|█████████                                       | 215/1138 [00:58<04:05,  3.76img/s, loss (batch)=0.19]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:47,  2.63batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:17,  7.22batch/s][A
Validation round:   4%|██▍                                                           | 5/126 [00:00<00:11, 10.57batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:09, 12.99batch/s][A
Validation round:   7%|████▍                                                         | 9/126 [00:00<00:07, 14.71batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:00<00:06, 16.58batch/s][A
Validation 

2


Epoch 6/15:  19%|████████▊                                      | 212/1138 [00:57<04:07,  3.75img/s, loss (batch)=0.197]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:59,  2.12batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:19,  6.23batch/s][A
Validation round:   4%|██▍                                                           | 5/126 [00:00<00:12,  9.57batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:09, 12.20batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 15.14batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:01<00:06, 16.93batch/s][A
Validation 

2


Epoch 7/15:  18%|████████▋                                      | 209/1138 [00:56<04:05,  3.79img/s, loss (batch)=0.232]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:52,  2.37batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  8.76batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:09, 12.83batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 15.89batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 17.94batch/s][A
Validation round:  13%|███████▋                                                     | 16/126 [00:01<00:05, 19.47batch/s][A
Validation 

2


Epoch 8/15:  18%|████████▌                                      | 206/1138 [00:55<04:04,  3.81img/s, loss (batch)=0.218]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:46,  2.68batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.11batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:09, 13.13batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 15.72batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 17.26batch/s][A
Validation round:  13%|███████▋                                                     | 16/126 [00:01<00:06, 18.27batch/s][A
Validation 

2


Epoch 9/15:  18%|████████▍                                      | 203/1138 [00:54<04:16,  3.64img/s, loss (batch)=0.189]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:45,  2.73batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.09batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:10, 11.75batch/s][A
Validation round:   6%|███▉                                                          | 8/126 [00:00<00:08, 13.93batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 15.32batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:00<00:07, 16.28batch/s][A
Validation 

2


Epoch 10/15:  18%|████████                                      | 200/1138 [00:54<04:13,  3.69img/s, loss (batch)=0.205]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:47,  2.61batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.00batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:10, 11.77batch/s][A
Validation round:   6%|███▉                                                          | 8/126 [00:00<00:08, 13.90batch/s][A
Validation round:   9%|█████▎                                                       | 11/126 [00:00<00:07, 16.41batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 17.30batch/s][A
Validation 

2


Epoch 11/15:  17%|███████▉                                      | 197/1138 [00:53<04:06,  3.81img/s, loss (batch)=0.185]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:46,  2.70batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.25batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:08, 13.33batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 16.14batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 18.11batch/s][A
Validation round:  13%|███████▋                                                     | 16/126 [00:01<00:05, 19.23batch/s][A
Validation 

2


Epoch 12/15:  17%|███████▊                                      | 194/1138 [00:52<04:12,  3.74img/s, loss (batch)=0.195]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:48,  2.59batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:13,  9.16batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:10, 11.77batch/s][A
Validation round:   6%|███▉                                                          | 8/126 [00:00<00:08, 13.89batch/s][A
Validation round:   9%|█████▎                                                       | 11/126 [00:00<00:07, 16.14batch/s][A
Validation round:  10%|██████▎                                                      | 13/126 [00:00<00:06, 16.83batch/s][A
Validation 

2


Epoch 13/15:  17%|███████▋                                      | 191/1138 [00:51<04:10,  3.77img/s, loss (batch)=0.225]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:43,  2.88batch/s][A
Validation round:   3%|█▉                                                            | 4/126 [00:00<00:12,  9.58batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:08, 13.34batch/s][A
Validation round:   8%|████▊                                                        | 10/126 [00:00<00:07, 16.19batch/s][A
Validation round:  10%|█████▊                                                       | 12/126 [00:00<00:06, 17.13batch/s][A
Validation round:  11%|██████▊                                                      | 14/126 [00:00<00:06, 17.85batch/s][A
Validation 

2


Epoch 14/15:  17%|███████▌                                      | 188/1138 [00:50<04:24,  3.59img/s, loss (batch)=0.175]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:46,  2.69batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:16,  7.40batch/s][A
Validation round:   4%|██▍                                                           | 5/126 [00:00<00:11, 10.94batch/s][A
Validation round:   6%|███▍                                                          | 7/126 [00:00<00:08, 13.40batch/s][A
Validation round:   7%|████▍                                                         | 9/126 [00:00<00:07, 15.14batch/s][A
Validation round:   9%|█████▎                                                       | 11/126 [00:00<00:07, 16.43batch/s][A
Validation 

2


Epoch 15/15:  16%|███████▊                                        | 185/1138 [00:50<04:10,  3.80img/s, loss (batch)=0.2]
Validation round:   0%|                                                                      | 0/126 [00:00<?, ?batch/s][A
Validation round:   1%|▍                                                             | 1/126 [00:00<00:41,  3.03batch/s][A
Validation round:   2%|█▍                                                            | 3/126 [00:00<00:15,  8.15batch/s][A
Validation round:   5%|██▉                                                           | 6/126 [00:00<00:09, 13.02batch/s][A
Validation round:   7%|████▍                                                         | 9/126 [00:00<00:07, 15.57batch/s][A
Validation round:   9%|█████▎                                                       | 11/126 [00:00<00:06, 16.65batch/s][A
Validation round:  11%|██████▊                                                      | 14/126 [00:00<00:06, 17.98batch/s][A
Validation 