In [None]:
import sys, os
from pathlib import Path
ROOT_DIR = Path(os.path.abspath(os.path.join(os.getcwd(), "..")))
BASE_DIR = ROOT_DIR / "pytorch-lightning"
sys.path.append(ROOT_DIR)
sys.path.append(BASE_DIR)

In [None]:
INPUT_DIR = BASE_DIR / "input"
OUTPUT_DIR = ROOT_DIR / "working"
CONFIG_DIR = BASE_DIR / "config"

COMPETITION_DATA_DIR = INPUT_DIR / "hubmap-organ-segmentation"

CONFIG_YAML_PATH = CONFIG_DIR / "default.yaml"

In [None]:
from typing import Any, Callable, Dict, List, Union, Optional, Tuple, Mapping
import matplotlib.pyplot as plt

import yaml
import logging 
import hashlib
import json
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import rasterio
import numpy as np
import pandas as pd
import albumentations as A
import pytorch_lightning as pl
import segmentation_models_pytorch as smp
from ast import literal_eval
from multimethod import multimethod
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm
from rasterio.windows import Window
from torchvision import transforms
from albumentations.pytorch import ToTensorV2
from segmentation_models_pytorch.base import modules as md
from fastai.vision.all import PixelShuffle_ICNR

from timm.optim import create_optimizer_v2
from pl_bolts.optimizers import lr_scheduler
from losses_metrics import SymmetricLovaszLoss, Dice_soft, Dice_threshold, Dice_soft_func, Dice_threshold_func

from segmentation_models_pytorch.decoders.deeplabv3.decoder import ASPP
from segmentation_models_pytorch.decoders.fpn.decoder import FPNDecoder
from segmentation_models_pytorch.decoders.unetplusplus.decoder import CenterBlock
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base import (
    SegmentationModel,
    SegmentationHead,
    ClassificationHead,
)

In [None]:
bs = 64
sz = 512    # the size of tiles
reduce = 4  # reduce the original images by 4 times
TH = 0.225  # threshold for positive predictions
DATA = './input/hubmap-organ-segmentation/test_images/'
TRAIN_CSV = "./input/hubmap-organ-segmentation/train.csv"
TEST_CSV = './input/hubmap-organ-segmentation/test.csv'
# MODELS = [f'./input/hubmap-models/hubmap_models/run_0/model_{i}.pth' for i in range(4)] + [f'../input/hubmap-models/hubmap_models/run_1/model_{i}.pth' for i in range(4)] + [f'../input/training-fastai-baseline/model_{i}.pth' for i in range(4)]
df_sample = pd.read_csv('./input/hubmap-organ-segmentation/sample_submission.csv')
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

In [None]:
# https://www.kaggle.com/datasets/thedevastator/hubmap-2022-256x256
mean = np.array([0.7720342, 0.74582646, 0.76392896])
std = np.array([0.24745085, 0.26182273, 0.25782376])

