In [1]:
%load_ext autoreload
%autoreload 2
import argparse
from typing import Dict

import torch
import os
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from model.metrics import Metrics
from model.unet import Unet, DEFAULT_UNET_LAYERS
from model.dice_loss import DiceLoss, DiceBCELoss
from datasets.dataset import RetinaSegmentationDataset
from utils.resultPrinter import ResultPrinter
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def train_model(model, dataloader, criterion, optimizer, device):
    metrics_tracker = Metrics(device)
    model.train()
    train_running_loss = 0.0
    for ind, (img, lbl) in enumerate(tqdm(dataloader, desc="Training")):
        # Copy to device
        img = img.to(device)
        lbl = lbl.to(device)
        # Make the prediction
        lbl_pred = model(img)
        optimizer.zero_grad()
        # Compute loss
        loss = criterion(lbl_pred, lbl)
        # compute metrics
        metrics_tracker.calculate(lbl_pred, lbl)
        # Running tally
        train_running_loss += loss.item() * img.shape[0]
        # Backward step
        loss.backward()
        optimizer.step()

    # Compute the loss for this epoch
    train_loss = train_running_loss / (ind + 1)
    # Compute the metrics for this epoch
    metrics = metrics_tracker.get_mean_metrics(ind + 1)
    metrics['loss'] = train_loss
    return metrics

In [3]:
def eval_model(model, dataloader, criterion, device):
    metrics_tracker = Metrics(device)
    model.eval()
    eval_running_loss = 0.0
    with torch.no_grad():
        for ind, (img, lbl) in enumerate(tqdm(dataloader, desc="Validation")):
            # Copy to device
            img = img.to(device)
            lbl = lbl.to(device)
            # Make the prediction
            lbl_pred = model(img)
            # Compute loss
            loss = criterion(lbl_pred, lbl)
            # compute metrics
            metrics_tracker.calculate(lbl_pred, lbl)
            # Running tally
            eval_running_loss += loss.item() * img.shape[0]

    # Compute the loss for this epoch
    eval_loss = eval_running_loss / (ind + 1)
    # Compute the metrics for this epoch
    metrics = metrics_tracker.get_mean_metrics(ind + 1)
    metrics['loss'] = eval_loss
    return metrics

In [5]:
rootdir: str = "C:/Users/shawn/Desktop/Development/CS7643/data/DATA_4D_Patches/DATA_4D_Patches"
workers: int = 8
load_encoder_weights: str = None
load_bt_checkpoint: str = None
anneal_tmax: int = 10
anneal_eta: int = 0
run_name: str = "test-drive"
checkpoint_dir: str = "C:/Users/shawn/Desktop/Development/CS7643/checkpoint/"

args = {
    "learning_rate": 0.01,
    "unet_layers": "64-128-256",
    "epochs": 10,
    "batch_size": 64,
    "scheduler": "CosineAnnealing",
    "loss_function": "DiceLoss",
    "dropout": 0.2
}

# Get the device
device = "cpu"
if torch.cuda.is_available():
    device = "cuda:0"

# Determine the layer sizes of the U-Net
unet_layers = DEFAULT_UNET_LAYERS
if args["unet_layers"]:
    unet_layers = [int(x) for x in args["unet_layers"].split("-")]

# Initialize the model on the GPU
model = Unet(dropout=args["dropout"], hidden_channels=unet_layers).to(device)
if load_encoder_weights:
    model.encoder.load_state_dict(torch.load(load_encoder_weights))
elif load_bt_checkpoint:
    model.encoder.load_state_dict(torch.load(load_bt_checkpoint)["encoder"])
optimizer = torch.optim.Adam(model.parameters(), lr=args["learning_rate"])

