In [3]:
from cgitb import reset
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import skimage as sm
import skimage.io
from matplotlib import pyplot as plt
import tifffile
import timm
from fastai.vision.all import *


In [4]:
# Hyperparameters

LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
NUM_EPOCHS = 5
NUM_WORKERS = 2
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "dat_orientation/train_images/"
TRAIN_MASK_DIR = "dat_orientation/train_masks/"
VAL_IMG_DIR = "dat_orientation/val_images/"
VAL_MASK_DIR = "dat_orientation/val_masks/"

In [7]:
# defines the dataloader

class OriDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        filenames = os.listdir(image_dir)
        filenames.sort()
        if ".DS_Store" in filenames:
            filenames.remove(".DS_Store")
        self.images = filenames

    def __len__(self):
        return len(self.images)

    # gets both the 10 frame images and corresponding mask
    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        mask_path = os.path.join(
            self.mask_dir, self.images[index].replace(".tif", "_mask.tif"))
        image = sm.io.imread(img_path).astype(np.float32)
        mask = np.array(Image.open(mask_path), dtype=np.float32)
        mask0 = mask
        mask[mask == 255] = 1
        images = torch.tensor(image/256).float()

        if self.transform is not None:
            # Normilises and transforms the images and masks 
            transformed = self.transform(image=image[0], image0=image[1], image1=image[2], image2=image[3], 
                                         image3=image[4], image4=image[5], image5=image[6], image6=image[7], 
                                         image7=image[8], image8=image[9], mask=mask)
            images[0] = transformed["image"]
            images[1] = transformed["image0"]
            images[2] = transformed["image1"]
            images[3] = transformed["image2"]
            images[4] = transformed["image3"]
            images[5] = transformed["image4"]
            images[6] = transformed["image5"]
            images[7] = transformed["image6"]
            images[8] = transformed["image7"]
            images[9] = transformed["image8"]

            mask = transformed["mask"]

            # saves the mask and image before and after transform to 
            # check transforms are correctly functioning

            # save_transform(image, mask0, transformed)

        return images, mask

# saves the before and after transform by the augmentations
def save_transform(image, mask0, transformed):

    result = np.zeros([10, 1034, 1034])
    result[:, 0:512, 0:512] = image
    result[0, 0:512, 522:] = np.array(transformed["image"])*255
    result[1, 0:512, 522:] = np.array(transformed["image0"])*255
    result[2, 0:512, 522:] = np.array(transformed["image1"])*255
    result[3, 0:512, 522:] = np.array(transformed["image2"])*255
    result[4, 0:512, 522:] = np.array(transformed["image3"])*255
    result[5, 0:512, 522:] = np.array(transformed["image4"])*255
    result[6, 0:512, 522:] = np.array(transformed["image5"])*255
    result[7, 0:512, 522:] = np.array(transformed["image6"])*255
    result[8, 0:512, 522:] = np.array(transformed["image7"])*255
    result[9, 0:512, 522:] = np.array(transformed["image8"])*255

    result[:, 522:, 0:512] = mask0*255
    result[:, 522:, 522:] = np.array(transformed["mask"])*255

    result = np.asarray(result, "uint8")
    tifffile.imwrite(f"transformResults/transform.tif", result)


# util

# save model parameters
def save_checkpoint(state, filename="models/UNetOrientation_new.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

# load model parameters
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

# Make the dataloader 
def get_loaders(
    train_dir,
    train_maskdir,
    val_dir,
    val_maskdir,
    batch_size,
    train_transform,
    val_transform,
    num_workers=4,
    pin_memory=True
):
    train_ds = OriDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform=train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True,
    )

    val_ds = OriDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=False
    )

    return train_loader, val_loader

# define metric to assess model performance 
def check_accuracy(loader, model, device="cuda"):
    num_correct = 0
    num_pixels = 0
    dice_score = 0
    model.eval()
    loop = tqdm(loader)

    with torch.no_grad():
        for batch_idx, (x, y) in enumerate(loop):
            x = x.to(device)
            y = y.to(device).unsqueeze(1)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            num_correct += (preds == y).sum()
            num_pixels += torch.numel(preds)
            dice_score += (2 * (preds * y).sum()) / (
                (preds + y).sum() + 1e-8
            )

    print(
        f"Accuracy {num_correct/num_pixels*100}%"
    )
    print(f"Dice score {dice_score/len(loader)}")
    model.train()

# saves the ground truth with the model prediciton to folder saved_images
def save_predictions_as_imgs(loader, model, folder="saved_images/", device="cuda"):
    model.eval()
    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            x = x.to(device)
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
            for i in range(preds.shape[0]):
                torchvision.utils.save_image(
                    preds[i], f"{folder}pred_{i}.png"
                )
                torchvision.utils.save_image(
                    y.unsqueeze(1)[i], f"{folder}img_{i}.png")

            break

    model.train()

