In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
from torchinfo import summary

from src.datasets import BiosensorDataset, create_datasets
from src.model_parts import *
from src.models import *
from src.train import train_model
from src.evaluate import evaluate
from src.utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

torch.manual_seed(42)
np.random.seed(42)

batch_size = 16
upscale_factor = 4

config = {
    'path': 'C:/onlab_git/Onlab/data_with_centers/',
    'mask_type': bool,
    'augment': True,
    'noise': 0.0,
    'dilation': 0,
    'tiling': False,
    'tiling_ratio': 1,
}

create_dataset_args = {
    'train_percent': 0.59,
    'test_percent': 0.215,
}

calc_config = {
    'biosensor_length': 8,
    'mask_size': 80 * upscale_factor,
    'input_scaling': False,
    'upscale_mode': 'nearest',
}

train_dataset, val_dataset, test_dataset = create_datasets(config, create_dataset_args, calc_config)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print('Train size:', len(train_dataset))
print('Validation size:', len(val_dataset))
print('Test size:', len(test_dataset))

In [None]:
model = UNet4(n_channels=calc_config['biosensor_length'], n_classes=1, down_conv=DoubleConv, up_conv=DoubleConv, bilinear=False)

model = model.to(device)
print(model.__class__.__name__)
project_name = ""
model_name = ""

model_summary = summary(model)
print(model_summary.trainable_params)

