In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from torchinfo import summary

from src.datasets import BiosensorDataset, create_datasets
from src.test_models.test_parts import *
from src.test_models.modular_models import *
from src.train import train_model
from src.evaluate import evaluate
from src.utils import *

In [2]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, down_conv, up_conv, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = down_conv(n_channels, 64)
        self.down1 = Down(64, 128, down_conv)
        self.down2 = Down(128, 256, down_conv)
        self.down3 = Down(256, 512, down_conv)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor, down_conv)
        self.up1 = Up(1024, 512 // factor, up_conv, bilinear)
        self.up2 = Up(512, 256 // factor, up_conv, bilinear)
        self.up3 = Up(256, 128 // factor, up_conv, bilinear)
        self.up4 = Up(128, 64, up_conv, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


class UNet4(nn.Module):
    def __init__(self, n_channels, n_classes, down_conv, up_conv, bilinear=False):
        super(UNet4, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        self.inc = down_conv(n_channels, 64)
        self.down1 = Down(64, 128, down_conv)
        self.down2 = Down(128, 256, down_conv)
        self.down3 = Down(256, 512, down_conv)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor, down_conv)
        self.up1 = Up(1024, 512 // factor, up_conv, bilinear)
        self.up2 = Up(512, 256 // factor, up_conv, bilinear)
        self.up3 = Up(256, 128 // factor, up_conv, bilinear)
        self.up4 = Up(128, 64 // factor, up_conv, bilinear)
        self.up5 = Up(64, 32 // factor, up_conv, bilinear)
        self.up6 = Up(32, 16, up_conv, bilinear)
        self.outc = OutConv(16, n_classes)

        self.up_s1=Upscaling(64, 32, up_conv)
        self.up_s2=Upscaling(32, 16, up_conv)

    def forward(self, xs):
        x1 = self.inc(xs)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        # Upsample "negative" layers
        x0=self.up_s1(x1)
        x_1=self.up_s2(x0)
        
        x = self.up5(x, x0)
        x = self.up6(x, x_1)
        x = self.outc(x)
        return x

class UNet8(nn.Module):
    def __init__(self, n_channels, n_classes, down_conv, up_conv, bilinear=False):
        super(UNet8, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = down_conv(n_channels, 64)
        self.down1 = Down(64, 128, down_conv)
        self.down2 = Down(128, 256, down_conv)
        self.down3 = Down(256, 512, down_conv)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor, down_conv)
        self.up1 = Up(1024, 512 // factor, up_conv, bilinear)
        self.up2 = Up(512, 256 // factor, up_conv, bilinear)
        self.up3 = Up(256, 128 // factor, up_conv, bilinear)
        self.up4 = Up(128, 64 // factor, up_conv, bilinear)
        self.up5 = Up(64, 32 // factor, up_conv, bilinear)
        self.up6 = Up(32, 16 // factor, up_conv, bilinear)
        self.up7 = Up(16, 8, up_conv, bilinear)
        self.outc = OutConv(8, n_classes)

        self.up_s1=Upscaling(64, 32)
        self.up_s2=Upscaling(32, 16)
        self.up_s3=Upscaling(16, 8)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

        x0=self.up_s1(x1)
        x_1=self.up_s2(x0)
        x_2=self.up_s3(x_1)

        x = self.up5(x, x0)
        x = self.up6(x, x_1)
        x = self.up7(x, x_2)
        x = self.outc(x)
        return x

In [None]:
data_path = 'C:/onlab_git/Onlab/data_with_centers/'
train_percent = 0.59
test_percent = 0.215

files = os.listdir(data_path)
train_size = int(train_percent * len(files))
val_size = len(files) - train_size
if test_percent > 0:
    test_size = int(test_percent * len(files))
    val_size = val_size - test_size

print('Train size:', train_size)
print('Validation size:', val_size)
print('Test size:', test_size)


# tran, val, test: train %, test %
# 128, 16, 19: .79, .12
# 112, 24, 27: .69, .17
# 96, 32, 35: .59, .215     this is the best - wandb logs
# 80, 40, 43: .495 .265
# 64, 48, 51: .395 .315
# 48, 48, 67: .3 .415

In [None]:
def create_datasets(config):
    path = config['path']
    train_percent = config['train_percent']
    mask_type = config['mask_type']
    test_percent = config.get('test_percent', 0)
    biosensor_length = config.get('biosensor_length', 8)
    mask_size = config.get('mask_size', 80)
    augment = config.get('augment', False)
    noise = config.get('noise', 0.0)
    dilation = config.get('dilation', 0)
    input_scaling = config.get('input_scaling', False)
    upscale_mode = config.get('upscale_mode', 'nearest')
    print(upscale_mode)
    # Your function implementation here
    pass

# Usage
config = {
    'path': 'path/to/data',
    'train_percent': 0.8,
    'mask_type': 'bool',
    'test_percent': 0.1,
    'biosensor_length': 128,
    'mask_size': 80,
    'augment': True,
    'noise': 0.1,
    'dilation': 0,
    'input_scaling': False,
    'upscale_mode': 'bilinear'
}

create_datasets(config)

In [3]:
# def create_datasets(
# path, train_percent, test_percent=0,
# mask_type=bool, biosensor_length=8, mask_size=80, 
# augment=False, noise=0.0, dilation=0, input_scaling=False, upscale_mode='nearest', tiling=False, tiling_ratio=1)

# class BiosensorDataset(Dataset): def __init__(
# self, path, files, mean, std, 
# mask_type=bool, biosensor_length=8, mask_size=80, 
# augment=False, noise=0.0, dilation=0, input_scaling=False, upscale_mode='nearest', tiling=False, tiling_ratio=1)

config = {
    'path': 'C:/onlab_git/Onlab/data_with_centers/',
    'mask_type': bool,
    'augment': False,
    'noise': 0.0,
    'dilation': 0,
    'tiling': False,
    'tiling_ratio': 1
}

calc_config = {
    'biosensor_length': 8,
    'mask_size': 80,
    'input_scaling': False,
    'upscale_mode': 'nearest',
}

create_dataset_args = {
    'train_percent': 0.59,
    'test_percent': 0.215,
}

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

torch.manual_seed(42)
np.random.seed(42)

data_path = 'C:/onlab_git/Onlab/data_with_centers/'
checkpoint_dir = 'test_saves'
train_percent = 0.59
test_percent = 0.215
batch_size = 16
bio_len = 8
upscale_factor = 8
mask_size = 80 * upscale_factor
input_scaling = False

train_dataset, val_dataset, test_dataset = create_datasets(data_path, train_percent, bool, test_percent=test_percent, 
                                biosensor_length=bio_len, mask_size=mask_size, augment=True, noise=0.0,
                                dilation=0, input_scaling=input_scaling, upscale_mode='nearest') # nearest bilinear bicubic

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print('Train size:', len(train_dataset))
print('Validation size:', len(val_dataset))
print('Test size:', len(test_dataset))

In [None]:
"""
Models to test:
- UNet with convolution and bilinear upsampling
- SRUNet4 with convolution and bilinear upsampling
- SRUNet8 with convolution and bilinear upsampling
- UNet with bigger inputs
- mid channels
- Single, Double, Triple, Quadruple Conv layers

- layer, channel, kernel size changes
___________________________________________________

- UNet_80_double_conv
- UNet_80_double_bilin

- UNet_80_single_conv

- Unet_160_double_conv
- UNet_320_double_conv

- SRUNet4_80_single_conv
- SRUNet4_80_single_bilin

- SRUNet8_80_single_conv
- SRUNet8_80_single_bilin



"""

In [None]:
# model = UNet8(n_channels=bio_len, n_classes=1, down_conv=SingleConv, up_conv=SingleConv, bilinear=False)
# model = UNet8(n_channels=bio_len, n_classes=1, down_conv=DoubleConv, up_conv=DoubleConv, bilinear=False)
# model = UNet8(n_channels=bio_len, n_classes=1, down_conv=TripleConv, up_conv=TripleConv, bilinear=False)
model = model.to(device)
print(model.__class__.__name__)
project_name = "Testing SR models"
# model_name = "UNet_320_triple_conv_bilin"
# model_name = "SRUNet4_80_single_bilin"
model_name = "SRUNet8_80_triple_bilin"

model_summary = summary(model)
print(model_summary.trainable_params)

In [None]:
try:
    train_model(
        model,
        project_name,
        model_name,
        device,
        train_loader,
        val_loader,
        learning_rate=0.03,
        epochs=20,
        checkpoint_dir=checkpoint_dir,
        amp=True,
        wandb_logging=True
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

In [None]:
from torchinfo import summary
model_summary = summary(model, depth=4)
print(model_summary)
# print(model_summary.total_params)
# print(model_summary.trainable_params)

In [15]:
def evaluate_after_training(model, val_loader, test_loader, device):
    val_dice_score, val_detection_rate = evaluate(model, val_loader, device)
    dice_score, detection_rate = evaluate(model, test_loader, device)
    print(f'Validation dice score: {val_dice_score}, Detection rate: {val_detection_rate}')
    print(f'Test dice score: {dice_score}, Detection rate: {detection_rate}')


In [None]:
# Load the checkpoint
checkpoint = torch.load("test_saves/checkpoint_epoch20.pth")
# Get the learning rate and remove it from the checkpoint
lr = checkpoint.pop('learning_rate')
# Load the state dictionary into the model
model.load_state_dict(checkpoint)
# Move the model to the device
model = model.to(device)

evaluate_after_training(model, val_loader, test_loader, device)

In [23]:
# Saving the best model for production
# model = UNet8(n_channels=8, n_classes=1)
# checkpoint = torch.load("checkpoints/checkpoint_8_4_85.pth")
# lr = checkpoint.pop('learning_rate')
# model.load_state_dict(checkpoint)

# torch.jit.script(model).save('saved_models/srunet8_len8_40dice.pth')

In [17]:
def plot_results(bio, mask, prediction, binary_prediction):
    plt.figure(figsize=(30, 10))

    colored_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.float32)
    colored_mask[mask == 1] = [1, 0, 0, 1]
    colored_mask[mask == 0] = [0, 0, 0, 0]

    colored_prediction = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.float32)
    colored_prediction[binary_prediction == 1] = [0, 0, 1, 1]
    colored_prediction[binary_prediction == 0] = [0, 0, 0, 0]

    plt.subplot(1, 3, 1)
    plt.imshow(bio, cmap='gray')
    plt.imshow(colored_mask, alpha=0.6)
    plt.title('Biosensor with mask')
    
    plt.subplot(1, 3, 2)
    plt.imshow(prediction, cmap='gray')
    plt.imshow(colored_prediction, alpha=0.6)
    plt.title('Prediction with the binary')

    intercection = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.float32)
    intercection[(mask == 1) & (binary_prediction == 1)] = [0, 1, 0, 1]

    plt.subplot(1, 3, 3)
    # plt.imshow(bio, cmap='gray')
    plt.imshow(colored_mask)
    plt.imshow(colored_prediction)
    plt.imshow(intercection)
    plt.title('Label and Prediction overlap')
    
    red_patch = mpatches.Patch(color=[1, 0, 0, 1], label='Mask')
    blue_patch = mpatches.Patch(color=[0, 0, 1, 1], label='Prediction')
    green_patch = mpatches.Patch(color=[0, 1, 0, 1], label='Overlap')

    plt.legend(handles=[red_patch, blue_patch, green_patch], loc='upper right', bbox_to_anchor=(1.5, 1))
    
    plt.show()

In [None]:
for batch_idx, (data, labels) in enumerate(test_loader):
    data = data.to(device)
    labels = labels.to(device)

    predictions = model(data)

    binary_predictions = (torch.nn.functional.sigmoid(predictions) > 0.5)
    binary_predictions = binary_predictions.cpu().detach().numpy()

    labels = labels.cpu().numpy()
    predictions = predictions.cpu().detach().numpy()

    for i in range(len(data)):
        index = (batch_idx * len(data) + i + 1)
        # print(index)
        label = np.squeeze(labels[i])
        binary_prediction = np.squeeze(binary_predictions[i])

        plot_results(data[i][-1].cpu().detach().numpy(), label, np.squeeze(predictions[i]), binary_prediction)

    break


In [None]:
def plot_loader_data(loader, title):
    for batch_idx, (data, labels) in enumerate(loader):
        # Move the data and labels to the CPU
        data = data.cpu().numpy()
        labels = labels.cpu().numpy()

        # if batch_idx == 1:
        #     break

        # Plot each image in the batch
        for i in range(len(data)):
            index = (batch_idx * len(data) + i + 1)

            plt.figure(figsize=(20, 10))

            # Plot the input image
            plt.subplot(1, 3, 1)
            plt.imshow(data[i][-1], cmap='gray')
            plt.title(f'{title} - Image {index} ')

            # Plot the label
            plt.subplot(1, 3, 2)
            plt.imshow(labels[i], cmap='gray')
            plt.title(f'{title} - Label {index}')

            plt.subplot(1, 3, 3)
            plt.imshow(data[i][-1], cmap='gray')
            plt.imshow(labels[i], cmap='Reds', alpha=0.25)

            plt.show()

plot_loader_data(test_loader, 'Test')