In [12]:
import _thread
import os

import albumentations
import numpy as np
import torch
import torch.utils.data
import torchvision
import tqdm
from PIL import Image
from albumentations.pytorch import ToTensorV2
from data.mask_to_submission import masks_to_submission
from torchvision.transforms.functional import resize

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE=", DEVICE)
Train_image_path = "data/training/images"
Train_mask_path = "data/training/groundtruth"
Validation_image_path = "data/validating/images"
Validation_mask_path = "data/validating/groundtruth"

DEVICE= cuda


# Data processing

In [6]:
class RoadDataset(torch.utils.data.Dataset):
    """
    The class RoadDataset loads the data and executes the pre-processing operations on it.
    More specifically, it re-applies the specified transform every time data is fetched via a dataloader.
    """

    def __init__(
        self,
        image_path: str,
        mask_path: str,
        transform,
    ):
        self.transform = transform
        self.images = self.load_images(image_path)
        self.masks = self.load_images(mask_path)
        self.images_augmented = []
        self.masks_augmented = []

        # Data augmentation
        for i in range(len(self.images)):
            output = self.transform(image=self.images[i], mask=self.masks[i])
            self.images_augmented.append(output["image"])
            self.masks_augmented.append(output["mask"])

    def get_images(self):
        return self.images, self.masks

    @staticmethod
    def load_images(image_path):
        """This method loads the images from the given path"""
        images = []
        for img in os.listdir(image_path):
            path = os.path.join(image_path, img)
            image = Image.open(path)
            images.append(np.asarray(image))
            # images.append(cv2.imread(path))
        return np.asarray(images)

    def augment(self, index):
        """This method applies data augmentation to the images"""
        output = self.transform(image=self.images[index], mask=self.masks[index])
        self.images_augmented[index] = output["image"]
        self.masks_augmented[index] = output["mask"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        """This method returns the image at a certain position and its mask"""
        image = self.images_augmented[index]
        mask = self.masks_augmented[index]
        _thread.start_new_thread(self.augment, (index,))
        return (image / 255), (mask.unsqueeze(0) > 100).float()


def get_loader(
    data_path: str,
    mask_path: str,
    transform,
    batch_size: int = 4,
) -> torch.utils.data.DataLoader:
    """Create the DataLoader class"""
    dataset = RoadDataset(
        data_path,
        mask_path,
        transform,
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        generator=torch.Generator().manual_seed(127),
    )



# Models

In [7]:
# ================================================================================================
# UNET
# ================================================================================================


class DoubleConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(
                out_channels, out_channels, kernel_size=3, padding=1, bias=False
            ),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True),
        )

    def forward(self, x: torch.Tensor):
        return self.double_conv(x)


