In [2]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os

from src.datasets import BiosensorDataset, calculate_mean_and_std
from src.unet import UNet
from src.train import train_model

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [5]:
torch.manual_seed(42)

data_path = 'data_with_centers/'
train_percent = 0.86
bio_len = 16
mask_size = 80
batch_size = 4

files = os.listdir(data_path)
train_size = int(train_percent * len(files))
val_size = len(files) - train_size
train_files, val_files = torch.utils.data.random_split(files, [train_size, val_size])

mean, std = calculate_mean_and_std(data_path, train_files, biosensor_length=bio_len)

train_dataset = BiosensorDataset(data_path, train_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)
val_dataset = BiosensorDataset(data_path, val_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


model = UNet(n_channels=bio_len, n_classes=1)
model = model.to(device)

try:
    train_model(
        model,
        device,
        train_loader,
        val_loader,
        learning_rate=0.01,
        epochs=10,
        # checkpoint_dir=checkpoint_dir,
        amp=True,
    )
except torch.cuda.OutOfMemoryError:
    torch.cuda.empty_cache()
    print('Detected OutOfMemoryError!')

Starting training:
        Epochs:          10
        Batch size:      4
        Learning rate:   0.01
        Training size:   140
        Validation size: 23
        Device:          cuda
        Mixed Precision: True
    


                                                                   (batch)=1.05]

Validation Dice score: 9.142484103108472e-09


                                                                   (batch)=1.19]

Validation Dice score: 1.0246270498726062e-08


                                                                   (batch)=0.967]

Validation Dice score: 0.18845197558403015


                                                                   (batch)=1.07] 

Validation Dice score: 0.002005594316869974


                                                                   (batch)=1]   

Validation Dice score: 0.0020711319521069527


                                                                   (batch)=0.965]

Validation Dice score: 0.11602358520030975


                                                                   (batch)=0.932]

Validation Dice score: 0.24308131635189056


                                                                   (batch)=0.991]

Validation Dice score: 0.3540637493133545


                                                                  s (batch)=0.821]

Validation Dice score: 0.32293790578842163


                                                                  s (batch)=0.83] 

Validation Dice score: 0.3662812113761902


                                                                  s (batch)=0.955]

Validation Dice score: 0.3731057047843933


Epoch 1/10: 100%|██████████| 140/140 [00:42<00:00,  3.31img/s, loss (batch)=1.05] 


Checkpoint 1 saved!


                                                                   (batch)=0.859]

Validation Dice score: 0.34857892990112305


                                                                   (batch)=0.658]

Validation Dice score: 0.42682522535324097


                                                                   (batch)=0.779]

Validation Dice score: 0.4110015034675598


                                                                   (batch)=0.751]

Validation Dice score: 0.4149702489376068


                                                                   (batch)=0.624]

Validation Dice score: 0.41655421257019043


                                                                   (batch)=0.771]

Validation Dice score: 0.41103124618530273


                                                                   (batch)=0.792]

Validation Dice score: 0.41430458426475525


                                                                   (batch)=1.03] 

Validation Dice score: 0.42996153235435486


                                                                  s (batch)=0.765]

Validation Dice score: 0.4012256860733032


                                                                  s (batch)=0.66] 

Validation Dice score: 0.4377637803554535


                                                                  s (batch)=0.852]

Validation Dice score: 0.4245809018611908


Epoch 2/10: 100%|██████████| 140/140 [00:40<00:00,  3.45img/s, loss (batch)=0.708]


Checkpoint 2 saved!


                                                                   (batch)=0.997]

Validation Dice score: 0.4205368459224701


                                                                   (batch)=0.989]

Validation Dice score: 0.4339176416397095


                                                                   (batch)=0.635]

Validation Dice score: 0.4327954053878784


                                                                   (batch)=0.608]

Validation Dice score: 0.43062031269073486


                                                                   (batch)=0.58] 

Validation Dice score: 0.4447163939476013


                                                                   (batch)=0.758]

Validation Dice score: 0.44358283281326294


                                                                   (batch)=1.03] 

