In [None]:
import os
import glob
from skimage import io

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
import numpy as np
from skimage.color import label2rgb
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    show_images,
)

In [None]:
# Hyperparameters
LEARNING_RATE = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2
NUM_EPOCHS = 200
NUM_WORKERS = 2
IMAGE_HEIGHT = 1024
IMAGE_WIDTH = 1024
PIN_MEMORY = True
LOAD_MODEL = False
TRAIN_IMG_DIR = "data/train_images/"
TRAIN_MASK_DIR = "data/train_masks/"
VAL_IMG_DIR = "data/val_images/"
VAL_MASK_DIR = "data/val_masks/"
SHOW_VAL_INTERVAL = 100
DEBUG=False

In [None]:
# Check GPU access
print(torch.version.cuda)
if torch.cuda.is_available():
    print("GPU available")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory}")
else:
    print("No GPU")

!nvidia-smi

In [None]:
all_images =[]
for file in glob.glob(TRAIN_IMG_DIR + "\*.tif"):
    all_images.append(os.path.basename(file))

#print(all_images)

# Load image example
img_filename = os.path.join(TRAIN_IMG_DIR, all_images[0])
image_example = io.imread(img_filename)
#image_example.astype(np.uint8)
    
# Load mask example
mask_filename = os.path.join(TRAIN_MASK_DIR, all_images[0])
mask_example = io.imread(mask_filename)

# define RGB colors from Tableau 10 (default in matplotlib)
t10_blue = [31/255, 119/255, 180/255]
t10_orange = [255/255, 127/255, 14/255]
t10_green = [44/255, 160/255, 44/255]
t10_red = [214/255, 39/255, 40/255]

# list of colors
colors=[t10_blue, t10_orange, t10_green, t10_red]

rgb_labels = label2rgb(
    label=mask_example[0:IMAGE_HEIGHT,0:IMAGE_WIDTH],
    image=image_example[0:IMAGE_HEIGHT,0:IMAGE_WIDTH],
    colors=colors,
    alpha=0.7,
    bg_label=0,
    bg_color=None
    )

# plot pair example and overlay
fig, ax = plt.subplots(1,3, figsize=(18,6))
fig.suptitle("Training data example\n" + all_images[0])
ax[0].imshow(image_example[0:IMAGE_HEIGHT,0:IMAGE_WIDTH], cmap='gray')
ax[0].set_title("Image")
ax[1].imshow(mask_example[0:IMAGE_HEIGHT,0:IMAGE_WIDTH])
ax[1].set_title("Mask")
ax[2].imshow(rgb_labels)
ax[2].set_title("Overlay")
plt.tight_layout
plt.show()