In [None]:
try:
    train_model(
        model,
        project_name,
        model_name,
        device,
        train_loader,
        val_loader,
        learning_rate=0.01,
        epochs=10,
        checkpoint_dir='test_saves',
        amp=True,
        wandb_logging=False,
        tile_ratio=config['tiling_ratio'],
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

In [None]:
from torchinfo import summary
model_summary = summary(model, depth=4)
print(model_summary)
# print(model_summary.total_params)
# print(model_summary.trainable_params)

In [None]:
# Load the checkpoint
checkpoint = torch.load("test_saves/checkpoint_epoch18.pth")
# Get the learning rate and remove it from the checkpoint
lr = checkpoint.pop('learning_rate')
# Load the state dictionary into the model
model.load_state_dict(checkpoint)
# Move the model to the device
model = model.to(device)

evaluate_after_training(model, val_loader, test_loader, device)

In [23]:
# Saving the best model for production
# model = UNet8(n_channels=8, n_classes=1)
# checkpoint = torch.load("checkpoints/checkpoint_8_4_85.pth")
# lr = checkpoint.pop('learning_rate')
# model.load_state_dict(checkpoint)

# torch.jit.script(model).save('saved_models/srunet8_len8_40dice.pth')

In [None]:
for batch_idx, (data, labels) in enumerate(test_loader):
    data = data.to(device)
    labels = labels.to(device)

    predictions = model(data)

    binary_predictions = (torch.nn.functional.sigmoid(predictions) > 0.5)

    # binary_predictions = binary_predictions.cpu().detach().numpy()
    # labels = labels.cpu().numpy()
    # predictions = predictions.cpu().detach().numpy()

    for i in range(len(data)):
        # index = (batch_idx * len(data) + i + 1)
        # print(index)
        label = np.squeeze(labels[i])
        binary_prediction = np.squeeze(binary_predictions[i])

        plot_results(data[i][-1], label, np.squeeze(predictions[i]), binary_prediction)

    break


In [None]:
plot_loader_data(test_loader, 'Test')

In [None]:
data_path = 'C:/onlab_git/Onlab/data_with_centers/'
train_percent = 0.59
test_percent = 0.215

files = os.listdir(data_path)
train_size = int(train_percent * len(files))
val_size = len(files) - train_size
if test_percent > 0:
    test_size = int(test_percent * len(files))
    val_size = val_size - test_size

print('Train size:', train_size)
print('Validation size:', val_size)
print('Test size:', test_size)


# tran, val, test: train %, test %
# 128, 16, 19: .79, .12
# 112, 24, 27: .69, .17
# 96, 32, 35: .59, .215     this is the best - wandb logs
# 80, 40, 43: .495 .265
# 64, 48, 51: .395 .315
# 48, 48, 67: .3 .415

SRUNet4_80_single_double_conv:  
Validation dice score: 0.3975284993648529, Detection rate: 0.7234726688102894  
Test dice score: 0.37295183539390564, Detection rate: 0.6732588134135855

SRUNet4_80_single_double_bilinear:  
Validation dice score: 0.37835413217544556, Detection rate: 0.6852090032154341  
Test dice score: 0.35115525126457214, Detection rate: 0.6213814846660934

SRUNet4_80_single_triple_conv:  
Validation dice score: 0.3669853210449219, Detection rate: 0.6864951768488746  
Test dice score: 0.33828431367874146, Detection rate: 0.6423043852106621

SRUNet4_80_double_triple_conv:  
Validation dice score: 0.36824554204940796, Detection rate: 0.6520900321543408  
Test dice score: 0.3549373149871826, Detection rate: 0.6102034967039266

SRUNet4_80_double_single_conv:  
Validation dice score: 0.40089017152786255, Detection rate: 0.7331189710610932  
Test dice score: 0.37190067768096924, Detection rate: 0.6798509601605044

UNet_80_double_conv:  
Validation dice score: 0.43662238121032715, Detection rate: 0.8377581120943953  
Test dice score: 0.4221133887767792, Detection rate: 0.7871352785145889  

UNet_80_double_bilinear:  
Validation dice score: 0.4426659047603607, Detection rate: 0.8292772861356932  
Test dice score: 0.42389044165611267, Detection rate: 0.7791777188328912  

Validation dice score: 0.43791401386260986, Detection rate: 0.8547197640117994  
Test dice score: 0.42553234100341797, Detection rate: 0.8017241379310345  
epoch 17:  
Validation dice score: 0.45137089490890503, Detection rate: 0.8521386430678466  
Test dice score: 0.4385537803173065, Detection rate: 0.8090185676392573  

Validation dice score: 0.4148973822593689, Detection rate: 0.8639380530973452  
Test dice score: 0.4070836901664734, Detection rate: 0.8246021220159151

UNet_80_double_nearest:  
Validation dice score: 0.45902019739151, Detection rate: 0.8174778761061947  
Test dice score: 0.44664859771728516, Detection rate: 0.7712201591511937

Validation dice score: 0.44770652055740356, Detection rate: 0.8547197640117994  
Test dice score: 0.43240076303482056, Detection rate: 0.7967506631299734

UNet_80_triple_nearest:  
Validation dice score: 0.43449652194976807, Detection rate: 0.8488200589970502  
Test dice score: 0.42631348967552185, Detection rate: 0.7950928381962865

Validation dice score: 0.44476252794265747, Detection rate: 0.8528761061946902  
Test dice score: 0.4263116717338562, Detection rate: 0.7891246684350133

UNet_80_triple_bilinear:
Validation dice score: 0.44418731331825256, Detection rate: 0.859882005899705  
Test dice score: 0.4293147325515747, Detection rate: 0.8252652519893899

Validation dice score: 0.45098742842674255, Detection rate: 0.8447640117994101  
Test dice score: 0.43127065896987915, Detection rate: 0.7911140583554377

UNet_80_double_triple_conv:  
Validation dice score: 0.43232494592666626, Detection rate: 0.8668879056047197  
Test dice score: 0.4190724790096283, Detection rate: 0.8272546419098143

UNet_80_double_triple_bilinear:  
Validation dice score: 0.44077542424201965, Detection rate: 0.8366519174041298  
Test dice score: 0.428361177444458, Detection rate: 0.7745358090185677

UNet_80_double_no_relu_conv:  
Validation dice score: 0.42527180910110474, Detection rate: 0.7614306784660767
Test dice score: 0.41641998291015625, Detection rate: 0.7178381962864722