Validation Dice score: 0.4198544919490814


                                                                   (batch)=0.725]

Validation Dice score: 0.4284883439540863


                                                                  s (batch)=0.613]

Validation Dice score: 0.44039666652679443


                                                                  s (batch)=0.718]

Validation Dice score: 0.4232310652732849


                                                                  s (batch)=0.787]

Validation Dice score: 0.44028162956237793


Epoch 3/10: 100%|██████████| 140/140 [00:40<00:00,  3.48img/s, loss (batch)=0.771]


Checkpoint 3 saved!


                                                                   (batch)=0.77] 

Validation Dice score: 0.4468948543071747


                                                                   (batch)=0.778]

Validation Dice score: 0.44872960448265076


                                                                   (batch)=0.654]

Validation Dice score: 0.4211677014827728


                                                                   (batch)=0.628]

Validation Dice score: 0.43103498220443726


                                                                   (batch)=0.609]

Validation Dice score: 0.43762850761413574


                                                                   (batch)=0.762]

Validation Dice score: 0.44370847940444946


                                                                   (batch)=0.66] 

Validation Dice score: 0.4475911259651184


                                                                   (batch)=0.56] 

Validation Dice score: 0.4507848024368286


                                                                  s (batch)=0.697]

Validation Dice score: 0.45025938749313354


                                                                  s (batch)=0.591]

Validation Dice score: 0.4484519362449646


                                                                  s (batch)=0.756]

Validation Dice score: 0.44979482889175415


Epoch 4/10: 100%|██████████| 140/140 [00:41<00:00,  3.39img/s, loss (batch)=0.925]


Checkpoint 4 saved!


                                                                   (batch)=1.03] 

Validation Dice score: 0.4279215335845947


                                                                   (batch)=0.671]

Validation Dice score: 0.4307696521282196


                                                                   (batch)=0.937]

Validation Dice score: 0.44016849994659424


                                                                   (batch)=0.669]

Validation Dice score: 0.44376787543296814


                                                                   (batch)=0.688]

Validation Dice score: 0.4523976445198059


                                                                   (batch)=0.651]

Validation Dice score: 0.4555555284023285


                                                                   (batch)=0.651]

Validation Dice score: 0.45276281237602234


                                                                   (batch)=0.7]  

Validation Dice score: 0.455926775932312


                                                                  s (batch)=0.742]

Validation Dice score: 0.45220503211021423


                                                                  s (batch)=0.599]

Validation Dice score: 0.45264944434165955


                                                                  s (batch)=0.636]

Validation Dice score: 0.4529517590999603


Epoch 5/10: 100%|██████████| 140/140 [00:41<00:00,  3.36img/s, loss (batch)=0.759]


Checkpoint 5 saved!


                                                                   (batch)=0.621]

Validation Dice score: 0.45137202739715576


                                                                   (batch)=0.703]

Validation Dice score: 0.4586901366710663


                                                                   (batch)=0.639]

Validation Dice score: 0.45658180117607117


                                                                   (batch)=0.601]

Validation Dice score: 0.45842108130455017


                                                                   (batch)=0.635]

Validation Dice score: 0.45950207114219666


                                                                   (batch)=0.589]

Validation Dice score: 0.45477670431137085


                                                                   (batch)=0.787]

Validation Dice score: 0.45118236541748047


                                                                   (batch)=0.762]

Validation Dice score: 0.4386764466762543


                                                                  s (batch)=0.901]

Validation Dice score: 0.44612324237823486


                                                                  s (batch)=0.568]

Validation Dice score: 0.4483659267425537


                                                                  s (batch)=0.754]

Validation Dice score: 0.4547818899154663


Epoch 6/10: 100%|██████████| 140/140 [00:40<00:00,  3.50img/s, loss (batch)=0.629]


Checkpoint 6 saved!


                                                                   (batch)=0.638]

Validation Dice score: 0.4539562165737152


                                                                   (batch)=0.651]

Validation Dice score: 0.44866400957107544


                                                                   (batch)=0.781]

Validation Dice score: 0.4544830322265625


                                                                   (batch)=0.697]

