In [1]:
import os

import albumentations
import numpy as np
import torch
import torchvision
import tqdm
from PIL import Image
from albumentations.pytorch import ToTensorV2

from data.mask_to_submission import masks_to_submission
from data_processing import get_loader
from models import LinkNet, UNet

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE=", DEVICE)
prefix = "./"
Train_image_path = prefix + "data/training/images"
Train_mask_path = prefix + "data/training/groundtruth"
Validation_image_path = prefix + "data/validating/images"
Validation_mask_path = prefix + "data/validating/groundtruth"

DEVICE= cuda


# Training

In [5]:
def f1(pred, label):
    """Compute F1 score"""
    pred = pred.view(-1)
    label = label.view(-1)
    tp = (label * pred).sum().to(torch.float32)
    # tn = ((1 - label) * (1 - pred)).sum().to(torch.float32)
    fp = ((1 - label) * pred).sum().to(torch.float32)
    fn = (label * (1 - pred)).sum().to(torch.float32)
    eps = 1e-7
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    return 2 * precision * recall / (precision + recall + eps), precision, recall


def compute_metrics(loader, model, device):
    """Compute the accuracy rate on the given dataset with the input model"""
    model.eval()
    log = dict()
    num_correct = 0
    num_pixels = 0
    f1_score = 0
    precision = 0
    recall = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            output = model(x)
            output = output[:, -1, :, :].unsqueeze(1)
            pred: torch.Tensor = (torch.sigmoid(output) >= 0.5).float()
            num_correct += torch.sum(pred == y).item()
            num_pixels += torch.numel(pred)
            a, b, c = f1(pred, y)
            f1_score += a.item()
            precision += b.item()
            recall += c.item()

    log["acc"] = num_correct / num_pixels * 100
    log["f1 score"] = f1_score / len(loader) * 100
    log["precision"] = precision / len(loader) * 100
    log["recall"] = recall / len(loader) * 100
    model.train()
    return log



In [6]:
def epoch(model, loader, optimizer, criterion, scaler):
    """Train the model"""
    acc_loss = 0
    for data, target in loader:
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        with torch.cuda.amp.autocast():
            output = model(data)
            loss = 0
            for i in range(output.shape[1]):
                pred = output[:, i, :, :].unsqueeze(1)
                loss += criterion(pred, target)

            acc_loss += loss.item()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    return acc_loss


def train(
    model,
    model_name,
    train_loader,
    validation_loader,
    lr: float = 1.0e-4,
    epochs: int = 10,
):
    log_file_name = os.path.join("logs", model_name)
    with open(log_file_name, "w") as f:
        f.write("epoch,loss,f1,iou,accuracy,precision,recall")

    model_file_name = os.path.join("checkpoints", model_name + ".pth")

    # Define the criterion and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Define the scheduler and scaler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    scaler = torch.cuda.amp.GradScaler()

    # Define the data augmentation used for the training set, and also create the data loader for it

    # Train the model, then save the training logs and the best model
    loop = tqdm.tqdm(range(epochs))
    max_f1 = 0
    for e in loop:
        loss = epoch(model, train_loader, optimizer, criterion, scaler)
        scheduler.step()
        metrics = compute_metrics(validation_loader, model, DEVICE)
        if metrics["f1 score"] > max_f1:
            max_f1 = metrics["f1 score"]
            if max_f1 > 80.0:
                torch.save(model, model_file_name + ".maxf1")

        with open(log_file_name, "a") as f:
            f.write(
                "{},{},{},{},{},{},{}".format(
                    e,
                    loss,
                    metrics["f1 score"],
                    0,
                    metrics["acc"],
                    metrics["precision"],
                    metrics["recall"],
                )
            )

        loop.set_postfix(loss=loss, f1_score=metrics["f1 score"], max_f1=max_f1)

    # Save the logs into a file
    torch.save(model, model_file_name)



In [None]:
train_transform = albumentations.Compose(
    [
        albumentations.Flip(),
        albumentations.Transpose(),
        albumentations.Rotate(),
        albumentations.CoarseDropout(max_holes=8, max_height=8, max_width=8),
        albumentations.OneOf(
            [
                albumentations.OpticalDistortion(),
                albumentations.GridDistortion(),
                albumentations.ElasticTransform(),
            ],
            p=0.5,
        ),
        ToTensorV2(),
    ]
)