In [None]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    train_loss_all= []

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        #targets = targets.float().unsqueeze(1).to(device=DEVICE) # unsqueeze(1) to add a channel dimension
        targets = targets.type(torch.LongTensor).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            train_loss = loss_fn(predictions, targets)
            if DEBUG:
                print("DEBUG!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
                m = nn.Softmax(dim=None)
                smax = m(predictions[0:1])
                argmax = torch.argmax(smax, dim=1)
                show_images(data[0:1], targets[0:1], argmax, smax[:,0], smax[:,1], smax[:,2], smax[:,3], smax[:,4], smax[:,5], titles=["Image", "Target", "Prediction", "Background Probability", "Myelin Probability", "Tongue Probability", "AxonM Probability", "AxonNM Probability", "Mitochondria Probability"],n_cols=3)
    
                #show_images(predictions,predictions>0,targets,titles=["Prediction","Threshold","Label"])

        # backward
        optimizer.zero_grad()
        scaler.scale(train_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(train_loss=train_loss.item())
        train_loss_all.append(train_loss.item())
    
    return sum(train_loss_all)/len(train_loss_all)

def transforms_fn():
    train_transform = A.Compose(
        [
            #A.RandomScale(scale_limit=0.2, interpolation=cv2.INTER_LINEAR, p=0.1),
            #A.Rotate(limit=90, p=0.25, border_mode=cv2.BORDER_CONSTANT),
            A.RandomCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            #A.OneOf([
            #    A.GaussianBlur (blur_limit=(3, 5), sigma_limit=0, always_apply=False, p=0.25),
            #    A.MedianBlur (blur_limit=3, always_apply=False, p=0.25),
            #    A.GaussNoise(var_limit=(0.05, 0.1), mean=0, per_channel=True, always_apply=False, p=0.25),
            #    A.Defocus (radius=(1, 15), alias_blur=(0.1, 0.5), always_apply=False, p=0.25)
            #], p=0.2),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=200000.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.RandomCrop(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0],
                std=[1.0],
                max_pixel_value=200000.0,
            ),
            ToTensorV2(),
        ],
    )
    
    return train_transform, val_transforms

def loss_plot_fn(train_loss, val_loss):
    # plot train and val loss
    print("\n----------------------------------------------------------------------------")
    print("\nTraining and validation loss")
    
    fig_loss = plt.gcf()
    
    plt.plot(np.arange(1,len(train_loss)+1).tolist(), train_loss, label = "Training loss")
    plt.plot(np.arange(1,len(val_loss)+1).tolist(), val_loss, label = "Validation loss")
    plt.title('Training and validation loss vs epoch number (linear)')
    plt.ylabel("Loss")
    plt.ylim(0,2)
    plt.xlabel("Epoch number")
    plt.xticks(ticks=np.arange(0, len(val_loss)+1, (len(val_loss))/4).tolist())
    plt.legend()
    plt.show()
    
    fig_loss.set_facecolor('white')
    fig_loss.savefig('loss_plot.png', bbox_inches='tight', dpi=300)

In [None]:
import copy
import random
from dataset import BinaryDataset


def visualize_augmentations(dataset, idx=0, samples=10, cols=5):
    dataset = copy.deepcopy(dataset)
    dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
    rows = samples // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(16, 8))
    for i in range(samples):
        image, _ = dataset[idx]
        ax.ravel()[i].imshow(image)
        ax.ravel()[i].set_axis_off()
    plt.tight_layout()
    plt.show()


    
train_transform, val_transforms = transforms_fn()
train_dataset = BinaryDataset(image_dir=TRAIN_IMG_DIR, mask_dir=TRAIN_MASK_DIR, transform=train_transform)

random.seed(42)
visualize_augmentations(train_dataset)

In [None]:
def main():
    train_transform, val_transforms = transforms_fn()
    
    model = UNET(in_channels=1, out_channels=6).to(DEVICE)
    loss_fn = nn.CrossEntropyLoss()
    # for multilabel segmentation change out_channels (e.g., 3) and use cross entropy loss insted of BCEWithLogitsLoss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        last_epoch, train_loss, val_loss = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)
        check_accuracy(val_loader, model, device=DEVICE, show_results=True)
        print("Model successfully loaded")
    else:
        train_loss = []
        val_loss = []
        last_epoch = 0
        print("Training model from scratch")


    scaler = torch.cuda.amp.GradScaler()


    if (SHOW_VAL_INTERVAL > 0):
        display_check = np.arange(SHOW_VAL_INTERVAL-1, NUM_EPOCHS, SHOW_VAL_INTERVAL)
    else:
        display_check = 0



    for epoch in range(NUM_EPOCHS):
        print("\n----------------------------------------------------------------------------")
        print(f"\nepoch {epoch+1}/{NUM_EPOCHS}")
        if LOAD_MODEL:
            print(f"total epoch {last_epoch+epoch+1}/{last_epoch+NUM_EPOCHS}")
        
        # train model
        t_loss = train_fn(train_loader, model, optimizer, loss_fn, scaler)
        train_loss.append(t_loss)

        # check accuracy
        show_results = np.any(display_check == epoch)
        v_loss = check_accuracy(val_loader, model, device=DEVICE, show_results=show_results)
        val_loss.append(v_loss)

        # save model checkpoint
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch + 1 + last_epoch,
            "train_loss": train_loss,
            "val_loss": val_loss
        }
        save_checkpoint(checkpoint)

        # print some examples to a folder
        #save_predictions_as_imgs(
        #    val_loader, model, folder="saved_images/", device=DEVICE
        #)


    loss_plot_fn(train_loss, val_loss)

In [None]:
if __name__ == "__main__":
    main()