In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim

from unet import UNet
from dataset import SegmentationDataset

In [2]:
# Define the data transforms
transform = transforms.Compose([
    transforms.ToTensor(),         # Convert the images to PyTorch tensors
    transforms.Grayscale()
])

In [3]:
# Define the data directories
data_dir = './data'
train_dir = data_dir + '/train'
test_dir = data_dir + '/test'
val_dir = data_dir + '/val'

In [4]:
# Load the datasets
train_data = SegmentationDataset(train_dir, transform=transform)
test_data = SegmentationDataset(test_dir, transform=transform)
val_data = SegmentationDataset(val_dir, transform=transform)

# Create data loaders
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8)
val_loader = DataLoader(val_data, batch_size=8)

In [5]:
# Load the pre-trained ResNet50 model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=1, n_classes=1).to(device)

In [6]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.002, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [7]:
# Define the validation function
def validate(model, testloader, criterion, device):
    model.eval()
    total_dice_coeff = 0
    total_loss = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            loss = criterion(outputs.squeeze(1), labels.squeeze(1))
            total_loss += loss.item()
            dice_coeff = dice_coefficient(outputs, labels)
            total_dice_coeff += dice_coeff.item()
            
    val_loss = total_loss
    val_dice_coeff = total_dice_coeff / len(testloader) * 100
    return val_dice_coeff, val_loss

def dice_coefficient(outputs, labels):
    smooth = 1.
    num = outputs.size(0)
    outputs = outputs.view(num, -1)
    labels = labels.view(num, -1)
    intersection = (outputs * labels).sum(1)
    dice = (2. * intersection + smooth) / (outputs.sum(1) + labels.sum(1) + smooth)
    return dice.mean()

In [8]:
best_loss = float('inf')
patience = 15
counter = 0

# Train the model
for epoch in range(200):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(1), labels.squeeze(1))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Validate the model every 1 epoch
    if (epoch + 1) % 1 == 0:
        val_accuracy, val_loss = validate(model, val_loader, criterion, device)
        print(f'Epoch {epoch + 1}, Loss {running_loss / len(train_loader):.4f}, Validation Loss {val_loss:.4f}, Validation Accuracy {val_accuracy:.2f}%')
        running_loss = 0.0

        if val_loss < best_loss:
            best_loss = val_loss
            model_name = f'models/unet_best_epoch.pth'
            torch.save(model.state_dict(), model_name)
            counter = 0

        else:
            counter += 1
            if counter >= patience:
                print(f'Validation loss did not improve for {patience} epochs, stopping training')
                break

Epoch 1, Loss 1105.4352, Validation Loss 2249.5673, Validation Accuracy 56.12%
Epoch 2, Loss 1955.6062, Validation Loss 2235.2441, Validation Accuracy 74.32%
Epoch 3, Loss 1099.8966, Validation Loss 2233.0930, Validation Accuracy 1.06%
Epoch 4, Loss 1098.4537, Validation Loss 2231.1566, Validation Accuracy 49.74%
Epoch 5, Loss 1097.8810, Validation Loss 2230.1912, Validation Accuracy 40.43%
Epoch 6, Loss 1097.4888, Validation Loss 2229.2991, Validation Accuracy 36.03%
Epoch 7, Loss 1097.0534, Validation Loss 2227.8453, Validation Accuracy 35.14%
Epoch 8, Loss 1096.3098, Validation Loss 2229.0775, Validation Accuracy -164.67%
Epoch 9, Loss 1095.2589, Validation Loss 2223.8154, Validation Accuracy 62.43%
Epoch 10, Loss 1094.2521, Validation Loss 2222.1034, Validation Accuracy 67.01%
Epoch 11, Loss 1093.3228, Validation Loss 2226.1320, Validation Accuracy 70.16%
Epoch 12, Loss 1093.6902, Validation Loss 2222.8967, Validation Accuracy 71.80%
Epoch 13, Loss 1092.9665, Validation Loss 2220.4