# Highway Scene Segmentation

This notebook showcases semantic scene segmentation into background pixels and a foreground with driving cars. Both training and validation data are taken from the highway scene in the Change Detection dataset, which consists of labeled highway camera frames (except the first 470 which are unlabeled): http://jacarini.dinf.usherbrooke.ca/dataset2014#

Input: 
  - RGB 
  - Shape (3,320,240)

Label:
  - Greyscale
  - 0:black:background, 170:grey:foreground-edges,  255:white:foreground
  - Shape (1,320,240)


![input image](showcase/in001600.jpg "Title") ![gt image](showcase/gt001600.png "Title")

## Imports

In [None]:
import os
from torch.utils.data import Dataset, random_split
from torchvision.io import read_image
from torchvision.transforms import v2
from torchvision import tv_tensors

import torch
from torch import nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F

from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, precision_score, recall_score, jaccard_score
import numpy as np
import matplotlib.pyplot as plt


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Pytorch Dataset Wrapper
Inherits from **torch.utils.data.Dataset** and implements two methods.
- **def \_\_getitem\_\_(self, idx)**: given an integer idx returns the data x,y
    - x is the image as a float tensor of shape: $(3,H,W)$
    - y is the label image as a mask of shape: $(H,W)$ each pixel should contain the label 0 (background) or 1 (foreground). It is recommended to use the type torch.long

Image resolution is decreased to fit at once into memory and GPU. Some regularization through color jittering.

