In [7]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
import skimage.measure
import scipy.ndimage

from src.datasets import BiosensorDataset, create_datasets
from src.unet.unet_model import UNet
from src.train import train_model
from src.evaluate import evaluate

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

torch.manual_seed(42)
np.random.seed(42)

data_path = 'C:/onlab_git/Onlab/data_with_centers/'
checkpoint_dir = 'unet-checkpoints'
train_percent = 0.495
test_percent = 0.30
batch_size = 4
bio_len = 8
mask_size = 80

train_dataset, val_dataset, test_dataset = create_datasets(data_path, train_percent, bool, test_percent=test_percent, biosensor_length=bio_len, mask_size=mask_size, augment=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model = UNet(n_channels=bio_len, n_classes=1)
model = model.to(device)

Using device cuda


In [3]:
try:
    train_model(
        model,
        device,
        train_loader,
        val_loader,
        learning_rate=0.01,
        epochs=15,
        checkpoint_dir=checkpoint_dir,
        amp=True,
        wandb_logging=False
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

Starting training:
        Epochs:          15
        Batch size:      4
        Learning rate:   0.01
        Training size:   80
        Validation size: 35
        Device:          cuda
        Mixed Precision: True
        Dilatation:      0
    


Epoch 1/15: 100%|██████████| 80/80 [00:14<00:00,  5.67img/s, loss (batch)=0.807]


Validation Dice score: 0.002604305511340499, Detection rate: 0.005486466715435259
Checkpoint 1 saved!


Epoch 2/15: 100%|██████████| 80/80 [00:13<00:00,  5.86img/s, loss (batch)=0.978]


Validation Dice score: 0.21715934574604034, Detection rate: 0.2607900512070227
Checkpoint 2 saved!


Epoch 3/15: 100%|██████████| 80/80 [00:13<00:00,  6.05img/s, loss (batch)=0.747]


Validation Dice score: 0.3683970868587494, Detection rate: 0.633138258961229
Checkpoint 3 saved!


Epoch 4/15: 100%|██████████| 80/80 [00:12<00:00,  6.30img/s, loss (batch)=0.784]


Validation Dice score: 0.35669973492622375, Detection rate: 0.7136064374542794
Checkpoint 4 saved!


Epoch 5/15: 100%|██████████| 80/80 [00:12<00:00,  6.48img/s, loss (batch)=0.675]


Validation Dice score: 0.3310006260871887, Detection rate: 0.7388441843452817
Checkpoint 5 saved!


Epoch 6/15: 100%|██████████| 80/80 [00:12<00:00,  6.28img/s, loss (batch)=0.896]


Validation Dice score: 0.40452465415000916, Detection rate: 0.7553035844915874
Checkpoint 6 saved!


Epoch 7/15: 100%|██████████| 80/80 [00:12<00:00,  6.31img/s, loss (batch)=0.672]


Validation Dice score: 0.4081939160823822, Detection rate: 0.7286027798098025
Checkpoint 7 saved!


Epoch 8/15: 100%|██████████| 80/80 [00:12<00:00,  6.33img/s, loss (batch)=0.78] 


Validation Dice score: 0.38368427753448486, Detection rate: 0.7750548646671543
Checkpoint 8 saved!


Epoch 9/15: 100%|██████████| 80/80 [00:12<00:00,  6.17img/s, loss (batch)=0.782]


Validation Dice score: 0.4055405855178833, Detection rate: 0.8057790782735919
Checkpoint 9 saved!


Epoch 10/15: 100%|██████████| 80/80 [00:13<00:00,  6.11img/s, loss (batch)=0.696]


Validation Dice score: 0.4229012131690979, Detection rate: 0.8094367227505487
Checkpoint 10 saved!


Epoch 11/15: 100%|██████████| 80/80 [00:12<00:00,  6.16img/s, loss (batch)=0.758]


Validation Dice score: 0.40817418694496155, Detection rate: 0.7900512070226774
Checkpoint 11 saved!


Epoch 12/15: 100%|██████████| 80/80 [00:13<00:00,  5.83img/s, loss (batch)=0.928]


Validation Dice score: 0.395841121673584, Detection rate: 0.8017556693489393
Checkpoint 12 saved!


Epoch 13/15: 100%|██████████| 80/80 [00:13<00:00,  6.09img/s, loss (batch)=0.966]


Validation Dice score: 0.4042311906814575, Detection rate: 0.8178493050475494
Checkpoint 13 saved!


Epoch 14/15: 100%|██████████| 80/80 [00:13<00:00,  6.12img/s, loss (batch)=0.765]


Validation Dice score: 0.4182581603527069, Detection rate: 0.7970007315288954
Checkpoint 14 saved!


Epoch 15/15: 100%|██████████| 80/80 [00:13<00:00,  6.15img/s, loss (batch)=0.676]

Validation Dice score: 0.4142612814903259, Detection rate: 0.8269934162399415
Checkpoint 15 saved!





In [12]:
# from torchsummary import summary
# Print the model summary
# summary(model, input_size=(bio_len, mask_size, mask_size))

# Load the checkpoint
checkpoint = torch.load("unet-checkpoints/checkpoint_epoch15.pth")
# Get the learning rate and remove it from the checkpoint
lr = checkpoint.pop('learning_rate')
# Load the state dictionary into the model
model.load_state_dict(checkpoint)
# Move the model to the device
model = model.to(device)

In [26]:
val_dice_score, val_detection_rate = evaluate(model, val_loader, device)
dice_score, detection_rate = evaluate(model, test_loader, device)
print(f'Validation dice score: {val_dice_score}, Detection rate: {val_detection_rate}')
print(f'Validation dice score: {dice_score}, Detection rate: {detection_rate}')

                                                                    

Validation dice score: 0.4142612814903259, Detection rate: 0.8269934162399415
Validation dice score: 0.389661967754364, Detection rate: 0.7946449916327994




In [15]:
# Saving the best model for production
# model = UNet(n_channels=8, n_classes=1)
# checkpoint = torch.load("checkpoints/checkpoint_8_4_85.pth")
# lr = checkpoint.pop('learning_rate')
# model.load_state_dict(checkpoint)

torch.jit.script(model).save('saved_models/unet_len8.pth')

In [None]:
# Load a batch of data and labels
loader_iter = iter(test_loader)
data, labels = next(loader_iter)
data, labels = next(loader_iter)
data, labels = next(loader_iter)

# Move the data and labels to the device
data = data.to(device)
labels = labels.to(device)

# Get the predictions
predictions = model(data)

sigmoid_predictions = (torch.nn.functional.sigmoid(predictions) > 0.5)
sigmoid_predictions = sigmoid_predictions.cpu().detach().numpy()

# Move the predictions and labels to the CPU and convert them to numpy arrays
predictions = predictions.cpu().detach().numpy()
binary_predictions = (predictions > 0.5).astype(np.uint8)

labels = labels.cpu().numpy()

# Plot the data, the labels, and the predictions
for i in range(len(data)):
    plt.figure(figsize=(10, 50))
    
    plt.subplot(1, 5, 1)
    plt.imshow(data[i].cpu().numpy()[-1], cmap='gray')
    plt.title('Biosensor')
    
    plt.subplot(1, 5, 2)
    plt.imshow(np.squeeze(labels[i]), cmap='gray')
    plt.title('True Mask')
    
    plt.subplot(1, 5, 3)
    plt.imshow(np.squeeze(predictions[i]), cmap='gray')
    plt.title('Prediction')

    plt.subplot(1, 5, 4)
    plt.imshow(np.squeeze(binary_predictions[i]), cmap='gray')
    plt.title('Binary Prediction')

    plt.subplot(1, 5, 5)
    plt.imshow(np.squeeze(sigmoid_predictions[i]), cmap='gray')
    plt.title('Sigmoid Binary Prediction')
    
    plt.show()

In [None]:
# Plot the data, the labels, and the predictions
for i in range(len(data)):
    plt.figure(figsize=(10, 10))
    
    plt.subplot(1, 3, 1)
    plt.imshow(data[i].cpu().numpy()[-1], cmap='gray')
    plt.title('Data')
    
    plt.subplot(1, 3, 2)
    plt.imshow(data[i].cpu().numpy()[-1], cmap='gray')
    plt.imshow(np.squeeze(labels[i]), cmap='jet', alpha=0.5)
    plt.title('Data with Label overlay')
    
    plt.subplot(1, 3, 3)
    plt.imshow(data[i].cpu().numpy()[-1], cmap='gray')
    plt.imshow(np.squeeze(binary_predictions[i]), cmap='jet', alpha=0.5)
    plt.title('Data with Binary Prediction overlay')
    
    plt.show()

In [None]:
# Plot the label and the prediction
for i in range(len(labels)):
    plt.figure(figsize=(10, 10))
    
    plt.subplot(1, 2, 1)
    plt.imshow(np.squeeze(labels[i]), cmap='gray')
    plt.title('Label')
    
    plt.subplot(1, 2, 2)
    plt.imshow(np.squeeze(labels[i]), cmap='gray')
    plt.imshow(np.squeeze(binary_predictions[i]), cmap='jet', alpha=0.5)
    plt.title('Label with Prediction overlay')
    
    plt.show()

In [None]:
from src.utils import *

In [17]:
label, detected = pos_pixels(model, test_loader, device, threshold=0.5)
print(label, detected)

9485 8665


In [32]:
cell_detection_rate, total, detected = cell_detection_skimage(model, val_loader, device)
print(f'Cell detection rate: {cell_detection_rate}, total cells: {total}, detected cells: {detected}')
cell_detection_rate, total, detected = cell_detection_skimage(model, test_loader, device)
print(f'Cell detection rate: {cell_detection_rate}, total cells: {total}, detected cells: {detected}')

Cell detection rate: 0.8149231894659839, total cells: 2734, detected cells: 2228
Cell detection rate: 0.7850824766913699, total cells: 4183, detected cells: 3284


In [38]:
cell_detection_rate, total, detected = cell_detection_scipy(model, val_loader, device)
print(f'Cell detection rate: {cell_detection_rate}, total cells: {total}, detected cells: {detected}')
cell_detection_rate, total, detected = cell_detection_scipy(model, test_loader, device)
print(f'Cell detection rate: {cell_detection_rate}, total cells: {total}, detected cells: {detected}')

Cell detection rate: 0.8269934162399415, total cells: 2734, detected cells: 2261
Cell detection rate: 0.7946449916327994, total cells: 4183, detected cells: 3324