s_th = 40  #saturation blancking threshold
p_th = 1000*(sz//256)**2 #threshold for the minimum number of pixels

In [None]:
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

In [None]:
def print_args(args, printer=logging.info):
    printer("==========       args      =============")
    for arg, content in args.__dict__.items():
        printer("{}:{}".format(arg, content))
    printer("==========     args END    =============")


class EasyConfig(dict):
    def __getattr__(self, key: str) -> Any:
        if key not in self:
            raise AttributeError(key)
        return self[key]

    def __setattr__(self, key: str, value: Any) -> None:
        self[key] = value

    def __delattr__(self, key: str) -> None:
        del self[key]

    def load(self, fpath: str, *, recursive: bool = False) -> None:
        """load cfg from yaml

        Args:
            fpath (str): path to the yaml file
            recursive (bool, optional): recursily load its parent defaul yaml files. Defaults to False.
        """
        if not os.path.exists(fpath):
            raise FileNotFoundError(fpath)
        fpaths = [fpath]
        if recursive:
            extension = os.path.splitext(fpath)[1]
            while os.path.dirname(fpath) != fpath:
                fpath = os.path.dirname(fpath)
                fpaths.append(os.path.join(fpath, 'default' + extension))
        for fpath in reversed(fpaths):
            if os.path.exists(fpath):
                with open(fpath) as f:
                    self.update(yaml.safe_load(f))

    def reload(self, fpath: str, *, recursive: bool = False) -> None:
        self.clear()
        self.load(fpath, recursive=recursive)

    # mutimethod makes python supports function overloading
    @multimethod
    def update(self, other: Dict) -> None:
        for key, value in other.items():
            if isinstance(value, dict):
                if key not in self or not isinstance(self[key], EasyConfig):
                    self[key] = EasyConfig()
                # recursively update
                self[key].update(value)
            else:
                self[key] = value

    @multimethod
    def update(self, opts: Union[List, Tuple]) -> None:
        index = 0
        while index < len(opts):
            opt = opts[index]
            if opt.startswith('--'):
                opt = opt[2:]
            if '=' in opt:
                key, value = opt.split('=', 1)
                index += 1
            else:
                key, value = opt, opts[index + 1]
                index += 2
            current = self
            subkeys = key.split('.')
            try:
                value = literal_eval(value)
            except:
                pass
            for subkey in subkeys[:-1]:
                current = current.setdefault(subkey, EasyConfig())
            current[subkeys[-1]] = value

    def dict(self) -> Dict[str, Any]:
        configs = dict()
        for key, value in self.items():
            if isinstance(value, EasyConfig):
                value = value.dict()
            configs[key] = value
        return configs

    def hash(self) -> str:
        buffer = json.dumps(self.dict(), sort_keys=True)
        return hashlib.sha256(buffer.encode()).hexdigest()

    def __str__(self) -> str:
        texts = []
        for key, value in self.items():
            if isinstance(value, EasyConfig):
                seperator = '\n'
            else:
                seperator = ' '
            text = key + ':' + seperator + str(value)
            lines = text.split('\n')
            for k, line in enumerate(lines[1:]):
                lines[k + 1] = (' ' * 2) + line
            texts.extend(lines)
        return '\n'.join(texts)


In [None]:
class HuBMAPDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, sz=512, reduce=4, transform=None):
        self.dataframe = dataframe
        self.reduce = reduce
        self.sz = reduce*sz
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        data = rasterio.open(os.path.join(DATA,f'{self.dataframe.loc[idx, ["id"]][0]}.tiff'), transform = rasterio.Affine(1, 0, 0, 0, 1, 0), num_threads='all_cpus')
        # some images have issues with their format 
        # and must be saved correctly before reading with rasterio
        if data.count != 3:
            subdatasets = data.subdatasets
            layers = []
            if len(subdatasets) > 0:
                for i, subdataset in enumerate(subdatasets, 0):
                    layers.append(rasterio.open(subdataset))
        pad0 = (self.sz - data.shape[0]%self.sz)%self.sz
        pad1 = (self.sz - data.shape[1]%self.sz)%self.sz
        n0max = (data.shape[0] + pad0)//self.sz
        n1max = (data.shape[1] + pad1)//self.sz

        if n0max*n1max !=1:
            merged = []
            for index, i in enumerate(range(n0max*n1max)):
                n0,n1 = i//n1max, i%n1max

                x0,y0 = -pad0//2 + n0*self.sz, -pad1//2 + n1*self.sz

                p00,p01 = max(0,x0), min(x0+self.sz,data.shape[0])
                p10,p11 = max(0,y0), min(y0+self.sz,data.shape[1])
                img = np.zeros((self.sz,self.sz,3),np.uint8)
                # mapping the loade region to the tile
                if data.count == 3:
                    img[(p00-x0):(p01-x0),(p10-y0):(p11-y0)] = np.moveaxis(data.read([1,2,3],
                        window=Window.from_slices((p00,p01),(p10,p11))), 0, -1)
                else:
                    for i,layer in enumerate(layers):
                        img[(p00-x0):(p01-x0),(p10-y0):(p11-y0),i] = layer.read(1,window=Window.from_slices((p00,p01),(p10,p11)))
                
                if self.reduce != 1:
                    img = cv2.resize(img,(self.sz//reduce,self.sz//reduce), interpolation = cv2.INTER_AREA)
                if self.transform is not None:
                    img = self.transform(image=(img/255.0))
                    # img = self.transform(image=(img/255.0 - mean)/std)
                merged.append(img["image"].permute(1,2,0))
            merged = torch.vstack([torch.hstack([merged[0],merged[1]]), torch.hstack([merged[2],merged[3]])])
            return (merged, 0, pad0, pad1, n0max, n1max)

        else:
            img = cv2.resize(np.asarray(data.read([1,2,3])).transpose(1,2,0),(self.sz//reduce,self.sz//reduce), interpolation = cv2.INTER_AREA)

            if self.transform is not None:
                img = self.transform(image=(img/255.0))
                # img = self.transform(image=(img/255.0 - mean)/std)
            img = img["image"].permute(1,2,0)
            merged = torch.vstack([torch.hstack([img,img]), torch.hstack([img,img])])
            return (merged, -1, pad0, pad1, n0max, n1max)


In [None]:
class LitDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_frame: pd.DataFrame,
        spatial_size: int,
        batch_size: int,
        num_workers: int,
    ):
        super().__init__()

        self.save_hyperparameters(ignore="data_frame")

        self.data_frame = data_frame

        self.train_transform, self.val_transform, self.test_transform = self._init_transforms()

    def _init_transforms(self) -> Tuple[Callable, Callable, Callable]:
        spatial_size = (self.hparams.spatial_size, self.hparams.spatial_size)
        train_transform = A.Compose([A.HorizontalFlip(),
                                     A.VerticalFlip(),
                                     A.RandomRotate90(),
                                     A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, border_mode=cv2.BORDER_REFLECT),
                                     A.OneOf([A.OpticalDistortion(p=0.3),
                                              A.GridDistortion(p=.1),A.PiecewiseAffine(p=0.3)], p=0.3),
                                     A.OneOf([A.HueSaturationValue(10,15,10),
                                              A.CLAHE(clip_limit=2),
                                              A.RandomBrightnessContrast()], p=0.3),
                                     A.Resize(height=spatial_size[0],width=spatial_size[1]),
                                     ToTensorV2()])

        val_transform = A.Compose([A.Resize(height=spatial_size[0],width=spatial_size[1]),
                                   ToTensorV2()])

        test_transform = A.Compose([A.Resize(height=spatial_size[0],width=spatial_size[1]),
                                   ToTensorV2()])

        return train_transform, val_transform, test_transform

    def setup(self, stage: str = None):
        if stage == "fit" or stage is None:
            pass

        if stage == "test" or stage is None:
            self.test_dataset = self._dataset(self.data_frame, transform=self.test_transform)

    def _dataset(self, df: pd.DataFrame, transform: Callable) -> Dataset:
        return HuBMAPDataset(dataframe=df, transform=transform)

    def train_dataloader(self) -> DataLoader:
        return self._dataloader(self.train_dataset, train=True, val=True)

    def val_dataloader(self) -> DataLoader:
        return self._dataloader(self.val_dataset, train=False, val=True)

    def test_dataloader(self) -> DataLoader:
        return self._dataloader(self.test_dataset)

    def _dataloader(self, dataset: Dataset, train: bool = False, val: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=True if train and val else False,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            drop_last=True if train and val else False,
        )

In [None]:
def build_from_config(cfg):
    if cfg.architecture=='unetplusplus-with-aspp-fpn':
        return UnetPlusPlus_with_ASPP_FPN(cfg.backbone, segmentation_channels=cfg.segmentation_channels, atrous_rates=tuple(cfg.atrous_rates), classes=cfg.classes)
    else:
        return getattr(smp, cfg.architecture)(cfg.backbone)

class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
            use_batchnorm=True,
            attention_type=None,
    ):
        super().__init__()
        self.conv1 = md.Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = md.Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = md.Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = md.Attention(attention_type, in_channels=out_channels)

        self.shuf = PixelShuffle_ICNR(in_channels, in_channels*2//2)

    def forward(self, x, skip=None):
        # x = F.interpolate(x, scale_factor=2, mode="nearest")

        x = self.shuf(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x

class UnetPlusPlus_with_ASPP_FPN_Decoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        aspp_out_channels,
        segmentation_channels=32,
        atrous_rates=(6,12,18),
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )
        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        self.aspp = ASPP(encoder_channels[0], aspp_out_channels, atrous_rates)
        self.fpn = FPNDecoder(tuple(list((encoder_channels[0],)+decoder_channels[:-1])[::-1]), segmentation_channels=segmentation_channels, encoder_depth=n_blocks, merge_policy="cat")

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        self.in_channels = [head_channels] + list(decoder_channels[:-1])
        self.skip_channels = list(encoder_channels[1:]) + [0]
        self.out_channels = decoder_channels
        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)

        blocks = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(layer_idx + 1):
                if depth_idx == 0:
                    in_ch = self.in_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1)
                    out_ch = self.out_channels[layer_idx]
                else:
                    out_ch = self.skip_channels[layer_idx]
                    skip_ch = self.skip_channels[layer_idx] * (layer_idx + 1 - depth_idx)
                    in_ch = self.skip_channels[layer_idx - 1]
                blocks[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
        blocks[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(
            self.in_channels[-1], 0, self.out_channels[-1], **kwargs
        )
        self.blocks = nn.ModuleDict(blocks)
        self.depth = len(self.in_channels) - 1

    def forward(self, *features):
        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder
        features = (self.aspp(features[0]),)+features[1:]
        # start building dense connections
        dense_x = {}
        for layer_idx in range(len(self.in_channels) - 1):
            for depth_idx in range(self.depth - layer_idx):
                if layer_idx == 0:
                    output = self.blocks[f"x_{depth_idx}_{depth_idx}"](features[depth_idx], features[depth_idx + 1])
                    dense_x[f"x_{depth_idx}_{depth_idx}"] = output
                else:
                    dense_l_i = depth_idx + layer_idx
                    cat_features = [dense_x[f"x_{idx}_{dense_l_i}"] for idx in range(depth_idx + 1, dense_l_i + 1)]
                    cat_features = torch.cat(cat_features + [features[dense_l_i + 1]], dim=1)
                    dense_x[f"x_{depth_idx}_{dense_l_i}"] = self.blocks[f"x_{depth_idx}_{dense_l_i}"](
                        dense_x[f"x_{depth_idx}_{dense_l_i-1}"], cat_features
                    )
        dense_x[f"x_{0}_{self.depth}"] = self.blocks[f"x_{0}_{self.depth}"](dense_x[f"x_{0}_{self.depth-1}"])
        fpn_input = [features[0]]+[dense_x[f"x_{0}_{i}"] for i in range(self.depth)]
        fpn_out = self.fpn(*fpn_input[::-1])
        fpn_out = F.interpolate(fpn_out,scale_factor=4,mode='bilinear')
        return torch.cat((dense_x[f"x_{0}_{self.depth}"], fpn_out), dim=1)

class UnetPlusPlus_with_ASPP_FPN(SegmentationModel):
    def __init__(
            self,
            encoder_name: str = "resnet34",
            encoder_depth: int = 5,
            encoder_weights: Optional[str] = "imagenet",
            decoder_use_batchnorm: bool = True,
            decoder_channels: List[int] = (256, 128, 64, 32, 16),
            decoder_attention_type: Optional[str] = None,
            segmentation_channels: int = 32,
            in_channels: int = 3,
            classes: int = 2,
            activation: Optional[Union[str, callable]] = None,
            aux_params: Optional[dict] = None,
            atrous_rates: Tuple = (6, 12 ,18),
        ):
            super().__init__()

            self.encoder = get_encoder(
                encoder_name,
                in_channels=in_channels,
                depth=encoder_depth,
                weights=encoder_weights,
            )

            self.decoder = UnetPlusPlus_with_ASPP_FPN_Decoder(
                encoder_channels=self.encoder.out_channels,
                decoder_channels=decoder_channels,
                aspp_out_channels=self.encoder.out_channels[-1],
                atrous_rates=atrous_rates,
                segmentation_channels=segmentation_channels,
                n_blocks=encoder_depth,
                use_batchnorm=decoder_use_batchnorm,
                center=True if encoder_name.startswith("vgg") else False,
                attention_type=decoder_attention_type,
            )

            self.segmentation_head = SegmentationHead(
                in_channels=decoder_channels[-1]+segmentation_channels*4,
                out_channels=classes,
                activation=activation,
                kernel_size=3,
            )

            if aux_params is not None:
                self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
            else:
                self.classification_head = None

            self.name = "unetplusplus-with-aspp-fpn-{}".format(encoder_name)
            self.initialize()

class LitModule(pl.LightningModule):
    def __init__(
        self,
        cfg,
    ):
        super().__init__()

        self.batch_size = cfg.data.batch_size
        self.cfg = cfg
        self.cfg_optimizer = self.cfg.train.optimizer
        self.cfg_scheduler = self.cfg.train.scheduler
        self.cfg_scheduler.epochs = cfg.train.epochs
        self.learning_rate = self.cfg.train.optimizer.learning_rate
        self.weight_decay = self.cfg.train.optimizer.weight_decay

        self.save_hyperparameters()

        self.model = build_from_config(cfg.model)

        self.loss_fn = self._init_loss_fn()

        # self.dice_soft, self.dice_th = self._init_metric_fn()

    def _init_loss_fn(self):
        return SymmetricLovaszLoss("binary")

    # def _init_metric_fn(self):
    #     return Dice_soft(), Dice_threshold()

    def configure_optimizers(self):
        # Setup the optimizer
        optimizer = torch.optim.Adam(
            params=self.parameters(), lr=self.cfg_optimizer.learning_rate, weight_decay=self.cfg_optimizer.weight_decay)

        # Setup the scheduler
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=self.cfg_scheduler.step_size,
                                                    gamma=self.cfg_scheduler.gamma)

        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        return self.model(images)

    def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        return self._step(batch, "train")

    def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        return self._step(batch, "val")

    def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        return self._step(batch, "test")

    def predict_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
        return self._step(batch, "test")

    def _step(self, batch: Dict[str, torch.Tensor], step: str) -> torch.Tensor:
        images, masks = batch["image"].float(), batch["mask"].int()
        outputs = self(images)

        loss = self.loss_fn(outputs, masks)
        # dice_soft = self.dice_soft(outputs, masks)
        # dice_th = self.dice_th(outputs, masks)
        dice_soft = Dice_soft_func(outputs, masks)
        dice_th, best_th = Dice_threshold_func(outputs, masks)

        self.log(f"{step}_loss", loss, sync_dist=True)
        self.log(f"{step}_dice_soft", dice_soft, sync_dist=True)
        self.log(f"{step}_dice_th", dice_th, sync_dist=True)

        if step == "test":
            self.log(f"{step}_best_th", best_th)

        return loss

    @classmethod
    def load_eval_checkpoint(cls, checkpoint_path: str, device: str) -> nn.Module:
        module = cls.load_from_checkpoint(checkpoint_path=checkpoint_path).to(device)
        module.eval()

        return module

    def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
        return super().load_state_dict(state_dict, strict)

In [None]:
cfg = EasyConfig()
cfg.load(CONFIG_YAML_PATH)

In [None]:
# data_module = LitDataModule(
#     data_frame=pd.read_csv("./input/hubmap-organ-segmentation/test.csv"),
#     spatial_size=512,
#     batch_size=1,
#     num_workers=1,
#     )

# data_module.setup()

module = LitModule(cfg).to(device)
module.load_state_dict(torch.load('./working/UnetPlusPlus_with_ASPP_FPN_efficientnet-b6_512_0.pth'))
module.eval()

# test_dl = data_module.test_dataloader()

In [None]:
class EasyDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, size=512, limit_size=512*2):
        self.dataframe = dataframe
        self.size = size
        self.limit_size = limit_size
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img = cv2.imread(os.path.join(DATA,f'{self.dataframe.loc[idx, ["id"]][0]}.tiff'))
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h = img.shape[0]
        w = img.shape[1]
        data = dict()
        if (h < self.limit_size) or (w < self.limit_size):
            flag = -1
            img = cv2.resize(img, (self.size,self.size))
            result = np.vstack(np.hstack(img,img),np.hstack(img,img))
        else:
            flag = 1
            img = cv2.resize(img, (self.limit_size,self.limit_size))
        img = ToTensorV2().apply(img)

        return {"image":img, "h":h, "w":w, "flag": flag}