# train the model and update parameters of model
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = torch.unsqueeze(targets, 1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

# load and train the deep learning model
def main():
    target10 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'image3': 'image', 
                'image4': 'image', 'image5': 'image', 'image6': 'image', 'image7': 'image', 
                'image8': 'image', 'image9': 'image', 'mask': 'mask'}
    # augmentations for training model
    train_transform = A.Compose(
        [
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.GaussianBlur(blur_limit=(3, 5), p=0.3),
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            A.RandomBrightnessContrast(p=0.3),
            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),
            ToTensorV2(),
        ],
        additional_targets=target10,
    )
    # augmentations for validation data
    val_transform = A.Compose(
        [
            A.Normalize(
                mean=0,
                std=1,
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
        additional_targets=target10,
    )
    
    # make the UNetOrientation model 
    resnet = timm.create_model("resnet34", pretrained=True)
    resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(
        7, 7), stride=(2, 2), padding=(3, 3), bias=False) # change model first layer to have 10 features 

    m = resnet
    m = nn.Sequential(*list(m.children())[:-2])
    model = DynamicUnet(m, 1, (120, 120), norm_type=None).to(DEVICE)

    loss_fn = nn.BCEWithLogitsLoss()  # if out_channels > 1 => cross entropy loss
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(
        0.9, 0.999), eps=1e-08) # adam learner

    # Make the dataloader 
    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,  # train_transform
        val_transform,  # val_transform
        NUM_WORKERS,
        PIN_MEMORY,
    )
    
    # Load training model if one avalable 
    if LOAD_MODEL:
        load_checkpoint(torch.load("models/UNetOrientation.pth.tar"), model)
#         save_predictions_as_imgs(
#             val_loader, model, folder="saved_images/", device=DEVICE)
#         check_accuracy(val_loader, model, device=DEVICE)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        # train model
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE)
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/", device=DEVICE)


In [22]:
# Train model 
LEARNING_RATE = 1e-4
main()


  0%|          | 0/42 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 42/42 [00:22<00:00,  1.85it/s, loss=0.242]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.09it/s]

Accuracy 89.76678466796875%
Dice score 0.6154643297195435



100%|██████████| 42/42 [00:22<00:00,  1.84it/s, loss=0.234]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.28it/s]

Accuracy 90.25120544433594%
Dice score 0.6295679211616516



100%|██████████| 42/42 [00:40<00:00,  1.05it/s, loss=0.189]


=> Saving checkpoint


100%|██████████| 11/11 [00:24<00:00,  2.27s/it]

Accuracy 91.2807846069336%
Dice score 0.6977839469909668



100%|██████████| 42/42 [00:52<00:00,  1.26s/it, loss=0.205]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.43it/s]

Accuracy 93.46248626708984%
Dice score 0.7872283458709717



100%|██████████| 42/42 [00:22<00:00,  1.85it/s, loss=0.135]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.14it/s]

Accuracy 94.22380828857422%
Dice score 0.8135465383529663





In [24]:
LEARNING_RATE = 5e-5
main()


  0%|          | 0/42 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 42/42 [00:22<00:00,  1.89it/s, loss=0.103] 


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.08it/s]

Accuracy 95.23346710205078%
Dice score 0.8492602109909058



100%|██████████| 42/42 [00:22<00:00,  1.86it/s, loss=0.104] 


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.14it/s]

Accuracy 95.6044921875%
Dice score 0.8634843230247498



100%|██████████| 42/42 [00:22<00:00,  1.84it/s, loss=0.0797]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.27it/s]

Accuracy 95.51014709472656%
Dice score 0.8615202903747559



100%|██████████| 42/42 [00:22<00:00,  1.84it/s, loss=0.112] 


=> Saving checkpoint


100%|██████████| 11/11 [00:15<00:00,  1.45s/it]

Accuracy 95.40090942382812%
Dice score 0.8561145663261414



100%|██████████| 42/42 [00:58<00:00,  1.40s/it, loss=0.0815]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.07it/s]

Accuracy 95.71009826660156%
Dice score 0.8668902516365051





In [25]:
LEARNING_RATE = 1e-5
main()


  0%|          | 0/42 [00:00<?, ?it/s]

=> Loading checkpoint


100%|██████████| 42/42 [00:22<00:00,  1.86it/s, loss=0.106] 


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.21it/s]

Accuracy 95.67807006835938%
Dice score 0.865170955657959



100%|██████████| 42/42 [00:22<00:00,  1.86it/s, loss=0.072] 


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.14it/s]

Accuracy 95.71603393554688%
Dice score 0.8670840263366699



100%|██████████| 42/42 [00:22<00:00,  1.83it/s, loss=0.0973]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  4.96it/s]

Accuracy 95.7632827758789%
Dice score 0.8682851195335388



100%|██████████| 42/42 [00:22<00:00,  1.84it/s, loss=0.103] 


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.36it/s]

Accuracy 95.76087951660156%
Dice score 0.8681668639183044



100%|██████████| 42/42 [00:22<00:00,  1.86it/s, loss=0.0923]


=> Saving checkpoint


100%|██████████| 11/11 [00:02<00:00,  5.26it/s]


Accuracy 95.76690673828125%
Dice score 0.8683835864067078


In [5]:
# displays layers and parameters of models
resnet = timm.create_model("resnet34", pretrained=True)
resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(
    7, 7), stride=(2, 2), padding=(3, 3), bias=False)

m = resnet
m = nn.Sequential(*list(m.children())[:-2])
model = DynamicUnet(m, 1, (120, 120), norm_type=None).to(DEVICE)

from torchsummary import summary
summary(model, (10, 120, 120))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 60, 60]          31,360
       BatchNorm2d-2           [-1, 64, 60, 60]             128
              ReLU-3           [-1, 64, 60, 60]               0
         MaxPool2d-4           [-1, 64, 30, 30]               0
            Conv2d-5           [-1, 64, 30, 30]          36,864
       BatchNorm2d-6           [-1, 64, 30, 30]             128
          Identity-7           [-1, 64, 30, 30]               0
              ReLU-8           [-1, 64, 30, 30]               0
          Identity-9           [-1, 64, 30, 30]               0
           Conv2d-10           [-1, 64, 30, 30]          36,864
      BatchNorm2d-11           [-1, 64, 30, 30]             128
             ReLU-12           [-1, 64, 30, 30]               0
       BasicBlock-13           [-1, 64, 30, 30]               0
           Conv2d-14           [-1, 64,