In [124]:
%reload_ext autoreload
%autoreload 2

import torch
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET

from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs,
)

### Hyperparameters

In [125]:
LEARNING_RATE = 1e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 16
NUM_EPOCHS = 5
IMAGE_HEIGHT = 160
IMAGE_WIDTH = 240
LOAD_MODEL = True
TRAIN_IMG_DIR = 'data/train'
TRAIN_MASK_DIR = 'data/train_masks'
VAL_IMG_DIR = 'data/val'
VAL_MASK_DIR = 'data/val_masks'

In [126]:
# 1 EPOCH of Training
def train_epoch(loader, model, optimizer, loss_fn):
    loop = tqdm(loader)
    
    for batch_idx, (x, y) in enumerate(loop):
        x = x.to(device=DEVICE)
        y = y.float().unsqueeze(1).to(device=DEVICE)
        
        # forward
        out = model(x)
        loss = loss_fn(out, y)
        
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # update progress bar
        loop.set_postfix(loss = loss.item())

In [127]:
def main():
    train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=40, p=0.9, border_mode=cv2.BORDER_CONSTANT),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0,0.0,0.0],
            std=[1.0,1.0,1.0],
            max_pixel_value=255.0 # default: 255
        ),
        ToTensorV2()
    ])
    
    val_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )
        
    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss() # binary classification. 
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform
    )
    
    if LOAD_MODEL:
        load_checkpoint(torch.load('my_checkpoint.pth.tar'), model, optimizer)
    
    for epoch in range(NUM_EPOCHS):
        train_epoch(train_loader, model, optimizer, loss_fn)
        
        # save model
        checkpoint = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)
        
        # check accuracy
        check_accuracy(train_loader, model, device=DEVICE)
        
        # print some examples to a folder
        save_predictions_as_imgs(train_loader, model, folder='saved_images/', device=DEVICE)
        
    

In [128]:
main()

=> Loading checkpoint


100%|██████████| 3/3 [02:11<00:00, 43.78s/it, loss=0.199]