In [None]:
augmentations = v2.Compose([
    v2.ColorJitter(brightness=0.5, contrast=1.5, saturation=2.5, hue=0.5),
    v2.RandomResizedCrop(size=(224, 224),scale=(0.7, 1.0)),
    v2.RandomRotation(30),
    v2.RandomHorizontalFlip(p=0.5),
    #v2.ToDtype(torch.float32, scale=True),
    #v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
input = read_image("showcase/in001600.jpg").float()/255
label = read_image("showcase/gt001600.png")

augmented_img, augmented_mask = augmentations(
    read_image("showcase/in001600.jpg").float()/255,
    tv_tensors.Mask(label==255)
    )

plt.figure(figsize=(20,20))

plt.subplot(1, 7, 1)
plt.imshow(input.permute(1, 2, 0))
plt.axis("off")
plt.title("input")

plt.subplot(1, 7, 2)
plt.imshow((input*(np.repeat(label, 3, axis=0)==0)).permute(1, 2, 0))
plt.axis("off")
plt.title("background")

plt.subplot(1, 7, 3)
plt.imshow((input*(np.repeat(label, 3, axis=0)==170)).permute(1, 2, 0))
plt.axis("off")
plt.title("shadow")

plt.subplot(1, 7, 4)
plt.imshow((input*(np.repeat(label, 3, axis=0)==255)).permute(1, 2, 0))
plt.axis("off")
plt.title("cars")

plt.subplot(1, 7, 5)
plt.imshow(label.squeeze(), cmap="gray")
plt.axis("off")
plt.title("label")

plt.subplot(1, 7, 6)
plt.imshow(augmented_img.permute(1, 2, 0))
plt.axis("off")
plt.title("aug_input")

plt.subplot(1, 7, 7)
plt.imshow(augmented_mask.squeeze(), cmap="gray")
plt.axis("off")
plt.title("aug_label")

plt.show()

In [None]:
class HighwayDataset(Dataset):
    def __init__(self, device, augmentation=None):
        self.input_filenames = sorted(os.listdir("highway/input"))[470:]
        self.labels_filenames = sorted(os.listdir("highway/groundtruth"))[470:]
        self.augmentation = augmentation
        self.device = device

        self.inputs = []
        self.labels = []

        for i in range(len(self.input_filenames)):
            image = read_image( "highway/input/" + self.input_filenames[i]).float()/255
            label = (read_image( "highway/groundtruth/" + self.labels_filenames[i])[0] > 0).long()

            self.inputs.append(image)
            self.labels.append(label)

        self.inputs = torch.stack(self.inputs).to(device)
        self.labels = torch.stack(self.labels).to(device)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        augmented_img, augmented_mask = augmentations(
            self.inputs[idx],
            tv_tensors.Mask(self.labels[idx])
            )
        return augmented_img, augmented_mask
    
highwayDataset = HighwayDataset(device)
highwayDataset.inputs.shape

In [None]:
highwayDataset = HighwayDataset(device)
highwayDataset.inputs.shape

## Fully-Convolutional Neural Network
The CNN inspired by U-Net is flexible to the input and output resolution and includes residual blocks.

- input: a batch of images $(B,3,H,W)$
- output: a batch of pixel-wise class predictions $(B,C,H,W)$, where $C=2$

In [None]:
class SmallUNet(nn.Module):
    def __init__(self):
        super(SmallUNet, self).__init__()

        self.sideconv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.sideconv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.sideconv3 = nn.Conv2d(128, 256, 3, 1, 1)
        self.sideconv4 = nn.Conv2d(256, 256, 3, 1, 1)
        self.upconv1 = nn.ConvTranspose2d(256, 256, 2, 2)
        self.sideconv5 = nn.Conv2d(512, 128, 3, 1, 1)
        self.upconv2 = nn.ConvTranspose2d(128, 128, 2, 2)
        self.sideconv6 = nn.Conv2d(256, 64, 3, 1, 1)
        self.upconv3 = nn.ConvTranspose2d(64, 64, 2, 2)
        self.sideconv7 = nn.Conv2d(128, 1, 3, 1, 1)

    def forward(self, x):
        
        # Encoder
        x = self.sideconv1(x)
        x_res1 = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.sideconv2(x)
        x_res2 = F.relu(x)
        x = F.max_pool2d(x, 2)

        x = self.sideconv3(x)
        x_res3 = F.relu(x)
        x = F.max_pool2d(x, 2)

        # Bottleneck
        x = self.sideconv4(x)
        x = F.relu(x)

        # Decoder
        x = self.upconv1(x)
        x = F.relu(x)
        x = torch.cat((x, x_res3), dim=1)
        x = self.sideconv5(x)
        x = F.relu(x)

        x = self.upconv2(x)
        x = F.relu(x)
        x = torch.cat((x, x_res2), dim=1)
        x = self.sideconv6(x)
        x = F.relu(x)

        x = self.upconv3(x)
        x = F.relu(x)
        x = torch.cat((x, x_res1), dim=1)
        output = self.sideconv7(x)

        return F.sigmoid(output)

## Train and Test(Validation) Function
- Loss: Cross Entropy
- Metrics:
    - Pixel-wise: Accuracy, Precision, Recall
    - Image-wise: Intersection over Union

In [None]:
def train_classifier(model, device, train_loader, optimizer, epoch, loss_list):
    model.train()

    epoch_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()



        epoch_loss += loss.item()
        #if batch_idx%5 == 0:
        #    print("train batch id:", batch_idx," data len:",len(data), " train batch loss:", loss.item())

        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, (batch_idx+1) * len(data), len(train_loader.dataset),
            100. * (batch_idx+1) / len(train_loader), loss.item()), end='\r')


    loss_list.append(epoch_loss / len(train_loader))


def test_classifier(model, device, test_loader, loss_list):
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)

            last_loss = F.cross_entropy(output, target).item()#, reduction='sum').item()
            #if batch_idx%2 == 0:
            #    print("test batch id:", batch_idx," data len:",len(data), " test batch loss:", last_loss)

            test_loss += last_loss
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()



    test_loss /= len(test_loader)#.dataset * 60 * 80 )

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset) *60*80 ,
    100. * correct / (len(test_loader.dataset)* 60 * 80 ), end='\r')) # / 60 * 80  wegen number of images * height *width  pixel insgesamt

    loss_list.append(test_loss)

## Training Loop
- 80/20 Data Split
- PyTorch DataLoaders
- Adam Optimizer

