In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from network import LaneDataset, LaneDetectionUNet

In [None]:
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
DEVICE

'cuda'

In [8]:
import numpy as np
def set_seed(seed=0):
    # for reproductibility
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
# Creating datasets
img_folder = r"C:\javier\personal_projects\computer_vision\data\KITTI_road_segmentation\data_road\training\image_2"
gt_folder = r"data\labels"

data = LaneDataset(img_folder, gt_folder)
print(f"All training samples: {len(data)}")

train_data, val_data = random_split(data, [200, 89])

All training samples: 289


In [None]:
set_seed(0)
n_epochs = 10
lr = 0.001
batch_size = 32

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

model = LaneDetectionUNet()

# not enough memory in CUDA :(
# model.to(DEVICE)

loss_fn = F.binary_cross_entropy_with_logits

optimizer = torch.optim.Adam(model.parameters(), lr = lr)

for epoch in range(n_epochs):
    print(f"Training epoch {epoch + 1}/{n_epochs}")
    model.train()
    epoch_tr_loss = 0.
    for b, batch in enumerate(train_loader):
        img, label = batch

        optimizer.zero_grad()

        logits = model(img)
        
        loss = loss_fn(logits, label)
        epoch_tr_loss += loss.item()
        loss.backward()

        optimizer.step()
        print(f"   Batch {b} -> loss: {loss.item():.3f}")
    epoch_tr_loss /= (b + 1)

    # validation round
    model.eval()
    epoch_val_loss = 0.
    with torch.no_grad():
        for b, (img, label) in enumerate(val_loader):
            logits = model(img)
            loss = loss_fn(logits, label)
            epoch_val_loss += loss.item()
    epoch_val_loss /= (b + 1)

    if ((epoch > 0) & (epoch % 3 == 0)) | (epoch == n_epochs):
        torch.save(model.state_dict(), f"checkpoints/shallowUNET_ep{epoch}.pth")

    print(f"Train loss: {epoch_tr_loss:.3f} | Validation loss: {epoch_val_loss:.3f}")
    print(f"-----------------------------------------------------------------------")

Training epoch 1/10
   Batch 0 -> loss: 1.303
   Batch 1 -> loss: 0.788
   Batch 2 -> loss: 1.007
   Batch 3 -> loss: 0.637
   Batch 4 -> loss: 0.617
   Batch 5 -> loss: 0.631
   Batch 6 -> loss: 0.463
Train loss: 0.778 | Validation loss: 0.481
-----------------------------------------------------------------------
Training epoch 2/10
   Batch 0 -> loss: 0.561
   Batch 1 -> loss: 0.461
   Batch 2 -> loss: 0.546
   Batch 3 -> loss: 0.500
   Batch 4 -> loss: 0.451
   Batch 5 -> loss: 0.531
   Batch 6 -> loss: 0.502
Train loss: 0.507 | Validation loss: 0.437
-----------------------------------------------------------------------
Training epoch 3/10
   Batch 0 -> loss: 0.434
   Batch 1 -> loss: 0.469
   Batch 2 -> loss: 0.471
   Batch 3 -> loss: 0.440
   Batch 4 -> loss: 0.454
   Batch 5 -> loss: 0.434
   Batch 6 -> loss: 0.426
Train loss: 0.447 | Validation loss: 0.426
-----------------------------------------------------------------------
Training epoch 4/10
   Batch 0 -> loss: 0.409
   

In [21]:
img_folder = r"C:\javier\personal_projects\computer_vision\data\KITTI_road_segmentation\data_road\training\image_2"
gt_folder = r"data\labels"

dataset = LaneDataset(img_folder, gt_folder)

model = LaneDetectionUNet()
model.load_state_dict(torch.load("checkpoints/shallowUNET_ep9.pth"))
model.eval()

img, gt = dataset[0]
img = img[None, :, :, :]
pred = model(img)
pred = np.uint8(255*F.sigmoid(pred).squeeze().detach().numpy())

In [None]:
import cv2
cv2.imshow("Pred", pred)

: 