In [1]:
import yaml
import torch
import logging
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from unet import Unet
from trainer import Trainer
from metrics import MeanIoU
from dataset import CustomSegmentationDataset

In [None]:
# Setup logging
logging.basicConfig(
    filename="training.log", level=logging.INFO, format="%(asctime)s %(message)s"
)

# Read configuration
with open("config.yaml", "r") as file:
    config = yaml.safe_load(file)

# Create dataloaders
data_train = CustomSegmentationDataset(
    config["data"]["train_image_dir"], config["data"]["train_label_dir"]
)
data_val = CustomSegmentationDataset(
    config["data"]["val_image_dir"], config["data"]["val_label_dir"]
)

train_loader = DataLoader(
    data_train, batch_size=config["training"]["batch_size"], shuffle=True
)
val_loader = DataLoader(
    data_val, batch_size=config["training"]["batch_size"], shuffle=False
)

# Initialise model
model = Unet(config["model"]["in_layers"], config["model"]["num_classes"])

# Define criterion
criterion = nn.CrossEntropyLoss(ignore_index=-1)

# Select optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=config["training"]["learning_rate"],
    weight_decay=float(config["training"]["weight_decay"]),
)

# Select best availible device
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)

# Validation metric
val_metric = MeanIoU(config["model"]["num_classes"]).to(device)

# Initialise trainer
trainer = Trainer(
    model,
    train_loader,
    val_loader,
    criterion,
    optimizer,
    device,
    val_metric,
    checkpoint_dir=config["checkpoint"]["directory"],
)

# Train the model
trainer.train(config["training"]["num_epochs"])

100%|██████████| 372/372 [00:53<00:00,  6.99it/s]
100%|██████████| 63/63 [00:04<00:00, 15.69it/s]
100%|██████████| 372/372 [00:50<00:00,  7.32it/s]
100%|██████████| 63/63 [00:03<00:00, 15.90it/s]
100%|██████████| 372/372 [00:49<00:00,  7.50it/s]
100%|██████████| 63/63 [00:03<00:00, 15.80it/s]
100%|██████████| 372/372 [00:50<00:00,  7.38it/s]
100%|██████████| 63/63 [00:03<00:00, 15.78it/s]
100%|██████████| 372/372 [00:49<00:00,  7.52it/s]
100%|██████████| 63/63 [00:03<00:00, 15.92it/s]
100%|██████████| 372/372 [00:48<00:00,  7.59it/s]
100%|██████████| 63/63 [00:04<00:00, 15.71it/s]
100%|██████████| 372/372 [00:48<00:00,  7.62it/s]
100%|██████████| 63/63 [00:03<00:00, 15.79it/s]
100%|██████████| 372/372 [00:48<00:00,  7.63it/s]
100%|██████████| 63/63 [00:03<00:00, 15.76it/s]
100%|██████████| 372/372 [00:48<00:00,  7.59it/s]
100%|██████████| 63/63 [00:04<00:00, 15.73it/s]
100%|██████████| 372/372 [00:49<00:00,  7.59it/s]
100%|██████████| 63/63 [00:04<00:00, 15.72it/s]
100%|██████████| 372