In [None]:
# some basic imports
import math
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from random import sample
from PIL import Image

In [None]:
# predefined for the project as the task is considered as classification in patches
# some constants
PATCH_SIZE = 16  # pixels per side of square patches
VAL_SIZE = 10  # size of the validation set (number of images)
CUTOFF = 0.25  # minimum average brightness for a mask patch to be classified as containing road

In [None]:
!wsl ls

In [None]:
# K-Fold

import numpy as np
from sklearn.model_selection import KFold

k = 5

arr = np.arange(144)  

kf = KFold(n_splits=k, shuffle=True, random_state=100)

fold_subsets = []

for train_index, val_index in kf.split(arr):
    train_subset = arr[train_index]
    test_subset = arr[val_index]
    fold_subsets.append((train_subset, test_subset))

In [None]:
SUBSET_ID = 1 # change it manually to which k-fold subset we are using currently [0,...,4]
validation_ids = fold_subsets[SUBSET_ID][1] # get validation ids
print(validation_ids)

In [None]:
# removing folders and files to we can recreate the training and validation folder
!rm -r validation
!rm -r training
!rm -r test
!rm mask_to_submission.py
!rm submission_to_mask.py 

# create training and validation folder according to the k-fold subset
try:
        !unzip ethz-cil-road-segmentation-2023.zip
        !rm -rf training/training
        !mkdir validation
        !mkdir validation/images
        !mkdir validation/groundtruth
        for img in glob("training/images/*.png"):
            for val_id in validation_ids:
                if img.endswith(f'satimage_{val_id}.png'):
                    os.rename(img, img.replace('training', 'validation'))
                    mask = img.replace('images', 'groundtruth') # replace path
                    os.rename(mask, mask.replace('training', 'validation'))
                    break
except:
    print('Please upload a .zip file containing your datasets.')

