In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from torchinfo import summary

from src.datasets import BiosensorDataset, create_datasets
from src.model_parts import *
from src.models import *
# from src.train_tiling import train_model, evaluate
from src.train import train_model
from src.evaluate import evaluate
from src.utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

torch.manual_seed(42)
np.random.seed(42)

batch_size = 16
upscale_factor = 4

config = {
    'path': 'C:/onlab_git/Onlab/data_with_centers/',
    'mask_type': bool,
    'augment': False,
    'noise': 0.0,
    'dilation': 0,
    'tiling': True,
    'tiling_ratio': 2,
}

create_dataset_args = {
    'train_percent': 0.59,
    'test_percent': 0.215,
}

calc_config = {
    'biosensor_length': 8,
    'mask_size': 80 * upscale_factor,
    'input_scaling': False,
    'upscale_mode': 'nearest',
}

train_dataset, val_dataset, test_dataset = create_datasets(config, create_dataset_args, calc_config)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print('Train size:', len(train_dataset))
print('Validation size:', len(val_dataset))
print('Test size:', len(test_dataset))

In [None]:
# model = UNet(n_channels=calc_config['biosensor_length'], n_classes=1, down_conv=SingleConv, up_conv=SingleConv, bilinear=False)
model = SRUNet4(n_channels=calc_config['biosensor_length'], n_classes=1, down_conv=SingleConv, up_conv=SingleConv, bilinear=False)
model = model.to(device)
print(model.__class__.__name__)
project_name = "Testing tiling"
model_name = ""

model_summary = summary(model)
print(model_summary.trainable_params)

In [None]:
try:
    train_model(
        model,
        project_name,
        model_name,
        device,
        train_loader,
        val_loader,
        learning_rate=0.03,
        epochs=5,
        checkpoint_dir='test_saves',
        amp=True,
        wandb_logging=False,
        tile_ratio=config['tiling_ratio'],
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

In [None]:
from torchinfo import summary
model_summary = summary(model, depth=4)
print(model_summary)
# print(model_summary.total_params)
# print(model_summary.trainable_params)

In [None]:
# Load the checkpoint
checkpoint = torch.load("test_saves/checkpoint_epoch14.pth")
# Get the learning rate and remove it from the checkpoint
lr = checkpoint.pop('learning_rate')
# Load the state dictionary into the model
model.load_state_dict(checkpoint)
# Move the model to the device
model = model.to(device)

evaluate_after_training(model, val_loader, test_loader, device)

In [23]:
# Saving the best model for production
# model = UNet8(n_channels=8, n_classes=1)
# checkpoint = torch.load("checkpoints/checkpoint_8_4_85.pth")
# lr = checkpoint.pop('learning_rate')
# model.load_state_dict(checkpoint)

# torch.jit.script(model).save('saved_models/srunet8_len8_40dice.pth')

In [None]:
for batch_idx, (data, labels) in enumerate(test_loader):
    data = data.to(device)
    labels = labels.to(device)

    batch_size, tile_num, channels, height, width = data.shape
    data = data.view(batch_size * tile_num, channels, height, width)

    predictions = model(data)

    predictions = predictions.view(batch_size, tile_num, 1, height, width)
    data = data.view(batch_size, tile_num, channels, height, width)

    binary_predictions = (torch.nn.functional.sigmoid(predictions) > 0.5)

    for i in range(len(data)):
        # index = (batch_idx * len(data) + i + 1)
        # print(index)

        label = np.squeeze(labels[i])
        binary_prediction = np.squeeze(binary_predictions[i])
        label = merge_tiles(label)
        binary_prediction = merge_tiles(binary_prediction)
        d = merge_tiles(data[i,:,-1])
        prediction = merge_tiles(predictions[i, :, 0])

        plot_results(d, label, prediction, binary_prediction)

    break


In [None]:
plot_loader_tiles_data(test_loader, 'Test')