In [None]:
train_df = pd.read_csv(TRAIN_CSV)

In [None]:
img = cv2.imread("./input/hubmap-organ-segmentation/test_images/62.tiff")
tmp_mask = rle2mask(train_df.loc[302, "rle"], (3000,3000))
plt.imshow(img)
plt.imshow(tmp_mask, alpha=0.2)

In [None]:
img = cv2.imread("./input/hubmap-organ-segmentation/test_images/10044.tiff")
tmp_mask = rle2mask(train_df.loc[0, "rle"], (3000,3000))
plt.imshow(img)
plt.imshow(tmp_mask, alpha=0.2)

In [None]:
test_ds = EasyDataset(pd.read_csv(TEST_CSV))
test_dl = DataLoader(test_ds, 1, False, num_workers=1)

In [None]:
import gc
names,preds = [],[]

tmp_pd = pd.read_csv(TEST_CSV)

TH = 0.65
index = 0
for i, batch in enumerate(test_dl):
    if index == 3:
        break
    else:
        small_batch = batch["image"].squeeze(0).permute(1,2,0).reshape(2,512,2,512,3).permute(0,2,1,3,4).contiguous().view(-1, 512, 512,3)
        one = small_batch.view(2,2,512,512,3).permute(0,2,1,3,4).reshape(1024,1024,3)
        one = cv2.resize(one.numpy(), (batch["h"].numpy()[0], batch["w"].numpy()[0]))
        mask = module(small_batch.permute(0,3,1,2).float().to(device))
        mask = torch.nn.Sigmoid()(mask)
        mask = mask.view(2,2,512,512,1).permute(0,2,1,3,4).reshape(1024,1024)
        mask[mask<=TH]=0
        mask[mask>TH]=1
        mask = cv2.resize(mask.cpu().detach().numpy()*255, (batch["h"].numpy()[0], batch["w"].numpy()[0]))
        plt.imshow(one, vmin=0, vmax=255)
        plt.imshow(mask, alpha=0.2)

        rle = mask2rle(mask)
        names.append(tmp_pd.loc[i, "id"])
        preds.append(rle)

        del mask, small_batch, one
        gc.collect()
        index += 1

In [None]:
df = pd.DataFrame({'id':names,'rle':preds})
df.to_csv('submission.csv',index=False)