In [None]:
import sys
sys.path.append("..")

import os
import argparse
import json

import torch
from torchvision.transforms import Normalize
import numpy as np

from tqdm import tqdm
import matplotlib.pyplot as plt

from src.models.base_unet import BaseUNet
from src.models.unet_pp import UNetPlus
from src.utils.visualizations import plot_predictions
from src.utils.io import load_image, save_mask
from src.scripts.mask_to_submission import masks_to_submission

In [None]:
checkpoint = "../logs/2023-07-11_13-32-43"

In [None]:
args_path = os.path.join(checkpoint, "config.json")
with open(args_path, "r") as f:
    vars = json.load(f)

args = argparse.Namespace(**vars)

In [None]:
metadata = os.path.join("..", args.metadata)

metadata = json.load(open(metadata, "r"))

In [None]:
if args.model == "unet":
    model = BaseUNet()
elif args.model == "unet++":
    model = UNetPlus()
else:
    raise ValueError("Invalid model name")

In [None]:
model.load_state_dict(torch.load(os.path.join(checkpoint, "best_model.pt"), map_location=torch.device("cpu")))

# Get test images

In [None]:
test_images = "../data/test/images/"

fnames = os.listdir(test_images)
fnames = [os.path.join(test_images, fname) for fname in fnames if fname.endswith(".png")]

len(fnames)

# Predict test images

In [None]:
model.eval()
pred_path = "../data/preds/"
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

mean = metadata["cil"]["img_mean"]
std = metadata["cil"]["img_std"]

transfrom = Normalize(mean=mean, std=std)

predictions = []
with torch.no_grad():
    for fname in tqdm(fnames, desc="Predicting", total=len(fnames), ncols=80):
        image = load_image(fname)
        image = torch.tensor(image)
        image = image.permute(2, 0, 1).unsqueeze(0).float()

        image = transfrom(image)

        prediction = model(image).squeeze(0)
        predictions.append(prediction)


        out_fname = os.path.join(pred_path, os.path.basename(fname))
        pred = prediction.detach().numpy() > 0.5
        pred = pred.astype(np.uint8) * 255
        pred = np.stack([pred, pred, pred], axis=-1)
        save_mask(pred, out_fname)

predictions = torch.stack(predictions)

In [None]:
N = 5

# images = [torch.tensor(load_image(fname)) for fname in fnames[:N]]
# masks = [torch.zeros_like(image) for image in predictions[:N]]
# weights = [torch.zeros_like(image) for image in predictions[:N]]

# images = torch.stack(images)
# masks = torch.stack(masks)
# weights = torch.stack(weights)

plot_predictions(
    images=images,
    masks=masks,
    predictions=predictions[:N],
    weights=predictions[:N] > .5,
)

# Create submission file

In [None]:
fnames = os.listdir(pred_path)
fnames = [os.path.join(pred_path, fname) for fname in fnames if fname.endswith(".png")]

In [None]:
masks_to_submission(
    "../data/submission.csv",
    "",
    *sorted(fnames),
)