Validation Dice score: 0.45624494552612305


                                                                   (batch)=0.58] 

Validation Dice score: 0.433244526386261


                                                                   (batch)=0.582]

Validation Dice score: 0.439034104347229


                                                                   (batch)=0.686]

Validation Dice score: 0.44757866859436035


                                                                   (batch)=0.894]

Validation Dice score: 0.45342814922332764


                                                                  s (batch)=0.94] 

Validation Dice score: 0.4585075378417969


                                                                  s (batch)=0.718]

Validation Dice score: 0.45952218770980835


                                                                  s (batch)=0.603]

Validation Dice score: 0.45791196823120117


Epoch 7/10: 100%|██████████| 140/140 [00:40<00:00,  3.46img/s, loss (batch)=0.658]


Checkpoint 7 saved!


                                                                   (batch)=0.621]

Validation Dice score: 0.45665308833122253


                                                                   (batch)=0.641]

Validation Dice score: 0.45155948400497437


                                                                   (batch)=0.793]

Validation Dice score: 0.4566790759563446


                                                                   (batch)=0.702]

Validation Dice score: 0.44223928451538086


                                                                   (batch)=0.806]

Validation Dice score: 0.45017069578170776


                                                                   (batch)=0.62] 

Validation Dice score: 0.451468288898468


                                                                   (batch)=0.598]

Validation Dice score: 0.45633286237716675


                                                                   (batch)=0.67] 

Validation Dice score: 0.4553402066230774


                                                                  s (batch)=0.671]

Validation Dice score: 0.4539473056793213


                                                                  s (batch)=0.6]  

Validation Dice score: 0.4549179971218109


                                                                  s (batch)=0.9]  

Validation Dice score: 0.4549688696861267


Epoch 8/10: 100%|██████████| 140/140 [00:42<00:00,  3.31img/s, loss (batch)=0.886]


Checkpoint 8 saved!


                                                                   (batch)=0.905]

Validation Dice score: 0.45732587575912476


                                                                   (batch)=0.668]

Validation Dice score: 0.4545321464538574


                                                                   (batch)=0.672]

Validation Dice score: 0.4555615782737732


                                                                   (batch)=0.798]

Validation Dice score: 0.45393046736717224


                                                                   (batch)=0.615]

Validation Dice score: 0.457345187664032


                                                                   (batch)=0.858]

Validation Dice score: 0.453481525182724


                                                                   (batch)=0.667]

Validation Dice score: 0.433146208524704


                                                                   (batch)=0.648]

Validation Dice score: 0.44496411085128784


                                                                  s (batch)=0.884]

Validation Dice score: 0.44374558329582214


                                                                  s (batch)=0.611]

Validation Dice score: 0.45411843061447144


                                                                  s (batch)=0.732]

Validation Dice score: 0.45759958028793335


Epoch 9/10: 100%|██████████| 140/140 [00:39<00:00,  3.53img/s, loss (batch)=0.586]


Checkpoint 9 saved!


                                                                  s (batch)=0.846]

Validation Dice score: 0.45537418127059937


                                                                  s (batch)=0.686]

Validation Dice score: 0.4561426043510437


                                                                  s (batch)=0.72] 

Validation Dice score: 0.45842301845550537


                                                                  s (batch)=0.931]

Validation Dice score: 0.43563181161880493


                                                                  s (batch)=0.973]

Validation Dice score: 0.4397374391555786


                                                                  s (batch)=0.56] 

Validation Dice score: 0.4487903118133545


                                                                  s (batch)=1.04] 

Validation Dice score: 0.45190393924713135


                                                                  s (batch)=0.675]

Validation Dice score: 0.4537050724029541


                                                                  ss (batch)=0.647]

Validation Dice score: 0.45015329122543335


                                                                  ss (batch)=0.68] 

Validation Dice score: 0.45625266432762146


                                                                  ss (batch)=0.593]

Validation Dice score: 0.44918954372406006


Epoch 10/10: 100%|██████████| 140/140 [00:40<00:00,  3.50img/s, loss (batch)=0.777]

Checkpoint 10 saved!



