In [1]:
import torch
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torch.optim import RMSprop
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import Dataset, DataLoader, random_split

import numpy as np
import os
import matplotlib.pyplot as plt

from src.datasets import BiosensorDataset, calculate_mean_and_std
from src.unet import UNet


In [None]:
class DiceLoss(torch.nn.Module):
    def forward(self, pred, target, smooth = 1.):
        intersection = (pred * target).sum()
        union = pred.sum() + target.sum()
        dice = (2. * intersection + smooth) / (union + smooth)
        return 1 - dice

class IoULoss(torch.nn.Module):
    def forward(self, pred, target, smooth = 1.):
        intersection = (pred * target).sum()
        total = (pred + target).sum()
        union = total - intersection 
        iou = (intersection + smooth) / (union + smooth)
        return 1 - iou

In [None]:
class UNetLightningModule(LightningModule):
    def __init__(self, learning_rate: float,  channels: int, classes: int, loss_func: Module, amp: bool = False, bilinear: bool = False):
        super().__init__()
        self.model = UNet(n_channels=channels, n_classes=classes, bilinear=bilinear)
        self.learning_rate = learning_rate
        self.amp = amp
        self.criterion = loss_func

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, true_masks = batch
        masks_pred_logits = self(images)
        # BCEWithLogitsLoss + dice_loss???
        loss = self.criterion(masks_pred_logits, true_masks.unsqueeze(1))
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, true_masks = batch
        masks_pred_logits = self(images)
        loss = self.criterion(masks_pred_logits, true_masks.unsqueeze(1))
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = RMSprop(self.parameters(), lr=self.learning_rate)
        scheduler = ReduceLROnPlateau(optimizer, 'max', patience=5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

torch.manual_seed(42)

data_path = 'data_with_centers/'
train_percent = 0.86
bio_len = 16
mask_size = 80
batch_size = 4

files = os.listdir(data_path)
train_size = int(train_percent * len(files))
val_size = len(files) - train_size
train_files, val_files = torch.utils.data.random_split(files, [train_size, val_size])

mean, std = calculate_mean_and_std(data_path, train_files, biosensor_length=bio_len)

train_dataset = BiosensorDataset(data_path, train_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)
val_dataset = BiosensorDataset(data_path, train_files, mean, std, bool, biosensor_length=bio_len, mask_size=mask_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Define the hyperparameters
args = Namespace(
    lr=0.001,
    epochs=10,
    batch_size=batch_size,
    amp=False,
    bilinear=False,
    data_path= data_path
)

# Define the loss function
criterion = DiceLoss()

# Initialize the model and data module
model = UNetLightningModule(learning_rate=args.lr, channels=bio_len, classes=1, loss_func=criterion, amp=args.amp, bilinear=args.bilinear)

data_module = BiosensorDataModule(data_path=args.data_path, batch_size=args.batch_size, biosensor_length=BIO_LENGTH, mask_size=MASK_SIZE)

# Initialize the trainer
trainer = Trainer(max_epochs=args.epochs, accelerator='gpu' if torch.cuda.is_available() else 'cpu', precision=16 if args.amp else 32)

# Train the model
trainer.fit(model, train_loader, val_loader)

In [None]:
# Model loading and prediction visualization