class Down(torch.nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = torch.nn.Sequential(
            torch.nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(torch.nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = torch.nn.ConvTranspose2d(
            in_channels, in_channels // 2, kernel_size=2, stride=2
        )
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        del x1
        del x2
        torch.cuda.empty_cache()
        return self.conv(x)


class UNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.inconv = DoubleConv(3, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outconv = torch.nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        x1 = self.inconv(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outconv(x)
        torch.cuda.empty_cache()
        return logits


# ================================================================================================
# LINKNET
# ================================================================================================


class LinkNetDecoderBlock(torch.nn.Module):
    def __init__(self, in_channel, out_channel):
        super(LinkNetDecoderBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channel, in_channel // 4, kernel_size=1)
        self.bn1 = torch.nn.BatchNorm2d(in_channel // 4)
        self.up = torch.nn.ConvTranspose2d(
            in_channel // 4,
            in_channel // 4,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.bn2 = torch.nn.BatchNorm2d(in_channel // 4)
        self.conv2 = torch.nn.Conv2d(in_channel // 4, out_channel, kernel_size=1)
        self.bn3 = torch.nn.BatchNorm2d(out_channel)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.up(x)
        x = self.bn2(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = torch.nn.functional.relu(x)
        return x


class LinkNet(torch.nn.Module):
    def __init__(
        self,
        encoder,
        channels=(64, 128, 256, 512),
    ):
        super().__init__()
        assert len(channels) == 4
        self.conv1 = encoder.conv1
        self.bn1 = encoder.bn1
        self.maxpool1 = encoder.maxpool
        self.encoders = torch.nn.ModuleList()
        self.encoders.append(encoder.layer1)
        self.encoders.append(encoder.layer2)
        self.encoders.append(encoder.layer3)
        self.encoders.append(encoder.layer4)

        self.decoders = torch.nn.ModuleList()
        channels = channels[::-1]
        for i in range(len(channels) - 1):
            self.decoders.append(LinkNetDecoderBlock(channels[i], channels[i + 1]))
        self.decoders.append(LinkNetDecoderBlock(channels[-1], channels[-1]))
        self.up = torch.nn.ConvTranspose2d(channels[-1], 32, kernel_size=3, stride=2)
        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=2)
        self.conv3 = torch.nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        xs = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.maxpool1(x)
        for enc in self.encoders:
            x = enc(x)
            xs.append(x)
        xs = xs[::-1]
        for i in range(3):
            x = self.decoders[i](x)
            if x.shape[2:] != xs[i + 1].shape[2:]:
                x = resize(x, xs[i + 1].shape[2:])
            x = x + xs[i + 1]
        x = self.decoders[3](x)

        x = self.up(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.conv3(x)
        return x



# Training

In [8]:
def f1(pred, label):
    """Compute F1 score"""
    pred = pred.view(-1)
    label = label.view(-1)
    tp = (label * pred).sum().to(torch.float32)
    # tn = ((1 - label) * (1 - pred)).sum().to(torch.float32)
    fp = ((1 - label) * pred).sum().to(torch.float32)
    fn = (label * (1 - pred)).sum().to(torch.float32)
    eps = 1e-7
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    return 2 * precision * recall / (precision + recall + eps), precision, recall


def compute_metrics(loader, model, device):
    """Compute the accuracy rate on the given dataset with the input model"""
    model.eval()
    log = dict()
    num_correct = 0
    num_pixels = 0
    f1_score = 0
    precision = 0
    recall = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            output = output[:, -1, :, :].unsqueeze(1)
            pred: torch.Tensor = (torch.sigmoid(output) >= 0.5).float()
            num_correct += torch.sum(pred == y).item()
            num_pixels += torch.numel(pred)
            a, b, c = f1(pred, y)
            f1_score += a.item()
            precision += b.item()
            recall += c.item()

    log["acc"] = num_correct / num_pixels * 100
    log["f1 score"] = f1_score / len(loader) * 100
    log["precision"] = precision / len(loader) * 100
    log["recall"] = recall / len(loader) * 100
    model.train()
    return log



In [25]:
def epoch(model, loader, optimizer, criterion, scaler):
    """Train the model"""
    acc_loss = 0
    for data, target in loader:
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        with torch.cuda.amp.autocast():
            output = model(data)
            loss = 0
            for i in range(output.shape[1]):
                pred = output[:, i, :, :].unsqueeze(1)
                loss += criterion(pred, target)

            acc_loss += loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return acc_loss


def train(
    model,
    model_name,
    train_loader,
    validation_loader,
    lr: float = 1.0e-4,
    epochs: int = 10,
):
    log_file_name = os.path.relpath(os.path.join("logs", model_name + ".csv"))
    with open(log_file_name, "w") as f:
        f.write("epoch,loss,f1,iou,accuracy,precision,recall\n")

    model_file_name = os.path.join("checkpoints", model_name + ".pth")

    # Define the criterion and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define the scheduler and scaler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    scaler = torch.cuda.amp.GradScaler()

    # Define the data augmentation used for the training set, and also create the data loader for it

    # Train the model, then save the training logs and the best model
    loop = tqdm.tqdm(range(epochs))
    max_f1 = 0
    for e in loop:
        loss = epoch(model, train_loader, optimizer, criterion, scaler)
        scheduler.step()
        metrics = compute_metrics(validation_loader, model, DEVICE)
        if metrics["f1 score"] > max_f1:
            max_f1 = metrics["f1 score"]
            if max_f1 > 80.0:
                torch.save(model, model_file_name + ".maxf1")

        with open(log_file_name, "a") as f:
            f.write(
                "{},{},{},{},{},{},{}\n".format(
                    e,
                    loss,
                    metrics["f1 score"],
                    0,
                    metrics["acc"],
                    metrics["precision"],
                    metrics["recall"],
                )
            )

        loop.set_postfix(loss=loss, f1_score=metrics["f1 score"], max_f1=max_f1)

    # Save the logs into a file
    torch.save(model, model_file_name)



In [19]:
train_transform = albumentations.Compose(
    [
        albumentations.Flip(),
        albumentations.Transpose(),
        albumentations.Rotate(),
        albumentations.CoarseDropout(max_holes=8, max_height=8, max_width=8),
        albumentations.OneOf(
            [
                albumentations.OpticalDistortion(),
                albumentations.GridDistortion(),
                albumentations.ElasticTransform(),
            ],
            p=0.5,
        ),
        ToTensorV2(),
    ]
)

train_loader = get_loader(
    data_path=Train_image_path,
    mask_path=Train_mask_path,
    transform=train_transform,
    batch_size=4,
)

# Define the data augmentation used for the validation set, and also create the data loader for it
val_transform = albumentations.Compose(
    [
        albumentations.Flip(),
        albumentations.Transpose(),
        ToTensorV2(),
    ]
)
val_loader = get_loader(
    data_path=Validation_image_path,
    mask_path=Validation_mask_path,
    transform=val_transform,
    batch_size=4,
)

# Actual training

In [26]:
train(
    model=UNet().to(DEVICE),
    model_name="unet",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

100%|██████████| 1000/1000 [1:54:06<00:00,  6.85s/it, f1_score=87.2, loss=1.09, max_f1=89.5]


In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet18(
            weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        ),
        channels=(64, 128, 256, 512),
    ).to(DEVICE),
    model_name="linknet18",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

 17%|█▋        | 168/1000 [18:06<1:28:40,  6.40s/it, f1_score=85.7, loss=1.73, max_f1=87.3]

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet34(
            weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1
        ),
        channels=(64, 128, 256, 512),
    ).to(DEVICE),
    model_name="linknet34",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet50(
            weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ).to(DEVICE),
    model_name="linknet50",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet101(
            weights=torchvision.models.ResNet101_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ).to(DEVICE),
    model_name="linknet101",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet152(
            weights=torchvision.models.ResNet152_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ).to(DEVICE),
    model_name="linknet152",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# Create submissions

In [32]:
def create_postprocessing_images(images, rotations, transposes):
    """Apply transformations to the image and return different prospectives"""
    ims = []
    for image in images:
        for rotation in rotations:
            ims.append(albumentations.rotate(image, rotation))
            if transposes:
                im = albumentations.hflip(image)
                ims.append(albumentations.rotate(im, rotation))
    ims = np.array(ims)
    ims = torch.tensor(ims).transpose(1, -1).transpose(2, -1).float()
    return ims


def combine_postprocessing_images(images, rotations, transposes):
    """Combine predictions of different prospectives"""
    outputs = []
    index = 0
    while index < len(images):
        output = np.zeros(images[0].shape)
        for rotation in rotations:
            im = images[index, 0]
            output += albumentations.rotate(im, -rotation)
            index += 1
            if transposes:
                im = images[index, 0]
                im = albumentations.rotate(im, -rotation)
                output += albumentations.hflip(im)
                index += 1
        output = output / len(images)
        outputs.append(output)
    return np.array(outputs)


def create_submission(model_name: str):
    model_file_name = os.path.join("checkpoints", model_name + ".pth.maxf1")
    model = torch.load(model_file_name).to(DEVICE)
    model.eval()

    # Create the directory to store the predictions
    path = "data/test_set_images"
    pred_path = "predictions/" + model_name
    if not os.path.exists(pred_path):
        os.makedirs(pred_path)

    # For each image, apply postprocessing augmentation, make predictions and save predictions
    for image in tqdm.tqdm(os.listdir(path)):
        img_path = os.path.join(path, image, image + ".png")
        im = np.asarray(Image.open(img_path)) / 255
        ims = create_postprocessing_images(
            [im], rotations=[0, 90, 180, 270], transposes=True
        )

        with torch.no_grad():
            output = model(ims.to(DEVICE))
            # output = output[:, 0].unsqueeze(1)
            predicts = torch.sigmoid(output).cpu().detach()

        predict = combine_postprocessing_images(
            predicts.numpy(), rotations=[0, 90, 180, 270], transposes=True
        ).reshape((608, 608))
        predict[predict < 0.5] = 0
        predict[predict >= 0.5] = 1
        predict *= 255
        Image.fromarray(predict).convert("L").save(
            os.path.join(pred_path, image) + ".png"
        )

    # Generate the submission file
    submission_filename = os.path.join(
        "submissions", "submission_{}.csv".format(model_name)
    )
    image_filenames = []
    for i in range(1, 51):
        image_filename = pred_path + "/test_" + str(i) + ".png"
        image_filenames.append(image_filename)
    masks_to_submission(submission_filename, *image_filenames)



In [33]:
# create_submission("unet")
create_submission("resnet18")
create_submission("resnet34")
create_submission("resnet50")
create_submission("resnet101")
create_submission("resnet152")

100%|██████████| 50/50 [00:34<00:00,  1.45it/s]