# Define scheduler (if necessary)
scheduler = None
if args["scheduler"] == 'CosineAnnealing':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, anneal_tmax, anneal_eta)
elif args["scheduler"] == 'ReduceOnPlateau':
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# Select the Loss function
loss_functions = {
    "BCEWithLogitsLoss": torch.nn.BCEWithLogitsLoss(),
    "CrossEntropyLoss": torch.nn.CrossEntropyLoss(),
    "DiceLoss": DiceLoss(),
    "DiceBCELoss": DiceBCELoss()
}
criterion = loss_functions[args["loss_function"]]

# Load the training datasets
training_path = os.path.join(rootdir, "Training")
training_file_basenames = os.listdir(os.path.join(training_path, "images"))
training_dataset = RetinaSegmentationDataset(training_path, training_file_basenames)
training_dataloader = torch.utils.data.DataLoader(
    training_dataset, batch_size=args["batch_size"], num_workers=workers,
    pin_memory=True, shuffle=True)

# Load the validation datasets
validation_path = os.path.join(rootdir, "Validation")
validation_file_basenames = os.listdir(os.path.join(validation_path, "images"))
validation_dataset = RetinaSegmentationDataset(validation_path, validation_file_basenames)
validation_dataloader = torch.utils.data.DataLoader(
    validation_dataset, batch_size=args["batch_size"], num_workers=workers,
    pin_memory=True, shuffle=False)

# Load the validation datasets
test_path = os.path.join(rootdir, "Testing")
test_file_basenames = os.listdir(os.path.join(test_path, "images"))
test_dataset = RetinaSegmentationDataset(test_path, test_file_basenames)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, num_workers=1,
    pin_memory=True, shuffle=False)

# Train / Val loop
training_losses = []
validation_losses = []

# Create a descriptive name for the checkpoints
temp_dict = dict(args)
descrip_name = ""
for key in temp_dict.keys():
    if (key != "rootdir" and
            "load" not in key and
            "checkpoint" not in key and
            "workers" not in key and
            "save_freq" not in key):
        descrip_name += "--" + key + "=" + str(temp_dict[key])
descrip_name = descrip_name.replace(' ', '_').replace('[', '').replace(']', '').replace('\'', '')

# runs dict should be passed to each instance of a results printer. It is only appended to so should be thread safe.
runs: Dict[str, Dict[str, float]] = {}
# create a new results printer for each param setting tested
result_printer = ResultPrinter(descrip_name, runs, run_name=run_name)

epoch_pbar = tqdm(total=args["epochs"], desc="Epochs")

prev_validation_loss = None

Epochs:   0%|                                                                                                                                     | 0/10 [00:00<?, ?it/s]

In [6]:
for i in range(args["epochs"]):

    train_metrics = train_model(model, training_dataloader, criterion, optimizer, device)
    result_printer.print(f'Training metrics: {str(train_metrics)}')
    train_loss = train_metrics['loss']

    validation_metrics = eval_model(model, validation_dataloader, criterion, device)
    result_printer.print(f'Validation metrics: {str(validation_metrics)}')
    validation_loss = validation_metrics['loss']

    result_printer.rankAndSave(validation_metrics)

    training_losses.append(train_loss)
    validation_losses.append(validation_loss)
    epoch_pbar.write("=" * 80)
    epoch_pbar.write("Epoch: {}".format(i))
    epoch_pbar.write("Train Loss : {:.4f}".format(train_loss))
    epoch_pbar.write("Validation Loss : {:.4f}".format(validation_loss))
    epoch_pbar.write("=" * 80)
    epoch_pbar.update(1)

    # Save plot of Train/Validation Loss Per Epoch
    result_printer.makePlots(training_losses, validation_losses, i)

    # Take appropriate scheduler step (if necessary)
    if args["scheduler"] == 'CosineAnnealing':
        scheduler.step()
    elif args["scheduler"] == 'ReduceOnPlateau':
        scheduler.step(validation_loss)

    if prev_validation_loss is not None:
        if abs(prev_validation_loss - validation_loss) / prev_validation_loss < 0.01:
            break

    # if i % args.save_freq == 0:
    # save the model
    state = dict(epoch=i + 1,
                 model=model.state_dict(),
                 optimizer=optimizer.state_dict(),
                 unet_layer_sizes=unet_layers,
                 args=temp_dict)
    torch.save(state, checkpoint_dir + f'unet-best-1208-{descrip_name}-epoch-{i}.pth')

    prev_validation_loss = validation_loss



