In [None]:
import sys
import ttach as tta

In [None]:
# pkg_dir = "/kaggle/input/sumo-sennet-2024-02-10-19-53-46"
pkg_dir = "/home/clay/research/kaggle/sennet"


# DATASET_FOLDER = "/kaggle/input/blood-vessel-segmentation"
DATASET_FOLDER = "/home/clay/research/kaggle/sennet/data/blood-vessel-segmentation"

In [None]:
sys.path.append(f"{pkg_dir}/src")

In [None]:
sys.path

In [None]:
from sennet.core.submission_utils import load_model_from_dir
from sennet.environments.constants import MODEL_OUT_DIR


print(f"{MODEL_OUT_DIR=}")

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
from glob import glob
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import segmentation_models_pytorch as smp
import gc
import monai

import re

In [None]:
import yaml
import json


submission_path = f"{pkg_dir}/configs/submission.yaml"
with open(submission_path, "rb") as f:
    submission_cfg = yaml.load(f, yaml.FullLoader)
print(json.dumps(submission_cfg, indent=4))

In [None]:
tta_models = []
weights = []

use_top_only = True #True
use_best = False # False
folds2predict = [0]
# folds2predict = [1]

use_tta = True

TH = 0.01

is_test = True
# is_test = not len(glob(os.path.join(DATASET_FOLDER, "test", "*", "*", "*.tif"))) == 6
# is_test = not len(glob(os.path.join(DATASET_FOLDER, "train", "*", "*", "*.tif"))) == 6

In [None]:
# def rename_keys(original_dict, pattern):
#     new_dict = {}
    
#     for old_key, value in original_dict.items():
#         new_key = re.sub(pattern, '', old_key)
        
#         new_dict[new_key] = value
    
#     return new_dict


def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    rle = ' '.join(str(x) for x in runs)
    if rle=='':
        rle = '1 0'
    return rle


def to_device(x: torch.Tensor, cuda_id: int = 0) -> torch.Tensor:
    return x.cuda(cuda_id) if torch.cuda.is_available() else x


# def load_jit_model(model_path: str, cuda_id: int = 0) -> torch.nn.Module:
#     model = torch.jit.load(
#         model_path,
#         map_location=f"cuda:{cuda_id}" if torch.cuda.is_available() else "cpu",
#     )
#     return model

## Dataset

In [None]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, in_channels=3, is_test=False):
        self.window = in_channels // 2
        self.is_test = is_test
        self.ids = []

        self.data_tensor = self.load_volume(dataset)
        self.shape_orig = self.data_tensor.shape

        padding = (
            (self.window, self.window),
        ) * self.data_tensor.ndim

        self.padding = tuple(
            (max(0, before), max(0, after)) for (before, after) in padding
        )
        self.data_tensor = np.pad(
            self.data_tensor, padding, mode="constant", constant_values=0
        )

    def __len__(self):
        return sum(self.shape_orig) if self.is_test else self.shape_orig[0]

    def normilize(self, image):
        image = (image - self.xmin) / (self.xmax - self.xmin)
        # image = np.clip(image, 0, 1)
        image = (image - 0.5) / 0.235
        
        return image.astype(np.float32)
    
    @staticmethod
    # def norm_by_percentile(volume, low=10, high=99.8):
    def norm_by_percentile(volume, low=1, high=99):
        # assert False, f"{volume.shape=}"
        channel_margin = 0.2
        channel_lb = int(channel_margin * volume.shape[0])
        channel_ub = int((1 - channel_margin) * volume.shape[0])
        
        xmin = np.percentile(volume[channel_lb: channel_ub], low)
        xmax = np.max([np.percentile(volume[channel_lb: channel_ub], high), 1])
        print(f"{xmin=}, {xmax=}")
        return xmin, xmax

    def load_volume(self, dataset):
        path = os.path.join(dataset, "images", "*.tif")
        dataset = sorted(glob(path))
        for p_img in tqdm(dataset):
            path_ = p_img.split(os.path.sep)
            slice_id, _ = os.path.splitext(path_[-1])
            self.ids.append(f"{path_[-3]}_{slice_id}")
            
        volume = None

        for z, path in enumerate(tqdm(dataset)):
            image = cv2.imread(path, cv2.IMREAD_ANYDEPTH)
            image = np.array(image, dtype=np.uint16)
            if volume is None:
                volume = np.zeros((len(dataset), *image.shape[-2:]), dtype=np.uint16)
            volume[z] = image
        self.xmin, self.xmax = self.norm_by_percentile(volume)
        return volume

    def __getitem__(self, idx):
        # Determine which axis to sample from based on the index
        if idx < self.shape_orig[0]:
            idx = idx + self.window
            slice_data = self.normilize(
                self.data_tensor[
                    idx - self.window : 1 + idx + self.window, :, :
                ].transpose(1, 2, 0)[self.window:-self.window, self.window:-self.window, :]
            )
            axis = "X"
            idx -= 1

        elif idx < self.shape_orig[0] + self.shape_orig[1]:
            idx -= (self.shape_orig[0] - self.window)
            slice_data = self.normilize(
                self.data_tensor[
                    :, idx - self.window : 1 + idx + self.window, :
                ].transpose(0, 2, 1)[self.window:-self.window, self.window:-self.window, :]
            )
            axis = "Y"
            idx -= 1

            
        else:
            idx -= (
                self.shape_orig[0]
                + self.shape_orig[1]
                - self.window
            ) 
            
            slice_data = self.normilize(
                self.data_tensor[
                    :, :, idx - self.window : 1 + idx + self.window
                ][self.window:-self.window, self.window:-self.window, :]
            )
            axis = "Z"
            idx -= 1

        slice_data = torch.tensor(slice_data.transpose(2, 0, 1))

        return {
            "slice": slice_data,
            "slice_index": idx,
            "axis": axis
        }

