In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

import os
import argparse
import json

import torch
from torchvision.transforms import Normalize
import numpy as np

from tqdm import tqdm
import matplotlib.pyplot as plt

from src.models.base_unet import BaseUNet
from src.models.unet_pp import UNetPlus
from src.models.hr_spin import HRSPIN
from src.utils.visualizations import plot_predictions
from src.utils.io import load_image, save_mask
from src.scripts.mask_to_submission import masks_to_submission

In [None]:
DEVICE = "mps"

In [None]:
models = os.listdir("../logs")

scores = []
for dir in models:
    metrics = json.load(open(f"../logs/{dir}/metrics.json", "r"))
    args = argparse.Namespace(**json.load(open(f"../logs/{dir}/config.json", "r")))
    scores.append((dir, args.model, max(metrics['val_acc'])))

threshold = 0.925
count = np.sum([1 for score in scores if score[-1] > threshold])
print(f"Number of models with score > {threshold}: {count}")
print()

scores.sort(key=lambda x: x[-1], reverse=True)
for score in scores:
    print(score)


In [None]:
N = 3

checkpoints = [f"../logs/{score[0]}" for score in scores[:N]]

In [None]:
metadata = "../metadata.json"

metadata = json.load(open(metadata, "r"))

In [None]:
def load_model(checkpoint):
    with open(os.path.join(checkpoint, "config.json"), "r") as f:
        vars = json.load(f)
    args = argparse.Namespace(**vars)

    chs = [3] + [2 ** (i + 5) for i in range(args.depth)]
    if args.model == "unet":
        model = BaseUNet(chs)
    elif args.model == "unet++":
        model = UNetPlus(chs)
    elif args.model == "spin":
        model = HRSPIN(num_stacks=args.num_stacks)
    else:
        raise ValueError("Invalid model name")
    
    model.load_state_dict(torch.load(os.path.join(checkpoint, "best_model.pt"), map_location=torch.device("cpu")))

    return model, args

# Get test images

In [None]:
test_images = "../data/test/images/"

fnames = os.listdir(test_images)
fnames = [os.path.join(test_images, fname) for fname in fnames if fname.endswith(".png")]

len(fnames)

# Predict test images

In [None]:
def predict_mask(path, checkpoints):
    mean = metadata["cil"]["img_mean"]
    std = metadata["cil"]["img_std"]

    transform = Normalize(mean=mean, std=std)

    image = load_image(path)
    image = torch.tensor(image)
    image = image.permute(2, 0, 1).unsqueeze(0).float()
    image = transform(image)

    preds = []
    for ckpt in checkpoints:
        model, args = load_model(ckpt)

        model.eval()
        model.to(DEVICE)
        with torch.no_grad():
            if args.model == "spin":
                prediction, _ = model(image.to(DEVICE))
                prediction = prediction[-1].squeeze(0)
            else:
                prediction = model(image.to(DEVICE)).squeeze(0)

            preds.append(prediction.cpu())

    return torch.stack(preds).mean(0)

predictions = torch.stack([predict_mask(fname, checkpoints) for fname in tqdm(fnames)])
predictions.shape

In [None]:
N = 5

images = [torch.tensor(load_image(fname)) for fname in fnames[:N]]
masks = [torch.zeros_like(image) for image in predictions[:N]]
weights = [torch.zeros_like(image) for image in predictions[:N]]

images = torch.stack(images)
masks = torch.stack(masks)
weights = torch.stack(weights)

plot_predictions(
    images=images,
    masks=masks,
    predictions=predictions[:N],
    # weights=predictions[:N] > .5,
)

# Save masks

In [None]:
pred_path = "../data/preds/"

for fname, prediction in tqdm(zip(fnames, predictions), total=len(fnames), ncols=80):
    out_fname = os.path.join(pred_path, os.path.basename(fname))
    prediction = prediction.numpy() > 0.5
    prediction = prediction.astype(np.uint8) * 255
    prediction = np.stack([prediction, prediction, prediction], axis=-1)
    save_mask(prediction, out_fname)

# Create submission file

In [None]:
fnames = os.listdir(pred_path)
fnames = [os.path.join(pred_path, fname) for fname in fnames if fname.endswith(".png")]

In [None]:
masks_to_submission(
    "../data/submission.csv",
    "",
    *sorted(fnames),
)

# Expected grade

In [19]:
our_score = 0.91513

max_score = 0.94186
baseline_score = 0.86380

4 + 2 * (our_score - baseline_score) / (max_score - baseline_score)

5.315142198308992