Training:   1%|▋                                                                                                                         | 1/180 [00:10<31:24, 10.53s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:10<13:13,  4.46s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:10<07:22,  2.50s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:11<04:38,  1.58s/it][A
Training:   3%|███▍                                                                                                                      | 5/180 [00:11<03:08,  1.08s/it][A
Training:   3%|████                                                                                                                  

Training:  51%|█████████████████████████████████████████████████████████████▊                                                           | 92/180 [00:26<00:15,  5.67it/s][A
Training:  52%|██████████████████████████████████████████████████████████████▌                                                          | 93/180 [00:26<00:15,  5.63it/s][A
Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:27<00:15,  5.68it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:27<00:14,  5.69it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:27<00:15,  5.57it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                      

Training metrics: {'f1_score': 0.537265965094169, 'sensitivity': 0.6199684855838616, 'specificity': 0.8661241778896914, 'accuracy': 0.8355407343970405, 'auc_roc': 0.8290151261621052, 'mean_iou': 0.38118391945544217, 'ssim': 0.06557724135410455, 'loss': 29.785734346177843}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:30,  7.71s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:07<02:27,  2.05s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:08<01:12,  1.04s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:42,  1.60it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.5052895200252533, 'sensitivity': 0.433935759862264, 'specificity': 0.9684198800722758, 'accuracy': 0.9020887247721354, 'auc_roc': 0.8584955024719239, 'mean_iou': 0.3452802722652753, 'ssim': 0.026333794235057818, 'loss': 31.776187235514325}
Epoch: 0
Train Loss : 29.7857
Validation Loss : 31.7762
training loss per epoch: [29.785734346177843]
validation loss per epoch: [31.776187235514325]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:19,  7.82s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:08<09:54,  3.34s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:34,  1.89s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:33,  1.21s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.71it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:14,  5.72it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.69it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:14,  5.69it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:25<00:14,  5.72it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.7284178336461385, 'sensitivity': 0.7299765768978331, 'specificity': 0.9612428900268343, 'accuracy': 0.9322289043002658, 'auc_roc': 0.9387574381298489, 'mean_iou': 0.5730911569462882, 'ssim': 0.012208043305306799, 'loss': 17.45008659362793}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:23,  7.62s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:07<02:26,  2.03s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:07<01:11,  1.03s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:42,  1.61it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.6716115438938141, 'sensitivity': 0.572603687842687, 'specificity': 0.9868585483233134, 'accuracy': 0.9336744944254557, 'auc_roc': 0.9287132938702901, 'mean_iou': 0.5107053542137145, 'ssim': 0.002369157588885476, 'loss': 21.066141408284505}
Epoch: 1
Train Loss : 17.4501
Validation Loss : 21.0661
training loss per epoch: [29.785734346177843, 17.45008659362793]
validation loss per epoch: [31.776187235514325, 21.066141408284505]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:05,  7.74s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:07<09:51,  3.32s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:34,  1.89s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:33,  1.21s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.71it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:15,  5.64it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.68it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:14,  5.60it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:24<00:14,  5.55it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.7632329626215829, 'sensitivity': 0.7635153008831872, 'specificity': 0.9664132323529985, 'accuracy': 0.9409656948513455, 'auc_roc': 0.9435118284490374, 'mean_iou': 0.6172806166940266, 'ssim': 0.0059481751057319345, 'loss': 15.192716725667317}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:23,  7.62s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:07<02:26,  2.03s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:07<01:11,  1.02s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:42,  1.61it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.6338929137587548, 'sensitivity': 0.5303311551610629, 'specificity': 0.9873927736282349, 'accuracy': 0.9287370808919271, 'auc_roc': 0.8487732474009196, 'mean_iou': 0.4754012970626354, 'ssim': 0.0015235987900329444, 'loss': 23.455560251871745}
Epoch: 2
Train Loss : 15.1927
Validation Loss : 23.4556
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:14,  7.79s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:07<09:52,  3.33s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:34,  1.89s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:33,  1.21s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.70it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:14,  5.72it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.69it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:14,  5.66it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:25<00:14,  5.66it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.7831970115502676, 'sensitivity': 0.7812965369886822, 'specificity': 0.9697412629922231, 'accuracy': 0.9460964361826579, 'auc_roc': 0.9386427538262473, 'mean_iou': 0.6437198940250609, 'ssim': 0.004635723469416714, 'loss': 13.902807299296061}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:38,  7.81s/it][A
Validation:   3%|███▏                                                                                                                     | 2/75 [00:07<04:00,  3.29s/it][A
Validation:   5%|██████▍                                                                                                                  | 4/75 [00:08<01:30,  1.28s/it][A
Validation:   8%|█████████▋                                                                                                               | 6/75 [00:08<00:48,  1.41it/s][A
Validation:  11%|████████████▉                                                                                                        

Validation metrics: {'f1_score': 0.7416694025198619, 'sensitivity': 0.7101433918873469, 'specificity': 0.9776849834124247, 'accuracy': 0.9423751958211263, 'auc_roc': 0.9061345831553141, 'mean_iou': 0.5970049065351486, 'ssim': 0.0018701187953896199, 'loss': 16.552221374511717}
Epoch: 3
Train Loss : 13.9028
Validation Loss : 16.5522
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317, 13.902807299296061]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745, 16.552221374511717]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:06,  7.75s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:07<09:48,  3.31s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:32,  1.88s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:32,  1.21s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.64it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:14,  5.67it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.68it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:14,  5.59it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:24<00:14,  5.54it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.7959170278575686, 'sensitivity': 0.7940204213062922, 'specificity': 0.9715805335177315, 'accuracy': 0.949281104405721, 'auc_roc': 0.9357597612672381, 'mean_iou': 0.661137228541904, 'ssim': 0.0039901549255268445, 'loss': 13.084696430630155}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:36,  7.79s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:07<02:29,  2.08s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:08<01:13,  1.05s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:43,  1.57it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.744367814262708, 'sensitivity': 0.7051703834533691, 'specificity': 0.980185931523641, 'accuracy': 0.9442005920410156, 'auc_roc': 0.8842561348279318, 'mean_iou': 0.6041455227136612, 'ssim': 0.0024636638117954135, 'loss': 16.375254974365234}
Epoch: 4
Train Loss : 13.0847
Validation Loss : 16.3753
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317, 13.902807299296061, 13.084696430630155]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745, 16.552221374511717, 16.375254974365234]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:30,  7.88s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:08<09:59,  3.37s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:39,  1.92s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:37,  1.24s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.63it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:15,  5.66it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.68it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:25<00:14,  5.70it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:25<00:14,  5.70it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.8060128576225705, 'sensitivity': 0.8028398328357272, 'specificity': 0.9732430547475814, 'accuracy': 0.9518612384796142, 'auc_roc': 0.9340860267480214, 'mean_iou': 0.675111378563775, 'ssim': 0.00359028495537738, 'loss': 12.437855381435819}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:08<10:00,  8.11s/it][A
Validation:   3%|███▏                                                                                                                     | 2/75 [00:08<04:09,  3.42s/it][A
Validation:   5%|██████▍                                                                                                                  | 4/75 [00:08<01:33,  1.32s/it][A
Validation:   8%|█████████▋                                                                                                               | 6/75 [00:08<00:50,  1.36it/s][A
Validation:  11%|████████████▉                                                                                                        

Validation metrics: {'f1_score': 0.7568294258912405, 'sensitivity': 0.7509014968077342, 'specificity': 0.9727875932057699, 'accuracy': 0.9434708658854166, 'auc_roc': 0.8880440568923951, 'mean_iou': 0.6166144466400146, 'ssim': 0.0025645809142345872, 'loss': 15.573962809244792}
Epoch: 5
Train Loss : 12.4379
Validation Loss : 15.5740
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317, 13.902807299296061, 13.084696430630155, 12.437855381435819]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745, 16.552221374511717, 16.375254974365234, 15.573962809244792]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<22:48,  7.64s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:07<09:38,  3.25s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:07<05:26,  1.85s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:28,  1.19s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:16,  5.16it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:16,  5.03it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:16,  5.12it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:15,  5.26it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:24<00:15,  5.26it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.8133988658587138, 'sensitivity': 0.8107381304105122, 'specificity': 0.9741475817230013, 'accuracy': 0.9536510944366455, 'auc_roc': 0.9344069149759081, 'mean_iou': 0.6855834540393617, 'ssim': 0.003890601032051361, 'loss': 11.961821746826171}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:07<09:30,  7.71s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:07<02:27,  2.05s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:08<01:12,  1.03s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:42,  1.60it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.7662706263860066, 'sensitivity': 0.7118966199954351, 'specificity': 0.9840280214945475, 'accuracy': 0.9488493601481119, 'auc_roc': 0.8903099060058594, 'mean_iou': 0.6280786621570588, 'ssim': 0.002022048933819557, 'loss': 14.981948852539062}
Epoch: 6
Train Loss : 11.9618
Validation Loss : 14.9819
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317, 13.902807299296061, 13.084696430630155, 12.437855381435819, 11.961821746826171]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745, 16.552221374511717, 16.375254974365234, 15.573962809244792, 14.981948852539062]