train_loader = get_loader(
    data_path=Train_image_path,
    mask_path=Train_mask_path,
    transform=train_transform,
    batch_size=4,
)

# Define the data augmentation used for the validation set, and also create the data loader for it
val_transform = albumentations.Compose(
    [
        albumentations.Flip(),
        albumentations.Transpose(),
        ToTensorV2(),
    ]
)
val_loader = get_loader(
    data_path=Validation_image_path,
    mask_path=Validation_mask_path,
    transform=val_transform,
    batch_size=4,
)

# Actual training

In [None]:
train(
    model=UNet(),
    model_name="unet",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet18(
            weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        ),
        channels=(64, 128, 256, 512),
    ),
    model_name="linknet18",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet34(
            weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1
        ),
        channels=(64, 128, 256, 512),
    ),
    model_name="linknet34",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet50(
            weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ),
    model_name="linknet50",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet101(
            weights=torchvision.models.ResNet101_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ),
    model_name="linknet101",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

In [None]:
train(
    model=LinkNet(
        encoder=torchvision.models.resnet152(
            weights=torchvision.models.ResNet152_Weights.IMAGENET1K_V2
        ),
        channels=(256, 512, 1024, 2048),
    ),
    model_name="linknet152",
    epochs=1000,
    train_loader=train_loader,
    validation_loader=val_loader,
)

# Create submissions

In [7]:
def create_postprocessing_images(images, rotations, transposes):
    """Apply transformations to the image and return different prospectives"""
    ims = []
    for image in images:
        for rotation in rotations:
            ims.append(albumentations.rotate(image, rotation))
            if transposes:
                im = albumentations.hflip(image)
                ims.append(albumentations.rotate(im, rotation))
    ims = np.array(ims)
    ims = torch.tensor(ims).transpose(1, -1).transpose(2, -1).float()
    return ims


def combine_postprocessing_images(images, rotations, transposes):
    """Combine predictions of different prospectives"""
    outputs = []
    index = 0
    while index < len(images):
        output = np.zeros(images[0].shape)
        for rotation in rotations:
            im = images[index, 0]
            output += albumentations.rotate(im, -rotation)
            index += 1
            if transposes:
                im = images[index, 0]
                im = albumentations.rotate(im, -rotation)
                output += albumentations.hflip(im)
                index += 1
        output = output / len(images)
        outputs.append(output)
    return np.array(outputs)


def create_submission(model_name: str):
    model_file_name = os.path.join("checkpoints", model_name + ".pth.maxf1")
    model = torch.load(model_file_name).to(DEVICE)
    model.eval()

    # Create the directory to store the predictions
    path = "data/test_set_images"
    pred_path = "predictions/" + model_name
    if not os.path.exists(pred_path):
        os.makedirs(pred_path)

    # For each image, apply postprocessing augmentation, make predictions and save predictions
    for image in tqdm.tqdm(os.listdir(path)):
        img_path = os.path.join(path, image, image + ".png")
        im = np.asarray(Image.open(img_path)) / 255
        ims = create_postprocessing_images(
            [im], rotations=[0, 90, 180, 270], transposes=True
        )

        with torch.no_grad():
            output = model(ims.to(DEVICE))
            # output = output[:, 0].unsqueeze(1)
            predicts = torch.sigmoid(output).cpu().detach()

        predict = combine_postprocessing_images(
            predicts.numpy(), rotations=[0, 90, 180, 270], transposes=True
        ).reshape((608, 608))
        predict[predict < 0.5] = 0
        predict[predict >= 0.5] = 1
        predict *= 255
        Image.fromarray(predict).convert("L").save(
            os.path.join(pred_path, image) + ".png"
        )

    # Generate the submission file
    submission_filename = "submission_{}.csv".format(model_name[:-4])
    image_filenames = []
    for i in range(1, 51):
        image_filename = pred_path + "/test_" + str(i) + ".png"
        image_filenames.append(image_filename)
    masks_to_submission(submission_filename, *image_filenames)

