In [None]:
import torch
import torch.utils.data as tdata
import torchvision
from torchvision.transforms.functional import resize
from PIL import Image
import albumentations
from albumentations.pytorch import ToTensorV2
import os
import numpy as np
import _thread
import tqdm
import re
import matplotlib.image as mpimg

DEVICE = "cuda"

# Models

In [None]:
class LinkNetDecoder(torch.nn.Module):
    """LinkNet decoder"""

    def __init__(self, in_channel, out_channel):
        super(LinkNetDecoder, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channel, in_channel // 4, kernel_size=1)
        self.bn1 = torch.nn.BatchNorm2d(in_channel // 4)
        self.up = torch.nn.ConvTranspose2d(
            in_channel // 4,
            in_channel // 4,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.bn2 = torch.nn.BatchNorm2d(in_channel // 4)
        self.conv2 = torch.nn.Conv2d(in_channel // 4, out_channel, kernel_size=1)
        self.bn3 = torch.nn.BatchNorm2d(out_channel)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.up(x)
        x = self.bn2(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = torch.nn.functional.relu(x)
        return x


class LinkNet(torch.nn.Module):
    """LinkNet"""

    def __init__(
        self, in_channel, out_channel, filters=None, resnet=None, pretrained=False
    ):
        super(LinkNet, self).__init__()
        if filters is None:
            filters = [64, 128, 256, 512]
        if resnet is None:
            resnet = torchvision.models.resnet34(pretrained=pretrained)
        assert len(filters) == 4
        self.conv1 = (
            resnet.conv1
            if in_channel == 3
            else torch.nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3)
        )
        self.bn1 = resnet.bn1
        self.maxpool1 = resnet.maxpool
        self.encoders = torch.nn.ModuleList()
        self.encoders.append(resnet.layer1)
        self.encoders.append(resnet.layer2)
        self.encoders.append(resnet.layer3)
        self.encoders.append(resnet.layer4)

        self.decoders = torch.nn.ModuleList()
        filters = filters[::-1]
        for i in range(len(filters) - 1):
            self.decoders.append(LinkNetDecoder(filters[i], filters[i + 1]))
        self.decoders.append(LinkNetDecoder(filters[3], filters[3]))
        self.up = torch.nn.ConvTranspose2d(filters[3], 32, kernel_size=3, stride=2)
        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=2)
        self.conv3 = torch.nn.Conv2d(32, out_channel, kernel_size=1)

    def forward(self, x):
        xs = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.nn.functional.relu(x)
        x = self.maxpool1(x)
        for enc in self.encoders:
            x = enc(x)
            xs.append(x)
        xs = xs[::-1]
        for i in range(3):
            x = self.decoders[i](x)
            if x.shape[2:] != xs[i + 1].shape[2:]:
                x = resize(x, xs[i + 1].shape[2:])
            x = x + xs[i + 1]
        x = self.decoders[3](x)

        x = self.up(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = self.conv3(x)
        return x



# Data

In [None]:
def img_crop(images, size, stride=16, padding=0):
    """Crop the image into patches of a given size"""
    list_img_patches = []
    for image in images:
        ndim = image.ndim
        imgwidth = image.shape[0] - size
        imgheight = image.shape[1] - size
        if ndim == 2:
            image = image.reshape((image.shape[0], image.shape[1], 1))
        new_image = np.zeros(
            (image.shape[0] + 2 * padding, image.shape[1] + 2 * padding, image.shape[2])
        )
        new_image[
            padding : padding + image.shape[0], padding : padding + image.shape[1], :
        ] = image
        image = new_image
        for i in range(padding, imgheight + padding + 1, stride):
            for j in range(padding, imgwidth + padding + 1, stride):
                im_patch = image[
                    j - padding : j + size + padding,
                    i - padding : i + size + padding,
                    :,
                ]
                if ndim == 2:
                    im_patch = im_patch.reshape((im_patch.shape[0], im_patch.shape[1]))
                list_img_patches.append(im_patch)
    return np.array(list_img_patches, dtype=np.float32)


class RoadDataset(torch.utils.data.Dataset):
    """The class RoadDataset loads the data and executes the pre-processing operations on it"""

    def __init__(
        self,
        image_path,
        mask_path,
        transform=None,
        one_hot=False,
        rotations=None,
        crop=False,
        crop_size=224,
        stride=16,
        padding=0,
    ):
        self.transform = transform
        self.one_hot = one_hot
        self.images = self.load_images(image_path)
        self.masks = self.load_images(mask_path)
        self.images_augmented = []
        self.masks_augmented = []

        if rotations is not None:
            self.images, self.masks = self.rotate(self.images, self.masks, rotations)

        # Crop the images into patches with respect to the given size
        if crop:
            self.images = img_crop(self.images, crop_size, stride, padding)
            self.masks = img_crop(self.masks, crop_size, stride, padding)

        # Data augmentation
        for i in range(len(self.images)):
            output = self.transform(image=self.images[i], mask=self.masks[i])
            self.images_augmented.append(output["image"])
            self.masks_augmented.append(output["mask"])

    @staticmethod
    def rotate(images, masks, rotations):
        """This method applies rotations to the image according to the given angles"""
        ims = []
        msks = []
        for im, msk in zip(images, masks):
            for rotation in rotations:
                ims.append(albumentations.rotate(im, rotation))
                msks.append(albumentations.rotate(msk, rotation))
        return np.asarray(ims), np.asarray(msks)

    def get_images(self):
        return self.images, self.masks

    @staticmethod
    def load_images(image_path):
        """This method loads the images from the given path"""
        images = []
        for img in os.listdir(image_path):
            path = os.path.join(image_path, img)
            image = Image.open(path)
            images.append(np.asarray(image))
        return np.asarray(images)

    def augment(self, index):
        """This method applies data augmentation to the images"""
        output = self.transform(image=self.images[index], mask=self.masks[index])
        self.images_augmented[index] = output["image"]
        self.masks_augmented[index] = output["mask"]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        """This method returns the image at a certain position and its mask"""
        image = self.images_augmented[index]
        mask = self.masks_augmented[index]
        #         if self.transform is not None:
        #             output = self.transform(image=image, mask=mask)
        #             image = output['image']
        #             mask = output['mask']

        _thread.start_new_thread(self.augment, (index,))
        mask = mask.reshape((1,) + mask.shape)
        if self.one_hot:
            one_hot_mask = torch.zeros((2,) + mask.shape[1:])
            one_hot_mask.scatter_(0, mask.long(), 1).float()
            mask = one_hot_mask
        return (image / 255), (mask > 100).float()


def get_loader(
    data_path,
    mask_path,
    transform,
    batch_size,
    num_worker,
    shuffle,
    pin_memory,
    one_hot=False,
    rotations=None,
    crop=False,
    crop_size=224,
    stride=16,
    padding=0,
):
    """Create the DataLoader class"""
    dataset = RoadDataset(
        data_path,
        mask_path,
        transform,
        one_hot=one_hot,
        rotations=rotations,
        crop=crop,
        crop_size=crop_size,
        stride=stride,
        padding=padding,
    )
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        pin_memory=pin_memory,
        num_workers=num_worker,
        generator=torch.Generator().manual_seed(127),
    )



# Training

In [None]:
def f1(pred, label):
    """Compute F1 score"""
    pred = pred.view(-1)
    label = label.view(-1)
    tp = (label * pred).sum().to(torch.float32)
    # tn = ((1 - label) * (1 - pred)).sum().to(torch.float32)
    fp = ((1 - label) * pred).sum().to(torch.float32)
    fn = (label * (1 - pred)).sum().to(torch.float32)
    eps = 1e-7
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    return 2 * precision * recall / (precision + recall + eps), precision, recall


def compute_metrics(loader, model, device, one_hot=False):
    """Compute the accuracy rate on the given dataset with the input model"""
    model.eval()
    log = dict()
    num_correct = 0
    num_pixels = 0
    f1_score = 0
    precision = 0
    recall = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            output = output[:, -1, :, :].unsqueeze(1)
            pred = torch.sigmoid(output)
            if one_hot:
                pred: torch.Tensor = pred.argmax(1)
                y = y.argmax(1)
            else:
                pred: torch.Tensor = (pred >= 0.5).float()
            num_correct += torch.sum(pred == y).item()
            num_pixels += torch.numel(pred)
            a, b, c = f1(pred, y)
            f1_score += a.item()
            precision += b.item()
            recall += c.item()

    log["acc"] = num_correct / num_pixels * 100
    log["f1 score"] = f1_score / len(loader) * 100
    log["precision"] = precision / len(loader) * 100
    log["recall"] = recall / len(loader) * 100
    model.train()
    return log



In [None]:
prefix = "/kaggle/input/d/tudoroancea/road-segmentation-dataset/"
Train_image_path = prefix + "data/training/images"
Train_mask_path = prefix + "data/training/groundtruth"
Test_image_path = prefix + "data/testing/images"
Test_mask_path = prefix + "data/testing/groundtruth"

LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_EPOCHS = 1000
NUM_WORKERS = 0
IMAGE_SIZE = 400
PIN_MEMORY = True
ONE_HOT = False
CROP = False
CROP_SIZE = 224
STRIDE = 16
PADDING = 0
T = 50

torch.manual_seed(0)
torch.cuda.manual_seed(0)


def train_model(model, loader, optimizer, criterion, scaler):
    """Train the model"""
    acc_loss = 0
    for data, target in loader:
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        with torch.cuda.amp.autocast():
            output = model(data)
            loss = 0
            for i in range(output.shape[1]):
                pred = output[:, i, :, :].unsqueeze(1)
                loss += criterion(pred, target)

            acc_loss += loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return acc_loss


def main(name):
    if ONE_HOT:
        name += "_with_one_hot"

    # Create the network
    net = LinkNet(
        in_channel=3,
        out_channel=1,
        resnet=torchvision.models.resnet152(pretrained=True),
        filters=[256, 512, 1024, 2048],
    ).to(DEVICE)
    # Define the criterion and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)

    # Define the scheduler and scaler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T)
    scaler = torch.cuda.amp.GradScaler()

    # Define the data augmentation used for the training set, and also create the data loader for it
    train_transform = albumentations.Compose(
        [
            albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            albumentations.Flip(),
            albumentations.Transpose(),
            albumentations.Rotate(),
            albumentations.CoarseDropout(max_holes=8, max_height=8, max_width=8),
            albumentations.OneOf(
                [
                    albumentations.OpticalDistortion(),
                    albumentations.GridDistortion(),
                    albumentations.ElasticTransform(),
                ],
                p=0.5,
            ),
            ToTensorV2(),
        ]
    )

    train_loader = get_loader(
        Train_image_path,
        Train_mask_path,
        train_transform,
        BATCH_SIZE,
        NUM_WORKERS,
        shuffle=True,
        pin_memory=PIN_MEMORY,
        one_hot=ONE_HOT,
        crop=CROP,
        crop_size=CROP_SIZE,
        stride=STRIDE,
        padding=PADDING,
    )

    # Define the data augmentation used for the validation set, and also create the data loader for it
    val_transform = albumentations.Compose(
        [
            albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
            albumentations.Flip(),
            albumentations.Transpose(),
            ToTensorV2(),
        ]
    )
    val_loader = get_loader(
        Test_image_path,
        Test_mask_path,
        val_transform,
        BATCH_SIZE,
        NUM_WORKERS,
        shuffle=False,
        pin_memory=PIN_MEMORY,
        rotations=[0, 90, 180, 270],
        one_hot=ONE_HOT,
        crop=CROP,
        crop_size=CROP_SIZE,
        stride=STRIDE,
        padding=PADDING,
    )

    # Train the model, then save the training logs and the best model
    logs = []
    loop = tqdm.tqdm(range(NUM_EPOCHS))
    max_f1 = 0

    for e in loop:
        loss = train_model(net, train_loader, optimizer, criterion, scaler)
        scheduler.step()
        log = compute_metrics(val_loader, net, DEVICE, one_hot=ONE_HOT)
        log["epochs"] = e
        log["loss"] = loss
        logs.append(log)
        if log["f1 score"] > max_f1:
            max_f1 = log["f1 score"]
            if max_f1 > 80.0:
                torch.save(net, "checkpoints/" + name + "_max_f1.pth")
        loop.set_postfix(
            loss=loss, acc=log["acc"], f1_score=log["f1 score"], max_f1=max_f1
        )

    # Save the logs into a file
    f = open("logs/" + name + "_results", mode="w")
    for log in logs:
        f.write(str(log) + "\n")
    f.close()
    #     torch.save(net, 'checkpoints/' + name + '.pth')
    return net



In [None]:
main("bruh")


# Creating Submission


In [None]:
def create_different_prospective(images, rotations, transposes):
    """Apply transformations to the image and return different prospectives"""
    ims = []
    for image in images:
        for rotation in rotations:
            ims.append(albumentations.rotate(image, rotation))
            if transposes:
                im = albumentations.hflip(image)
                ims.append(albumentations.rotate(im, rotation))
    ims = np.array(ims)
    ims = torch.tensor(ims).transpose(1, -1).transpose(2, -1).float()
    return ims


def combine_different_prospective(images, rotations, transposes):
    """Combine predictions of different prospectives"""
    outputs = []
    index = 0
    while index < len(images):
        output = np.zeros(images[0].shape)
        for rotation in rotations:
            im = images[index, 0]
            output += albumentations.rotate(im, -rotation)
            index += 1
            if transposes:
                im = images[index, 0]
                im = albumentations.rotate(im, -rotation)
                output += albumentations.hflip(im)
                index += 1
        output = output / len(images)
        outputs.append(output)
    return np.array(outputs)



In [None]:
foreground_threshold = 0.5


def patch_to_label(patch):
    df = np.mean(patch)
    if df > foreground_threshold:
        return 1
    else:
        return 0


def mask_to_submission_strings(image_filename):
    """Reads a single image and outputs the strings that should go into the submission file"""
    img_number = int(re.search(r"\d+", image_filename).group(0))
    im = mpimg.imread(image_filename)
    patch_size = 16
    for j in range(0, im.shape[1], patch_size):
        for i in range(0, im.shape[0], patch_size):
            patch = im[i : i + patch_size, j : j + patch_size]
            label = patch_to_label(patch)
            yield ("{:03d}_{}_{},{}".format(img_number, j, i, label))


def masks_to_submission(submission_filename, *image_filenames):
    """Converts images into a submission file"""
    with open(submission_filename, "w") as f:
        f.write("id,prediction\n")
        for fn in image_filenames[0:]:
            f.writelines("{}\n".format(s) for s in mask_to_submission_strings(fn))



In [None]:
IMAGE_SIZE = 608
ROTATIONS = [0, 90, 180, 270]
TRANSPOSE = True
ONE_HOT = False
model_name = "bruh_max_f1.pth"

model = torch.load("/kaggle/working/checkpoints/" + model_name).to(DEVICE)
model.eval()

# Create the directory to store the predictions
path = prefix + "data/test_set_images"
pred_path = "/kaggle/working/prediction/" + model_name
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

# For each image, apply Test Time Augmentation, make predictions and save predictions
images = os.listdir(path)
for image in tqdm.tqdm(images):
    img_path = os.path.join(path, image, image + ".png")
    im = np.asarray(Image.open(img_path)) / 255
    ims = create_different_prospective([im], ROTATIONS, TRANSPOSE)

    with torch.no_grad():
        output = model(ims.to(DEVICE))
        output = output[:, 0].unsqueeze(1)
        predicts = torch.sigmoid(output).cpu().detach()
        # predicts = torch.sigmoid(model(ims.to(device))).cpu().detach()

    if ONE_HOT:
        predicts = predicts.argmax(1).unsqueeze(1).float()
    predict = combine_different_prospective(
        predicts.numpy(), ROTATIONS, TRANSPOSE
    ).reshape((608, 608))
    # predict[predict < 0.5] = 0
    # predict[predict >= 0.5] = 1
    predict *= 255
    Image.fromarray(predict).convert("L").save(os.path.join(pred_path, image) + ".png")


In [None]:
# Generate the submission file
submission_filename = "dummy_submission.csv"
image_filenames = []
for i in range(1, 51):
    image_filename = pred_path + "/test_" + str(i) + ".png"
    #     print(image_filename)
    image_filenames.append(image_filename)
masks_to_submission(submission_filename, *image_filenames)
