In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
import os
import matplotlib.pyplot as plt

from src.datasets import BiosensorDataset
from src.unet import UNet
from src.train import train_model

In [11]:
BIO_LENGTH = 32
BATCH_SIZE = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [12]:
data_path = 'data_with_centers/'
dataset = BiosensorDataset(data_path, mask_type=bool, biosensor_length=BIO_LENGTH, mask_size=80)

TRAIN_SIZE = int(len(dataset)*0.86)
VAL_SIZE = len(dataset) - TRAIN_SIZE

train_data, val_data = torch.utils.data.random_split(dataset, [TRAIN_SIZE, VAL_SIZE])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True)

In [13]:
model = UNet(n_channels=BIO_LENGTH, n_classes=1)
model = model.to(device)

In [14]:
try:
    train_model(
        model,
        device,
        train_loader,
        val_loader,
        learning_rate=0.01,
        epochs=10,
        # checkpoint_dir=checkpoint_dir,
        amp=True,
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

Starting training:
        Epochs:          10
        Batch size:      8
        Learning rate:   0.01
        Training size:   140
        Validation size: 23
        Device:          cuda
        Mixed Precision: True
    


                                                                  (batch)=1.59]

Validation Dice score: 8.123965500317354e-09


                                                                   (batch)=1.23]

Validation Dice score: 0.05712094530463219


                                                                   (batch)=1.13]

Validation Dice score: 0.05712568014860153


                                                                   (batch)=1.42]

Validation Dice score: 8.02481281425571e-09


                                                                   (batch)=1.05]

Validation Dice score: 8.18197065655113e-09


                                                                   (batch)=1.07]

Validation Dice score: 0.07384420931339264


                                                                   (batch)=1.05]

Validation Dice score: 0.05178339406847954


                                                                   (batch)=1.19]

Validation Dice score: 7.994119144427714e-09


                                                                   (batch)=1.05]

Validation Dice score: 0.17355875670909882


                                                                   (batch)=0.961]

Validation Dice score: 0.0005081382114440203


                                                                   (batch)=1.05] 

Validation Dice score: 0.2035689800977707


                                                                   (batch)=0.901]

Validation Dice score: 0.03852665051817894


                                                                  s (batch)=1.16] 

Validation Dice score: 0.2229548692703247


                                                                  s (batch)=0.892]

Validation Dice score: 0.2803993225097656


                                                                  s (batch)=0.997]

Validation Dice score: 0.21752914786338806


                                                                  s (batch)=0.987]

Validation Dice score: 0.24236151576042175


                                                                  s (batch)=0.972]

Validation Dice score: 0.34233757853507996


Epoch 1/10: 144img [00:55,  2.59img/s, loss (batch)=0.8]                          


Validation Dice score: 0.3241266906261444
Checkpoint 1 saved!


                                                                  (batch)=0.888]

Validation Dice score: 0.34390920400619507


                                                                   (batch)=0.777]

Validation Dice score: 0.378205269575119


                                                                   (batch)=0.881]

Validation Dice score: 0.34273749589920044


                                                                   (batch)=0.773]

Validation Dice score: 0.3208479881286621


                                                                   (batch)=0.887]

Validation Dice score: 0.23106259107589722


                                                                   (batch)=1.02] 

Validation Dice score: 0.38896387815475464


                                                                   (batch)=0.784]

Validation Dice score: 0.3262336254119873


                                                                   (batch)=0.833]

Validation Dice score: 0.32033416628837585


                                                                   (batch)=0.865]

Validation Dice score: 0.3415219783782959


                                                                   (batch)=0.796]

Validation Dice score: 0.40790340304374695


                                                                   (batch)=0.875]

Validation Dice score: 0.4032191038131714


                                                                   (batch)=0.718]

Validation Dice score: 0.3259432017803192


                                                                  s (batch)=0.822]

Validation Dice score: 0.39718371629714966


                                                                  s (batch)=0.86] 

Validation Dice score: 0.21948055922985077


                                                                  s (batch)=1.03]

Validation Dice score: 0.4066433906555176


                                                                  s (batch)=0.773]

Validation Dice score: 0.3399045169353485


                                                                  s (batch)=0.848]

Validation Dice score: 0.36841917037963867


Epoch 2/10: 144img [00:51,  2.80img/s, loss (batch)=0.629]                        


Validation Dice score: 0.3775270879268646
Checkpoint 2 saved!


                                                                  (batch)=0.734]

