In [1]:
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt

from src.datasets import BiosensorDataset, calculate_mean_and_std
from src.unet import UNet

from src.train_lightning import UNetLightningModule, BiosensorDataModule
from src.losses import DiceLoss, IoULoss

from argparse import Namespace
from pathlib import Path
import torch
from pytorch_lightning import Trainer
from torch.nn import BCEWithLogitsLoss

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [None]:
torch.manual_seed(42)

data_path = 'data_with_centers/'
train_percent = 0.86
bio_len = 16
mask_size = 80
batch_size = 4

files = os.listdir(data_path)
train_size = int(train_percent * len(files))
val_size = len(files) - train_size
train_files, val_files = torch.utils.data.random_split(files, [train_size, val_size])

mean, std = calculate_mean_and_std(data_path, train_files, biosensor_length=bio_len)

train_dataset = BiosensorDataset(data_path, train_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)
val_dataset = BiosensorDataset(data_path, train_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Define the hyperparameters
args = Namespace(
    lr=0.001,
    epochs=10,
    batch_size=batch_size,
    amp=False,
    bilinear=False,
    data_path= data_path
)

# Define the loss function
criterion = DiceLoss()

# Initialize the model and data module
model = UNetLightningModule(learning_rate=args.lr, channels=bio_len, classes=1, loss_func=criterion, amp=args.amp, bilinear=args.bilinear)

data_module = BiosensorDataModule(data_path=args.data_path, batch_size=args.batch_size, biosensor_length=BIO_LENGTH, mask_size=MASK_SIZE)

# Initialize the trainer
trainer = Trainer(max_epochs=args.epochs, accelerator='gpu' if torch.cuda.is_available() else 'cpu', precision=16 if args.amp else 32)

# Train the model
trainer.fit(model, data_module)