In [None]:
if not os.path.isdir('validation'):  # make sure this has not been executed yet
  try:
    !wsl unzip -o ethz-cil-road-segmentation-2023.zip
    !wsl mv training/training/* training
    !wsl rm -rf training/training
    !wsl mkdir validation
    !wsl mkdir validation/images
    !wsl mkdir validation/groundtruth
    for img in sample(glob("training/images/*.png"), VAL_SIZE):
      os.rename(img, img.replace('training', 'validation'))
      mask = img.replace('images', 'groundtruth')
      os.rename(mask, mask.replace('training', 'validation'))
  except:
    print('Please upload a .zip file containing your datasets.')
else:
  print("Validation folder exists already")

In [None]:
if not os.path.isdir('masked'):  # make sure this has not been executed yet
  try:
          !wsl unzip -o masked.zip
  except:
      print('Please upload a .zip file containing the inpainted images.')
else:
    print("Masked images folder exists already")

In [None]:
def load_all_from_path(path):
    # loads all HxW .pngs contained in path as a 4D np.array of shape (n_images, H, W, 3)
    # images are loaded as floats with values in the interval [0., 1.]
    return np.stack([np.array(Image.open(f)) for f in sorted(glob(path + '/*.png'))]).astype(np.float32) / 255.


def show_first_n(imgs, masks, n=5):
    # visualizes the first n elements of a series of images and segmentation masks
    imgs_to_draw = min(n, len(imgs))
    fig, axs = plt.subplots(2, imgs_to_draw, figsize=(18.5, 6))
    for i in range(imgs_to_draw):
        axs[0, i].imshow(imgs[i])
        axs[1, i].imshow(masks[i])
        axs[0, i].set_title(f'Image {i}')
        axs[1, i].set_title(f'Mask {i}')
        axs[0, i].set_axis_off()
        axs[1, i].set_axis_off()
    plt.show()


# paths to training and validation datasets
train_path = 'training'
val_path = 'validation'

print(os.path.join(train_path, 'images'))
train_images = load_all_from_path(os.path.join(train_path, 'images'))
train_masks = load_all_from_path(os.path.join(train_path, 'groundtruth'))
val_images = load_all_from_path(os.path.join(val_path, 'images'))
val_masks = load_all_from_path(os.path.join(val_path, 'groundtruth'))

# visualize a few images from the training set
show_first_n(train_images, train_masks, n=2)

In [None]:
print(f"Training samples {len(train_images)}, Validation. samples {len(val_images)}, Shape: {train_images.shape} ")

In [None]:
def load_all_from_path_by_id(path, idx):
    # loads all HxW .pngs contained in path as a 4D np.array of shape (n_images, H, W, 3)
    # images are loaded as floats with values in the interval [0., 1.]
    return np.stack([np.array(Image.open(f)) for f in sorted(glob(path + '/*_' + str(idx) + '*.png'))]).astype(np.float32) / 255.


In [None]:
def get_img_idx(path):
    return int(re.search(r'\d+', os.path.basename(path)).group()) # first number in filename

In [None]:
def load_all_from_path_except(path, exceptions=[]):
    # loads all HxW .pngs contained in path as a 4D np.array of shape (n_images, H, W, 3)
    # images are loaded as floats with values in the interval [0., 1.]
    return np.stack([np.array(Image.open(f)) for f in sorted(glob(path + '/*.png')) if get_img_idx(f) not in exceptions]).astype(np.float32) / 255.

In [None]:
# DELETE AUGMENTATION FOLDER. SET DELETE_AUG = True if you really want to remove it
DELETE_AUG = False

if DELETE_AUG:
    !rm -r augmentation

In [None]:
%pip install --user albumentations

In [None]:
import albumentations as A
import cv2
import matplotlib.pyplot as plt

transform = A.Compose([
    A.HorizontalFlip(p=0.6),
    A.RandomRotate90(p=0.6),
    A.RandomBrightnessContrast(p=0.5),
    A.ElasticTransform(p=0.5, alpha=5, sigma=10, alpha_affine=0),
    # A.CoarseDropout(max_holes=10, max_height=40, max_width=100)
    A.CoarseDropout(max_holes=10, max_height=40, max_width=100)
])

src_images = "training/images/*.png"
# src_images = "inpainted/images/*.png"
# src_images = "masked/images/*.png"

augmentation_factor = 4

dest_folder_img = 'augmentation/images'
dest_folder_mask = 'augmentation/groundtruth'

val_images_path = os.path.join('validation', 'images', '*.png')
val_indices = [get_img_idx(f) for f in glob(val_images_path)]

if not os.path.isdir('augmentation'):  # make sure this has not been executed yet
    !wsl mkdir augmentation
    !wsl mkdir augmentation/images
    !wsl mkdir augmentation/groundtruth
    
    for img_path in glob(src_images):
        # don't use variations of images in the validation set in case of using masked or inpainted images
        if get_img_idx(img_path) in val_indices:
            continue
            
        for i in range(augmentation_factor):
            image = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

            mask_path = img_path.replace('images', 'groundtruth')
            # mask_path = mask_path.replace('_inpainted', '')
            # mask_path = mask_path.replace('_masked', '')
            mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)

            transformed = transform(image=image, mask=mask)
            transformed_image = transformed['image']
            transformed_mask = transformed['mask']

            file_name = os.path.basename(img_path)
            image_name, image_ext = os.path.splitext(file_name) # get iamge name and extension
            img_dest_path = os.path.join(dest_folder_img, f"{image_name}_transformed_{i}{image_ext}")
            mask_dest_path = os.path.join(img_dest_path.replace("images", "groundtruth"))
            
            transformed_image = Image.fromarray(transformed_image)
            transformed_image.save(img_dest_path)
    
            transformed_mask = Image.fromarray(transformed_mask)
            transformed_mask.save(mask_dest_path)

In [None]:
# Show sample for image 0
aug_path = 'augmentation'
idx = 0

original_image = [train_images[idx]]
original_mask = [train_masks[idx]]

aug_images = load_all_from_path_by_id(os.path.join(aug_path, 'images'), idx)
aug_masks = load_all_from_path_by_id(os.path.join(aug_path, 'groundtruth'), idx)


disp_imgs = np.concatenate((original_image, aug_images), axis=0)
disp_masks = np.concatenate((original_mask, aug_masks), axis=0)

len(disp_imgs)
# visualize a few images from the training set
show_first_n(disp_imgs, disp_masks, 10)

In [None]:
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

def np_to_tensor(x, device):
    # allocates tensors from np.arrays
    if device == 'cpu':
        return torch.from_numpy(x).cpu()
    else:
        return torch.from_numpy(x).contiguous().pin_memory().to(device=device, non_blocking=True)


class ImageDataset(torch.utils.data.Dataset):
    # dataset class that deals with loading the data and making it available by index.

    def __init__(self, path, device, use_patches=True, use_augmentation=False, use_inpainted_images=False, resize_to=(400, 400)):
        self.path = path
        self.device = device
        self.use_patches = use_patches
        self.use_augmentation = use_augmentation
        self.use_inpainted_images = use_inpainted_images
        self.resize_to=resize_to
        self.x, self.y, self.n_samples = None, None, None
        self._load_data()

    def _load_data(self):  # not very scalable, but good enough for now
        self.x = load_all_from_path(os.path.join(self.path, 'images'))[:,:,:,:3]
        self.y = load_all_from_path(os.path.join(self.path, 'groundtruth'))
        
        if self.use_augmentation:
             # load from augmentation path and concat x and y
            aug_path = 'augmentation'
            aug_x = load_all_from_path(os.path.join(aug_path, 'images'))[:,:,:,:3]
            aug_y = load_all_from_path(os.path.join(aug_path, 'groundtruth'))
            
            self.x = np.concatenate((self.x, aug_x), axis=0)
            self.y = np.concatenate((self.y, aug_y), axis=0)
            
        if self.use_inpainted_images:
            inpainted_path = 'inpainted'
            
            val_images_path = os.path.join('validation', 'images', '*.png')
            val_indices = [get_img_idx(f) for f in glob(val_images_path)]
            
            inpainted_x = load_all_from_path_except(os.path.join(inpainted_path, 'images'), exceptions=val_indices)[:,:,:,:3]
            inpainted_y = load_all_from_path_except(os.path.join(inpainted_path, 'groundtruth'), exceptions=val_indices)
            
            self.x = np.concatenate((self.x, inpainted_x), axis=0)
            self.y = np.concatenate((self.y, inpainted_y), axis=0)
       
        if self.use_patches:  # split each image into patches
            self.x, self.y = image_to_patches(self.x, self.y)
        elif self.resize_to != (self.x.shape[1], self.x.shape[2]):  # resize images
            self.x = np.stack([cv2.resize(img, dsize=self.resize_to) for img in self.x], 0)
            self.y = np.stack([cv2.resize(mask, dsize=self.resize_to) for mask in self.y], 0)
        self.x = np.moveaxis(self.x, -1, 1)  # pytorch works with CHW format instead of HWC
        
        
        self.n_samples = len(self.x)

    def _preprocess(self, x, y):
        # to keep things simple we will not apply transformations to each sample,
        # but it would be a very good idea to look into preprocessing
        return x, y

    def __getitem__(self, item):
        return self._preprocess(np_to_tensor(self.x[item], self.device), np_to_tensor(self.y[[item]], self.device))

    def __len__(self):
        return self.n_samples


def show_val_samples(x, y, y_hat, segmentation=False):
    # training callback to show predictions on validation set
    imgs_to_draw = min(5, len(x))
    if x.shape[-2:] == y.shape[-2:]:  # segmentation
        fig, axs = plt.subplots(3, imgs_to_draw, figsize=(18.5, 12))
        for i in range(imgs_to_draw):
            axs[0, i].imshow(np.moveaxis(x[i], 0, -1))
            axs[1, i].imshow(np.concatenate([np.moveaxis(y_hat[i], 0, -1)] * 3, -1))
            axs[2, i].imshow(np.concatenate([np.moveaxis(y[i], 0, -1)]*3, -1))
            axs[0, i].set_title(f'Sample {i}')
            axs[1, i].set_title(f'Predicted {i}')
            axs[2, i].set_title(f'True {i}')
            axs[0, i].set_axis_off()
            axs[1, i].set_axis_off()
            axs[2, i].set_axis_off()
    else:  # classification
        fig, axs = plt.subplots(1, imgs_to_draw, figsize=(18.5, 6))
        for i in range(imgs_to_draw):
            axs[i].imshow(np.moveaxis(x[i], 0, -1))
            axs[i].set_title(f'True: {np.round(y[i]).item()}; Predicted: {np.round(y_hat[i]).item()}')
            axs[i].set_axis_off()
    plt.show()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training using {device}")
# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset('training', device, use_patches=False, resize_to=(384, 384))
len(train_dataset.x)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataset = ImageDataset('training', device, use_patches=False, use_augmentation=False, resize_to=(384, 384))

print(train_dataset.n_samples) # standard dataset has 134
train_dataset._load_data()
train_dataset = ImageDataset('training', device, use_patches=False, use_augmentation=True, use_inpainted_images=False, resize_to=(384, 384))
print(train_dataset.n_samples) # check if we have more

In [None]:
def train(train_dataloader, eval_dataloader, model, loss_fn, metric_fns, optimizer, n_epochs):
    # training loop
    logdir = './tensorboard/net'
    writer = SummaryWriter(logdir)  # tensorboard writer (can also log images)

    history = {}  # collects metrics at the end of each epoch
    
    # best_val_patch_acc = 0
    best_val_loss = 1

    for epoch in range(n_epochs):  # loop over the dataset multiple times

        # initialize metric list
        metrics = {'loss': [], 'val_loss': []}
        for k, _ in metric_fns.items():
            metrics[k] = []
            metrics['val_'+k] = []

        pbar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{n_epochs}')
        # training
        model.train()
        for (x, y) in pbar:
            optimizer.zero_grad()  # zero out gradients
            y_hat = model(x)  # forward pass
            loss = loss_fn(y_hat, y)
            loss.backward()  # backward pass
            optimizer.step()  # optimize weights

            # log partial metrics
            metrics['loss'].append(loss.item())
            for k, fn in metric_fns.items():
                metrics[k].append(fn(y_hat, y).item())
            pbar.set_postfix({k: sum(v)/len(v) for k, v in metrics.items() if len(v) > 0})

        # validation
        model.eval()
        with torch.no_grad():  # do not keep track of gradients
            for (x, y) in eval_dataloader:
                y_hat = model(x)  # forward pass
                loss = loss_fn(y_hat, y)

                # log partial metrics
                metrics['val_loss'].append(loss.item())
                for k, fn in metric_fns.items():
                    metrics['val_'+k].append(fn(y_hat, y).item())

        # summarize metrics, log to tensorboard and display
        history[epoch] = {k: sum(v) / len(v) for k, v in metrics.items()}
        for k, v in history[epoch].items():
          writer.add_scalar(k, v, epoch)
        print(' '.join(['\t- '+str(k)+' = '+str(v)+'\n ' for (k, v) in history[epoch].items()]))
        
        if epoch % 2 == 0: # only show after 5 epochs
            show_val_samples(x.detach().cpu().numpy(), y.detach().cpu().numpy(), y_hat.detach().cpu().numpy())
            
        # if metrics['val_patch_acc'][-1] > best_val_patch_acc:
        #     best_val_patch_acc = metrics['val_patch_acc'][-1]
        #     print('saving...')
        #     torch.save(model.state_dict(), 'unet')
        
        if metrics['val_loss'][-1] < best_val_loss:
            best_val_loss = metrics['val_loss'][-1]
            print('saving...')
            torch.save(model.state_dict(), 'unet')

    print('Finished Training')
    # plot loss curves
    plt.plot([v['loss'] for k, v in history.items()], label='Training Loss')
    plt.plot([v['val_loss'] for k, v in history.items()], label='Validation Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()

In [None]:
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

class Block(nn.Module):
    # a repeating structure composed of two convolutional layers with batch normalization and ReLU activations
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU(),
                                   nn.BatchNorm2d(out_ch),
                                   nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU())

    def forward(self, x):
        return self.block(x)


class UNet(nn.Module):
    # UNet-like architecture for single class semantic segmentation.
    def __init__(self, chs=(3,64,128,256,512,1024), p_dropout=0.0):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder # [::-1] reverse, [:-1] select all until last
        self.enc_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_ch, out_ch, 2, 2) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # deconvolution
        self.dec_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # decoder blocks
        self.head = nn.Sequential(nn.Conv2d(dec_chs[-1], 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            x = self.dropout(x)
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        # decode
        for block, upconv, feature in zip(self.dec_blocks, self.upconvs, enc_features[::-1]):
            x = upconv(x)  # increase resolution
            x = torch.cat([x, feature], dim=1)  # concatenate skip features
            x = block(x)  # pass through the block
        return self.head(x)  # reduce to 1 channel


def patch_accuracy_fn(y_hat, y):
    # computes accuracy weighted by patches (metric used on Kaggle for evaluation)
    h_patches = y.shape[-2] // PATCH_SIZE
    w_patches = y.shape[-1] // PATCH_SIZE
    patches_hat = y_hat.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    patches = y.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    return (patches == patches_hat).float().mean()

def dice_similarity_fn(y_hat, y):
    # computes dice similarity
    overlap = (y_hat.round() == y.round()).float().sum()
    return 2*overlap  / (2*overlap + y_hat.round().sum() + y.round().sum())

def accuracy_fn(y_hat, y):
    # computes classification accuracy
    return (y_hat.round() == y.round()).float().mean()

In [None]:
chs=(3,64,128,256,512,1024)
chs[::-1][:-1] # reverse it and remove first item

In [None]:
from attn_unet import UTNet
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training using {device}")
# reshape the image to simplify the handling of skip connections and maxpooling
train_dataset = ImageDataset('training', device, use_patches=False, use_augmentation=True, use_inpainted_images=False, resize_to=(384, 384))
val_dataset = ImageDataset('validation', device, use_patches=False, use_augmentation=False, resize_to=(384, 384))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=10, shuffle=True)

model = UTNet(chs=(3, 64, 128, 256), p_dropout=0.0).to(device)
loss_fn = nn.BCELoss()
metric_fns = {'acc': accuracy_fn, 'patch_acc': patch_accuracy_fn, 'dice': dice_similarity_fn}
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 35
train(train_dataloader, val_dataloader, model, loss_fn, metric_fns, optimizer, n_epochs)

In [None]:
def create_submission(labels, test_filenames, submission_filename):
    test_path='test/images'
    with open(submission_filename, 'w') as f:
        f.write('id,prediction\n')
        for fn, patch_array in zip(sorted(test_filenames), test_pred):
            img_number = int(re.search(r"\d+", fn).group(0))
            for i in range(patch_array.shape[0]):
                for j in range(patch_array.shape[1]):
                    f.write("{:03d}_{}_{},{}\n".format(img_number, j*PATCH_SIZE, i*PATCH_SIZE, int(patch_array[i, j])))

test_path = 'test/images'

In [None]:
# load model
model.load_state_dict(torch.load("unet"))

# predict on test set
test_filenames = (glob(test_path + '/*.png'))
test_images = load_all_from_path(test_path)
batch_size = test_images.shape[0] # number of test images
size = test_images.shape[1:3] #WH of images (currently its 400x400 -> reshape to 384x384 for model)
# we also need to resize the test images. This might not be the best ideas depending on their spatial resolution.
test_images = np.stack([cv2.resize(img, dsize=(384, 384)) for img in test_images], 0)
test_images = test_images[:, :, :, :3] # only get first three channels
test_images = np_to_tensor(np.moveaxis(test_images, -1, 1), device) # switch from HWC to CWH
test_pred = [model(t).detach().cpu().numpy() for t in test_images.unsqueeze(1)] # use our model to predict segmentation mask
test_pred = np.concatenate(test_pred, 0)
test_pred= np.moveaxis(test_pred, 1, -1)  # CHW to HWC
test_pred = np.stack([cv2.resize(img, dsize=size) for img in test_pred], 0)  # resize to original shape

# now test_pred has shape (144, 400, 400), for each image the pixel have values 0-1.

# now compute labels
test_pred = test_pred.reshape((-1, size[0] // PATCH_SIZE, PATCH_SIZE, size[0] // PATCH_SIZE, PATCH_SIZE)) # split in patches
test_pred = np.moveaxis(test_pred, 2, 3) # move dimension 16x16 together
test_pred = np.round(np.mean(test_pred, (-1, -2)) > CUTOFF) # mean of 16x16, classify 0 or 1
create_submission(test_pred, test_filenames, submission_filename='unet_submission.csv')