Validation Dice score: 0.3946830630302429


                                                                   (batch)=0.716]

Validation Dice score: 0.4013294577598572


                                                                   (batch)=0.684]

Validation Dice score: 0.40349480509757996


                                                                   (batch)=0.746]

Validation Dice score: 0.41671907901763916


                                                                   (batch)=0.737]

Validation Dice score: 0.4315182864665985


                                                                   (batch)=0.676]

Validation Dice score: 0.43066680431365967


                                                                   (batch)=1.06] 

Validation Dice score: 0.4236956536769867


                                                                   (batch)=0.814]

Validation Dice score: 0.4227948784828186


                                                                   (batch)=0.81] 

Validation Dice score: 0.43142926692962646


                                                                   (batch)=0.844]

Validation Dice score: 0.4311390817165375


                                                                   (batch)=0.712]

Validation Dice score: 0.43057164549827576


                                                                   (batch)=0.789]

Validation Dice score: 0.43164190649986267


                                                                  s (batch)=0.765]

Validation Dice score: 0.43191489577293396


                                                                  s (batch)=0.731]

Validation Dice score: 0.43341195583343506


                                                                  s (batch)=0.802]

Validation Dice score: 0.4330793619155884


                                                                  s (batch)=0.739]

Validation Dice score: 0.43101081252098083


                                                                  s (batch)=0.676]

Validation Dice score: 0.43126147985458374


Epoch 3/10: 144img [00:51,  2.81img/s, loss (batch)=0.908]                        


Validation Dice score: 0.43379074335098267
Checkpoint 3 saved!


                                                                  (batch)=0.822]

Validation Dice score: 0.4321971535682678


                                                                   (batch)=0.806]

Validation Dice score: 0.4309436082839966


                                                                   (batch)=0.727]

Validation Dice score: 0.4363459050655365


                                                                   (batch)=0.775]

Validation Dice score: 0.43713295459747314


                                                                   (batch)=0.734]

Validation Dice score: 0.4283600449562073


                                                                   (batch)=0.788]

Validation Dice score: 0.434039443731308


                                                                   (batch)=0.686]

Validation Dice score: 0.4305851459503174


                                                                   (batch)=0.765]

Validation Dice score: 0.43289148807525635


                                                                   (batch)=0.907]

Validation Dice score: 0.43561476469039917


                                                                   (batch)=0.76] 

Validation Dice score: 0.4327438473701477


                                                                   (batch)=0.902]

Validation Dice score: 0.43455106019973755


                                                                   (batch)=0.663]

Validation Dice score: 0.4317592978477478


                                                                  s (batch)=0.843]

Validation Dice score: 0.4345693290233612


                                                                  s (batch)=0.772]

Validation Dice score: 0.43603453040122986


                                                                  s (batch)=0.681]

Validation Dice score: 0.4367217421531677


                                                                  s (batch)=0.71] 

Validation Dice score: 0.4378858208656311


                                                                  s (batch)=0.716]

Validation Dice score: 0.43534159660339355


Epoch 4/10: 144img [00:52,  2.75img/s, loss (batch)=0.688]                        


Validation Dice score: 0.4399131238460541
Checkpoint 4 saved!


                                                                  (batch)=0.828]

Validation Dice score: 0.43877023458480835


                                                                   (batch)=0.804]

Validation Dice score: 0.42833805084228516


                                                                   (batch)=0.767]

Validation Dice score: 0.43551039695739746


                                                                   (batch)=0.681]

Validation Dice score: 0.4349666237831116


                                                                   (batch)=0.801]

Validation Dice score: 0.43133431673049927


                                                                   (batch)=0.811]

Validation Dice score: 0.43529462814331055


                                                                   (batch)=0.646]

Validation Dice score: 0.43482136726379395


                                                                   (batch)=0.716]

Validation Dice score: 0.4332927167415619


                                                                   (batch)=0.676]

Validation Dice score: 0.4329708516597748


                                                                   (batch)=0.655]

Validation Dice score: 0.4357469975948334


                                                                   (batch)=0.758]

Validation Dice score: 0.43546104431152344


                                                                   (batch)=0.794]

Validation Dice score: 0.4361634850502014


                                                                  s (batch)=0.839]

Validation Dice score: 0.4371063709259033


Epoch 5/10:  80%|████████  | 112/140 [00:39<00:09,  2.80img/s, loss (batch)=0.692]


KeyboardInterrupt: 