In [None]:
train_dataset, test_dataset = random_split(highwayDataset, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=512)

model = SmallUNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.002)#vorher lr 0.001

scheduler = StepLR(optimizer, step_size=10, gamma=0.1)# vorher step_size 5

train_losses = []
test_losses = []

epochs = 20

for epoch in range(1,epochs + 1):
    train_classifier(model, device, train_loader, optimizer, epoch, train_losses)
    test_classifier(model, device, test_loader, test_losses)
    scheduler.step()

In [None]:
def get_segmentation(model, device, test_loader):
    model.eval()
    segmentations = []
    accuracies = []
    precisions = []
    recalls = []
    ious = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            segmentations.append(pred.detach().cpu().numpy())

            for i in range(data.size(0)):  
                pred_flat = pred[i].cpu().numpy().flatten()
                target_flat = target[i].cpu().numpy().flatten()

                accuracies.append(accuracy_score(target_flat, pred_flat))
                precisions.append(precision_score(target_flat, pred_flat, average='macro'))
                recalls.append(recall_score(target_flat, pred_flat, average='macro'))
                ious.append(jaccard_score(target_flat, pred_flat, average='macro'))

    #print(accuracies, precisions, recalls, ious)
    return np.concatenate(segmentations), np.array(accuracies), np.array(precisions), np.array(recalls), np.array(ious)

In [None]:
results, acc, prec, rec, iou = get_segmentation(model, device, test_loader)

In [None]:
results.shape, acc.shape, prec.shape, rec.shape, iou.shape

## Report

In [None]:
example_output = model(highwayDataset[10030][0]).detach()
example_probabilities = F.softmax(example_output, dim=0)
example_segmentation = example_probabilities.argmax(dim=0)

plt.figure()

plt.subplot(1, 5, 1)
plt.imshow(highwayDataset[10030][0].detach().cpu().permute(1, 2, 0))
plt.axis("off")
plt.title("Image")

plt.subplot(1, 5, 2)
plt.imshow(example_output[0].detach().cpu(), cmap="gray")
plt.axis("off")
plt.title("Back Probs")

plt.subplot(1, 5, 3)
#plt.imshow(torch.squeeze(model(d[10030][0]).detach(),0))
plt.imshow(example_output[1].detach().cpu(), cmap="gray")
plt.axis("off")
plt.title("Front Probs")

plt.subplot(1, 5, 4)
#plt.imshow(torch.squeeze(model(d[10030][0]).detach(),0))
plt.imshow(example_segmentation.detach().cpu(), cmap="gray")
plt.axis("off")
plt.title("Back Pred")

plt.subplot(1, 5, 5)
plt.imshow(highwayDataset[10030][1].detach().cpu(), cmap="gray")
plt.axis("off")
plt.title("GT")

plt.show()

# Training and Test Error over each Epoch
Later comment: TODO Looks very suspicious. Must be revisited. EDIT: Data split and loader setup seem correct. Did not find any leakage yet. Maybe investigate training vs val/test loss without training data augmentation? still suspicious..

In [None]:
x = [ i+1 for i in range(epochs)]


plt.plot(x, train_losses, label='training loss')
plt.plot(x, test_losses, label='test loss')

plt.title('training and test loss over epochs')
plt.xlabel('epoch')
plt.ylabel('loss')

plt.legend()

plt.show()

## Evaluation Metrics for Final Model

In [None]:
print(highwayDataset.gts[test_dataset.indices].shape)
print(classification_report(highwayDataset.gts[test_dataset.indices].detach().cpu().flatten(), results.flatten()))

In [None]:
print(acc)
plt.hist(acc)
plt.title("Accuracy for each image")
plt.show()

In [None]:
print(prec)
plt.hist(prec)
plt.title("Precision for each image")
plt.show()

In [None]:
print(rec)
plt.hist(rec)
plt.title("Recall for each image")
plt.show()

In [None]:
print(iou)
plt.hist(iou)
plt.title("Intersection over Union for each image")
plt.show()