## Model

In [None]:
# def find_highest_score_filename(file_list):
#     highest_score = float('-inf')
#     highest_score_filename = None

#     for filename in file_list:
#         # Extract the score from the filename using regular expression
#         match = re.search(r'dice_(\d+\.\d+)', filename)
#         if match:
#             current_score = float(match.group(1))
#             if current_score > highest_score:
#                 highest_score = current_score
#                 highest_score_filename = filename

#     return highest_score_filename

In [None]:
class ModelWrapper(torch.nn.Module):
    def __init__(self, our_model):
        torch.nn.Module.__init__(self)
        self.model = our_model

    def forward(self, img):
        # print(f"{img.shape=}")
        # print(f"{img}")
        res = self.model.predict(img.unsqueeze(1))
        return res.pred

In [None]:
in_chans = []

for model_config in tqdm(submission_cfg["predictors"]["models"]):
    for fold in folds2predict:
        if use_top_only:
            # model_path = sorted(glob(f"/kaggle/input/senet-hoa/{model_config[3]}/{fold}/checkpoints/epoch*.ckpt"))[-1]
            # print(f"use_top_only, loading: {model_path}")
            # state_dict = rename_keys(torch.load(model_path, map_location="cpu")["state_dict"], "net.")
            # model = to_device(smp.create_model(arch=model_config[0], encoder_name=model_config[1], in_channels=model_config[2], encoder_weights=None, decoder_attention_type=model_config[5]))
            # model.load_state_dict(state_dict)
            # model.eval()
            cfg, raw_model = load_model_from_dir(MODEL_OUT_DIR / model_config)
            raw_model = raw_model.cuda().eval()
            model_in_chans = raw_model.kw["in_channels"]
            model = ModelWrapper(raw_model)
            
            if use_tta:
                tta_models.append(tta.SegmentationTTAWrapper(model, tta.aliases.flip_transform(), merge_mode='mean')) #flip_transform d4_transform
            else:
                tta_models.append(model)
            
            weights.append(1.0)
            in_chans.append(model_in_chans)
            
        elif use_best:
            raise NotImplementedError("don't use this")
        else:
            raise NotImplementedError("don't use this")


print(f"{in_chans=}")

In [None]:
# del state_dict

## Inference

In [None]:
datasets = ["/home/clay/research/kaggle/sennet/data/blood-vessel-segmentation/train/kidney_3_dense/"]
# datasets = sorted(glob(f"{DATASET_FOLDER}/test/*"))[::-1]
print(f"{datasets=}")

In [None]:
rles, ids = [], []
with torch.no_grad():
    for dataset in datasets:
#         test_dataset[2]["slice"][1,...].unsqueeze(0).shape
        test_dataset = BuildDataset(dataset, is_test=is_test, in_channels=3) # TODO: refactor this
        test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4, shuffle=False, pin_memory=False)

        y_preds = np.zeros(test_dataset.shape_orig, dtype=np.half)
        ids += test_dataset.ids

        pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f'Inference {dataset}')
        for step, batch in pbar:
            images = to_device(batch["slice"])
            
            # print(images.shape)
            axis = batch["axis"][0]
            idx = batch["slice_index"].numpy()[0]

            preds = 0
            for tta_model, weight, in_chan in zip(tta_models, weights, in_chans):
                preds += weight * monai.inferers.sliding_window_inference(
                    inputs=images if in_chan != 1 else images[:, 1,...].unsqueeze(0),
                    predictor=tta_model,
                    sw_batch_size=8,
                    roi_size=(800, 800),
                    # overlap=0.25,
                    overlap=0.5,
                    # overlap=0.9,
                    padding_mode="reflect",
                    mode="gaussian",
                    # mode="constant",
                    sw_device="cuda",
                    device="cuda",
                    progress=False,
                )
            if axis == "X":
                y_preds[idx, :, :] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
            elif axis == "Y":
                y_preds[:, idx, :] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
            elif axis == "Z":
                y_preds[:, :, idx] += ((preds / sum(weights)).squeeze().sigmoid().cpu().numpy() / 3.).astype(np.half)
        
       # y_preds = cc3d.dust(
         #       (y_preds > TH).astype(np.uint8),
             #   connectivity=18,
            #    threshold=-1,
           #     in_place=False
         #   )
        
        for pred in y_preds:
            rles.append(rle_encode((pred > TH).astype(np.uint8)))
            # rles.append(rle_encode((pred)))

        del test_dataset, test_loader, y_preds
        gc.collect()

In [None]:
d = BuildDataset(datasets[0], is_test=is_test, in_channels=3)
len(d)

In [None]:
d[0]["slice"][0]

In [None]:
submission = pd.DataFrame.from_dict({
    "id": ids,
    "rle": rles
})

submission["height"] = 1706
submission["width"] = 1510

submission.to_csv("submission.csv", index=False)
submission.to_csv("/opt/kaggle/sennet/data_dumps/predicted/igor/kidney_3_dense/submission.csv", index=False)

In [None]:
submission