In [4]:
import torch
import torch.nn as nn
from glob import glob
import sys
import tifffile
import shutil
import numpy as np
from tqdm import tqdm
from torch.cuda.amp import autocast
import os
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
import gc

# config

In [2]:
class cfg:
    # ============== model cfg =============
    in_chans = 6  # 65
    # ============== _ cfg =============
    image_size = 256
    stride = image_size // 1
    drop_egde_pixel = 32
    batch = 64

    # ============== fold =============
    batch = 128
    model_path = ["/kaggle/working/notebook/experiment/baseline/baseline/baseline_best_fold0.pth"]


is_kaggle_notebook = "kaggle_web_client" in sys.modules
if is_kaggle_notebook:
    dir_test = "/kaggle/input/blood-vessel-segmentation/test"
else:
    dir_test = "/kaggle/input/blood-vessel-segmentation/train/kidney_2"

dir_raw = "/kaggle/working/dataset_test/stack_raw"
dir_clipped = "/kaggle/working/dataset_test/stack_clipped"

In [5]:
def stack_tifs(dir_dataset, dir_stack):
    os.makedirs(dir_stack, exist_ok=True)
    for data_name in os.listdir(dir_dataset):
        for data_type in ["images", "labels"]:
            stack = []
            save_path = f"{dir_stack}/{data_name}_{data_type}.npy"
            tif_paths = glob(f"{dir_dataset}/{data_name}/{data_type}/*.tif")

            if len(tif_paths) == 0:
                continue
            if os.path.exists(save_path):
                continue

            for tif_path in sorted(tif_paths):
                tif = tifffile.imread(tif_path)
                stack.append(tif)

            stack = np.stack(stack)
            np.save(save_path, stack)


# 訓練のためにpercentileに基づいて値をクリップしfloat32で保存
def save_clipped_npy(dir_raw, dir_clipped, percentile):
    os.makedirs(dir_clipped, exist_ok=True)
    for npy_path in glob(f"{dir_raw}/*.npy"):
        data_name = npy_path.split("/")[-1].split(".")[0]
        data_type = data_name.split("_")[-1]
        save_path = f"{dir_clipped}/{data_name}.npy"

        if os.path.exists(save_path):
            continue

        if "voi" in npy_path:
            continue

        if "labels" == data_type:
            npy = np.load(npy_path).astype(bool)

        elif data_type == "images":
            npy = np.load(npy_path)
            stack_len = npy.shape[0]

            upper = stack_len * 0.3
            lower = stack_len * 0.7

            p_low = int(np.percentile(npy[upper:lower], percentile))  # 上下端に近い部分はpercentile計算対象から除外
            p_high = int(np.percentile(npy[upper:lower], 100 - percentile))

            npy = np.clip(npy, p_low, p_high).astype("float32")
            scale = float(p_high - p_low)
            npy = (npy - p_low) / scale

        np.save(save_path, npy)


stack_tifs(dir_test, dir_raw)
save_clipped_npy(dir_raw, dir_clipped, 0.05)
shutil.rmtree(dir_raw)