=> Saving checkpoint
Got 1795913/1843200 with accuracy 97.43
Dice score:0.9310169219970703
tensor([[[[0.3719, 0.3719, 0.3373,  ..., 0.3269, 0.3415, 0.3707],
          [0.3226, 0.3163, 0.2724,  ..., 0.2780, 0.3046, 0.3518],
          [0.2871, 0.2578, 0.2180,  ..., 0.2108, 0.2455, 0.3333],
          ...,
          [0.2694, 0.2284, 0.1998,  ..., 0.2085, 0.2473, 0.3424],
          [0.2927, 0.2397, 0.2319,  ..., 0.2319, 0.2557, 0.3610],
          [0.3399, 0.2697, 0.2643,  ..., 0.2700, 0.3084, 0.4135]]],


        [[[0.3718, 0.3685, 0.3356,  ..., 0.3324, 0.3403, 0.3676],
          [0.3204, 0.3179, 0.2709,  ..., 0.2625, 0.3051, 0.3552],
          [0.2854, 0.2548, 0.2160,  ..., 0.2059, 0.2423, 0.3363],
          ...,
          [0.2721, 0.2301, 0.2022,  ..., 0.2047, 0.2432, 0.3405],
          [0.2939, 0.2459, 0.2337,  ..., 0.2243, 0.2501, 0.3588],
          [0.3388, 0.2809, 0.2826,  ..., 0.2639, 0.3036, 0.4118]]],


        [[[0.3716, 0.3723, 0.3375,  ..., 0.3208, 0.3378, 0.3693],
          [0.

  0%|          | 0/3 [00:00<?, ?it/s]

tensor([[[[0.3710, 0.3714, 0.3366,  ..., 0.3219, 0.3385, 0.3692],
          [0.3218, 0.3166, 0.2714,  ..., 0.2730, 0.2996, 0.3505],
          [0.2861, 0.2566, 0.2166,  ..., 0.2069, 0.2387, 0.3296],
          ...,
          [0.2736, 0.2378, 0.2062,  ..., 0.2107, 0.2491, 0.3445],
          [0.2933, 0.2460, 0.2370,  ..., 0.2333, 0.2567, 0.3620],
          [0.3380, 0.2737, 0.2699,  ..., 0.2705, 0.3091, 0.4141]]],


        [[[0.3717, 0.3723, 0.3375,  ..., 0.3208, 0.3378, 0.3693],
          [0.3228, 0.3174, 0.2734,  ..., 0.2725, 0.2994, 0.3503],
          [0.2875, 0.2588, 0.2192,  ..., 0.2060, 0.2377, 0.3291],
          ...,
          [0.2732, 0.2374, 0.2054,  ..., 0.2106, 0.2492, 0.3444],
          [0.2931, 0.2459, 0.2367,  ..., 0.2331, 0.2563, 0.3618],
          [0.3378, 0.2735, 0.2695,  ..., 0.2701, 0.3090, 0.4141]]],


        [[[0.3717, 0.3723, 0.3379,  ..., 0.3210, 0.3379, 0.3693],
          [0.3229, 0.3174, 0.2736,  ..., 0.2727, 0.2994, 0.3503],
          [0.2877, 0.2589, 0.2196,  ..

100%|██████████| 3/3 [02:23<00:00, 47.81s/it, loss=0.19] 


=> Saving checkpoint
Got 1800439/1843200 with accuracy 97.68
Dice score:0.9365747570991516
tensor([[[[0.3587, 0.3551, 0.3200,  ..., 0.3092, 0.3262, 0.3635],
          [0.3058, 0.2950, 0.2499,  ..., 0.2509, 0.2797, 0.3385],
          [0.2718, 0.2318, 0.1963,  ..., 0.1876, 0.2198, 0.3151],
          ...,
          [0.2624, 0.2163, 0.1866,  ..., 0.1913, 0.2279, 0.3303],
          [0.2831, 0.2261, 0.2170,  ..., 0.2152, 0.2372, 0.3457],
          [0.3314, 0.2570, 0.2539,  ..., 0.2524, 0.2891, 0.3982]]],


        [[[0.3599, 0.3557, 0.3200,  ..., 0.3128, 0.3299, 0.3655],
          [0.3063, 0.2944, 0.2507,  ..., 0.2533, 0.2844, 0.3404],
          [0.2721, 0.2323, 0.1968,  ..., 0.1909, 0.2264, 0.3201],
          ...,
          [0.2573, 0.2044, 0.1795,  ..., 0.1880, 0.2255, 0.3280],
          [0.2819, 0.2208, 0.2106,  ..., 0.2136, 0.2361, 0.3447],
          [0.3333, 0.2523, 0.2497,  ..., 0.2525, 0.2886, 0.3974]]],


        [[[0.3592, 0.3558, 0.3208,  ..., 0.3090, 0.3255, 0.3639],
          [0.

image saved


  0%|          | 0/3 [00:00<?, ?it/s]

tensor([[[[0.3589, 0.3554, 0.3204,  ..., 0.3091, 0.3257, 0.3635],
          [0.3061, 0.2952, 0.2505,  ..., 0.2504, 0.2795, 0.3384],
          [0.2722, 0.2323, 0.1970,  ..., 0.1872, 0.2193, 0.3151],
          ...,
          [0.2620, 0.2157, 0.1859,  ..., 0.1909, 0.2279, 0.3303],
          [0.2830, 0.2255, 0.2165,  ..., 0.2150, 0.2368, 0.3456],
          [0.3312, 0.2566, 0.2538,  ..., 0.2520, 0.2888, 0.3981]]],


        [[[0.3596, 0.3562, 0.3211,  ..., 0.3082, 0.3255, 0.3637],
          [0.3069, 0.2958, 0.2517,  ..., 0.2502, 0.2794, 0.3382],
          [0.2733, 0.2338, 0.1989,  ..., 0.1866, 0.2187, 0.3144],
          ...,
          [0.2619, 0.2161, 0.1858,  ..., 0.1907, 0.2279, 0.3299],
          [0.2830, 0.2259, 0.2167,  ..., 0.2145, 0.2365, 0.3453],
          [0.3312, 0.2569, 0.2534,  ..., 0.2516, 0.2885, 0.3980]]],


        [[[0.3586, 0.3551, 0.3200,  ..., 0.3092, 0.3262, 0.3635],
          [0.3058, 0.2949, 0.2499,  ..., 0.2510, 0.2797, 0.3385],
          [0.2718, 0.2317, 0.1963,  ..

  0%|          | 0/3 [00:14<?, ?it/s]


KeyboardInterrupt: 