Training:   0%|                                                                                                                                  | 0/180 [00:00<?, ?it/s][A
Training:   1%|▋                                                                                                                         | 1/180 [00:07<23:15,  7.80s/it][A
Training:   1%|█▎                                                                                                                        | 2/180 [00:07<09:50,  3.32s/it][A
Training:   2%|██                                                                                                                        | 3/180 [00:08<05:33,  1.88s/it][A
Training:   2%|██▋                                                                                                                       | 4/180 [00:08<03:32,  1.21s/it][A
Training:   3%|███▍                                                                                                                   

Training:  52%|███████████████████████████████████████████████████████████████▏                                                         | 94/180 [00:24<00:15,  5.53it/s][A
Training:  53%|███████████████████████████████████████████████████████████████▊                                                         | 95/180 [00:24<00:15,  5.56it/s][A
Training:  53%|████████████████████████████████████████████████████████████████▌                                                        | 96/180 [00:24<00:14,  5.60it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▏                                                       | 97/180 [00:24<00:14,  5.65it/s][A
Training:  54%|█████████████████████████████████████████████████████████████████▉                                                       | 98/180 [00:25<00:14,  5.67it/s][A
Training:  55%|██████████████████████████████████████████████████████████████████▌                                                     

Training metrics: {'f1_score': 0.8201951172616747, 'sensitivity': 0.816516426205635, 'specificity': 0.9753460698657566, 'accuracy': 0.9554017861684163, 'auc_roc': 0.9359944621721904, 'mean_iou': 0.6952499293618732, 'ssim': 0.004105333435452647, 'loss': 11.526486290825737}



Validation:   0%|                                                                                                                                 | 0/75 [00:00<?, ?it/s][A
Validation:   1%|█▌                                                                                                                       | 1/75 [00:08<10:00,  8.11s/it][A
Validation:   4%|████▊                                                                                                                    | 3/75 [00:08<02:35,  2.16s/it][A
Validation:   7%|████████                                                                                                                 | 5/75 [00:08<01:16,  1.09s/it][A
Validation:   9%|███████████▎                                                                                                             | 7/75 [00:08<00:44,  1.51it/s][A
Validation:  12%|██████████████▌                                                                                                      

Validation metrics: {'f1_score': 0.7664717721939087, 'sensitivity': 0.7136205937465032, 'specificity': 0.9839075096448262, 'accuracy': 0.9485255432128906, 'auc_roc': 0.8942509396870931, 'mean_iou': 0.6280761444568634, 'ssim': 0.002221194242980952, 'loss': 14.970980173746744}
Epoch: 7
Train Loss : 11.5265
Validation Loss : 14.9710
training loss per epoch: [29.785734346177843, 17.45008659362793, 15.192716725667317, 13.902807299296061, 13.084696430630155, 12.437855381435819, 11.961821746826171, 11.526486290825737]
validation loss per epoch: [31.776187235514325, 21.066141408284505, 23.455560251871745, 16.552221374511717, 16.375254974365234, 15.573962809244792, 14.981948852539062, 14.970980173746744]


In [141]:
from datasets.dataset import RetinaSegmentationDataset
import torchvision.transforms as transforms
from img_transform.transforms import EyeMaskCustomTransform, EyeDatasetCustomTransform


DRIVE_TRANSFORMS = transforms.Compose([
    transforms.ToTensor(),
    torch.nn.ConstantPad2d((0, 75, 0, 56), 0),
    EyeDatasetCustomTransform(mask_threshold=0.25),
])


# Load the validation datasets
test_path = os.path.join(rootdir, "Testing")
test_file_basenames = os.listdir(os.path.join(test_path, "images"))
test_dataset = RetinaSegmentationDataset(test_path, test_file_basenames, has_labels=False, img_transforms=DRIVE_TRANSFORMS)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, num_workers=1,
    pin_memory=True, shuffle=False)

In [172]:
def predict_model(model, dataloader, device):
    model.eval()
    with torch.no_grad():
        for ind, (img, lbl) in enumerate(tqdm(dataloader, desc="Testing")):
            # Copy to device
            img = img.to(device)
            # Make the prediction
            lbl_pred = model(img)
            im = np.round(F.sigmoid(lbl_pred).squeeze(0).squeeze(0).cpu().detach().numpy()[:584, :565])
            im = Image.fromarray(im.astype(np.uint8) * 255)
            im.save(f"C:/Users/shawn/Desktop/Development/CS7643/drive_predicted/{ind}.png")

In [173]:
predict_model(model, test_dataloader, device)

result_printer.close()


Testing:   0%|                                                                                                                                    | 0/20 [00:00<?, ?it/s][A
Testing:   5%|██████▏                                                                                                                     | 1/20 [00:01<00:20,  1.08s/it][A
Testing:  10%|████████████▍                                                                                                               | 2/20 [00:01<00:09,  1.95it/s][A
Testing:  30%|█████████████████████████████████████▏                                                                                      | 6/20 [00:01<00:02,  6.99it/s][A
Testing:  50%|█████████████████████████████████████████████████████████████▌                                                             | 10/20 [00:01<00:00, 11.84it/s][A
Testing:  70%|██████████████████████████████████████████████████████████████████████████████████████                                  

In [71]:
import numpy as np
test = torch.Tensor(np.ones((3, 3)))

In [174]:
import pickle

with open(os.path.join(test_path, "images", test_file_basenames[0]), "rb") as f:
    test_img = pickle.load(f)

array([[[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        ...,
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        ...,
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],

       [[ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        ...,
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]],

       ...,

       [[ 8,  6,  7,  0],
        [ 8,  6,  7,  0],
        [ 8,  6,  7,  0],
        ...,
        [ 8,  8, 10,  0],
        [ 8,  8, 10,  0],
        [ 7,  7,  7,  0]],

       [[ 8,  6,  7,  0],
        [ 8,  6,  7,  0],
        [ 8,  6,  7,  0],
        ...,
        [ 7,  7,  9,  0],
        [ 6,  6,  8,  0],
        [ 6,  6,  6,  0]],

       [[ 9,  7,  8,  0],
        [ 9,  7,  8,  0],
        [ 9,  7,  8,  0],
        ...,
        [ 1,  0,  5,  0],
        [ 0,  0

In [108]:
565-640

-75