In [4]:
class CustomModel(nn.Module):
    def __init__(self, model_arch, backbone, in_chans, target_size, weight):
        super().__init__()

        self.model = smp.create_model(
            model_arch,
            encoder_name=backbone,
            encoder_weights=weight,
            in_channels=in_chans,
            classes=target_size,
            activation=None,
        )
        self.batch = cfg.batch
        self.in_chans = in_chans

    def forward_(self, image):
        output = self.model(image)
        return output[:, 0]

    def forward(self, image):
        # image.shape=(batch,c,h,w)
        image = image.to(torch.float32)

        shape = image.shape
        image = [torch.rot90(image, k=i, dims=(-2, -1)) for i in range(4)]
        image = torch.cat(image, dim=0)
        with autocast():
            with torch.no_grad():
                image = [self.forward_(image[i * self.batch : (i + 1) * self.batch]) for i in range(image.shape[0] // self.batch + 1)]
                image = torch.cat(image, dim=0)
        image = image.sigmoid()
        image = image.reshape(4, shape[0], *shape[2:])
        image = [torch.rot90(image[i], k=-i, dims=(-2, -1)) for i in range(4)]
        image = torch.stack(image, dim=0).mean(0)

        return image


def load_model(model_path):
    pth = torch.load(model_path)

    print("model_name", pth["model_arch"])
    print("backbone", pth["backbone"])
    model = CustomModel(pth["model_arch"], pth["backbone"], pth["in_chans"], pth["target_size"], weight=None)
    model.load_state_dict(pth["model"])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    return model


class InferenceDataset(Dataset):
    def __init__(self, stack, in_chan):
        self.in_chan = in_chan

        pad = torch.zeros(self.in_chan // 2, *stack.shape[1:], dtype=stack.dtype)
        self.stack = torch.cat((pad, stack, pad), dim=0)

    def __len__(self):
        return self.stack.shape[0] - self.in_chan

    def __getitem__(self, z_):
        stack = self.stack[z_ : z_ + self.in_chan]
        return stack, z_

In [5]:
def add_pad(stack: torch.Tensor, pad: int):
    # stack=(C,H,W)
    # output=(C,H+2*pad,W+2*pad)
    mean_ = int(stack.to(torch.float32).mean())
    stack = torch.cat([stack, torch.ones([stack.shape[0], pad, stack.shape[2]], dtype=stack.dtype, device=stack.device) * mean_], dim=1)
    stack = torch.cat([stack, torch.ones([stack.shape[0], stack.shape[1], pad], dtype=stack.dtype, device=stack.device) * mean_], dim=2)
    stack = torch.cat([torch.ones([stack.shape[0], pad, stack.shape[2]], dtype=stack.dtype, device=stack.device) * mean_, stack], dim=1)
    stack = torch.cat([torch.ones([stack.shape[0], stack.shape[1], pad], dtype=stack.dtype, device=stack.device) * mean_, stack], dim=2)
    return stack


def shift_axis(tensor, axis):
    perm = [axis, (axis + 1) % 3, (axis + 2) % 3]  # 軸の順番をシフト
    tensor = tensor.permute(*perm)
    return tensor


def remove_pad(pred: torch.Tensor, pad: int):
    pred = pred[..., pad:-pad, pad:-pad]
    return pred


def cutout_chip(img, stack_shape, stride, img_size, edge):
    chip = []
    xy_indexs = []

    x1_list = np.arange(0, stack_shape[-2] + 1, stride)
    y1_list = np.arange(0, stack_shape[-1] + 1, stride)

    for y1 in y1_list:
        for x1 in x1_list:
            x2 = x1 + img_size
            y2 = y1 + img_size
            chip.append(img[..., x1:x2, y1:y2])
            xy_indexs.append([x1 + edge, x2 - edge, y1 + edge, y2 - edge])
    return chip, xy_indexs


def infer_each_z(model, img, stack_shape):
    img = img.to("cuda:0")
    img = add_pad(img[0], cfg.image_size // 2)[None]

    chip, xy_indexs = cutout_chip(img, stack_shape, cfg.stride, cfg.image_size, cfg.drop_egde_pixel)

    preds = model.forward(torch.cat(chip)).to(device=0)
    preds = remove_pad(preds, cfg.drop_egde_pixel)

    pred = torch.zeros_like(img[:, 0], dtype=torch.float32, device=img.device)
    count = torch.zeros_like(img[:, 0], dtype=torch.float32, device=img.device)
    for i, (x1, x2, y1, y2) in enumerate(xy_indexs):
        pred[..., x1:x2, y1:y2] += preds[i]
        count[..., x1:x2, y1:y2] += 1
    pred /= count
    pred = remove_pad(pred, cfg.image_size // 2)

    pred = (pred[0] * 255).to(torch.uint8).cpu()
    return pred


def get_output(model, stack_path):
    os.makedirs("/kaggle/working/output", exist_ok=True)
    kidney = stack_path.split("/")[-1].split(".")[0]

    for axis in [0, 1, 2]:
        stack = torch.tensor(np.load(stack_path))
        stack = shift_axis(stack, axis)

        preds = torch.zeros_like(stack, dtype=torch.uint8)

        dataset = InferenceDataset(stack, cfg.in_chans)
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=2)
        for img, z_ in tqdm(dataloader):  # img=(1,C,H,W)
            # if z_ == 400:
            #     break
            pred = infer_each_z(model, img, stack.shape)
            preds[z_] = pred

        preds = shift_axis(preds, -axis)
        np.save(f"/kaggle/working/output/{kidney}_{axis}.npy", preds)
        del stack, preds, dataset, dataloader
        gc.collect()

In [9]:
model = load_model(cfg.model_path[0])
get_output(model, "/kaggle/working/dataset/stack_clipped/kidney_2_images.npy")

model_name Unet
backbone efficientnet-b0


100%|██████████| 2217/2217 [02:53<00:00, 12.76it/s]
100%|██████████| 1041/1041 [02:26<00:00,  7.12it/s]
100%|██████████| 1511/1511 [03:17<00:00,  7.66it/s]


In [10]:
stack_path = "/kaggle/working/dataset/stack_clipped/kidney_2_images.npy"
stack = torch.tensor(np.load(stack_path))

In [None]:
dataset = InferenceDataset(stack, cfg.in_chans)

In [None]:
def rle_encode(mask):
    pixel = mask.flatten()
    pixel = np.concatenate([[0], pixel, [0]])
    run = np.where(pixel[1:] != pixel[:-1])[0] + 1
    run[1::2] -= run[::2]
    rle = " ".join(str(r) for r in run)
    if rle == "":
        rle = "1 0"
    return rle


def get_id(img_path):
    id = img_path.split("/")[-3:]
    id.pop(1)
    id = "_".join(id)
    return id[:-4]


def get_ids(img_paths):
    ids = []
    for img_path in img_paths:
        ids.append(get_id(img_path))
    return ids


img_paths = sorted(glob("/kaggle/input/blood-vessel-segmentation/train/kidney_2/images/*.tif"))
ids = get_ids(img_paths)

In [None]:
####################################
TH = [x.flatten().numpy() for x in outputs]
TH = np.concatenate(TH)
index = -int(len(TH) * cfg.th_percentile)
TH: int = np.partition(TH, index)[index]
print(TH)

####################################
submission_df = []
debug_count = 0
for index in range(len(ids)):
    id = ids[index]
    i = 0
    for x in outputs:
        if index >= len(x):
            index -= len(x)
            i += 1
        else:
            break
    mask_pred = (outputs[i][index] > TH).numpy()
    ####################################

    rle = rle_encode(mask_pred)

    submission_df.append(
        pd.DataFrame(
            data={
                "id": id,
                "rle": rle,
            },
            index=[0],
        )
    )

submission_df = pd.concat(submission_df)
submission_df.to_csv("submission.csv", index=False)
submission_df.head(6)