## summary

* 2.5d segmentation
    *  segmentation_models_pytorch 
    *  Unet
* use only 6 slices
* slide inference

In [None]:

from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, log_loss
import pickle
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import warnings
import sys
import pandas as pd
import os
import gc
import sys
import math
import time
import random
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter
import cv2

import scipy as sp
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from functools import partial

import argparse
import importlib
import torch
import torch.nn as nn
from torch.optim import Adam, SGD, AdamW
import torch.nn.functional as F


import datetime

In [None]:
sys.path.append('/kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4')
sys.path.append('/kaggle/input/efficientnet-pytorch/EfficientNet-PyTorch-master')
sys.path.append('/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master')
sys.path.append('/kaggle/input/segmentation-models-pytorch/segmentation_models.pytorch-master')


import segmentation_models_pytorch as smp

In [None]:
import numpy as np
from torch.utils.data import DataLoader, Dataset
import cv2
import torch
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

In [None]:
sys.path.append("/kaggle/input/resnet3d")

from resnet3d import generate_model

In [None]:
# fragment, model, threshold, image_size

## config

In [None]:
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

class CFG:
    
    in_chans = 16 # 65
    # ============== training cfg =============
    size = 256
    tile_size = 256
    stride = tile_size // 2

    batch_size = 24
    

    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

## helper

In [None]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    # pixels = (pixels >= thr).astype(int)
    
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

## dataset

In [None]:
def read_image(fragment_id):
    images = []

    # idxs = range(65)
    # mid = 65 // 2
    start = 20 # mid - CFG.in_chans // 2
    end = 40 # mid + CFG.in_chans // 2
    idxs = range(start, end)

    for i in tqdm(idxs):
        
        image = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/surface_volume/{i:02}.tif", 0)
        
        if np.abs(np.max(image) - np.min(image)) > 1e-3:
            image = (image - np.min(image)) / (np.max(image) - np.min(image))
        

        pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
        pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)

        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

        images.append(image)
    images = np.stack(images, axis=2)
    
    return images

In [None]:
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)

    # print(aug)
    return aug

class CustomDataset(Dataset):
    def __init__(self, images, cfg, labels=None, transform=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        self.transform = transform

    def __len__(self):
        # return len(self.xyxys)
        return len(self.images)

    def __getitem__(self, idx):
        # x1, y1, x2, y2 = self.xyxys[idx]
        image = self.images[idx]
        data = self.transform(image=image)
        image = data['image']
        # image = (image - 0.45)/0.225
        
        
        
        return image


In [None]:
def make_test_dataset(fragment_id):
    test_images = read_image(fragment_id)
    
    x1_list = list(range(0, test_images.shape[1]-CFG.tile_size+1, CFG.stride))
    y1_list = list(range(0, test_images.shape[0]-CFG.tile_size+1, CFG.stride))
    
    test_images_list = []
    xyxys = []
    for y1 in y1_list:
        for x1 in x1_list:
            y2 = y1 + CFG.tile_size
            x2 = x1 + CFG.tile_size
            
            test_images_list.append(test_images[y1:y2, x1:x2])
            xyxys.append((x1, y1, x2, y2))
    xyxys = np.stack(xyxys)
            
    test_dataset = CustomDataset(test_images_list, CFG, transform=get_transforms(data='valid', cfg=CFG))
    
    test_loader = DataLoader(test_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    return test_loader, xyxys

## model

In [None]:
class Decoder(nn.Module):
    def __init__(self, encoder_dims, upscale):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
                nn.BatchNorm2d(encoder_dims[i-1]),
                nn.ReLU(inplace=True)
            ) for i in range(1, len(encoder_dims))])

        self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
        self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")

    def forward(self, feature_maps):
        for i in range(len(feature_maps)-1, 0, -1):
            f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
            f = torch.cat([feature_maps[i-1], f_up], dim=1)
            f_down = self.convs[i-1](f)
            feature_maps[i-1] = f_down

        x = self.logit(feature_maps[0])
        mask = self.up(x)
        return mask


class SegModel(nn.Module):
    def __init__(self, model_depth):
        super().__init__()
        self.encoder = generate_model(model_depth=model_depth, n_input_channels=1)
        self.decoder = Decoder(encoder_dims=[64, 128, 256, 512], upscale=4)
        
    def forward(self, x):
        feat_maps = self.encoder(x)
        feat_maps_pooled = [torch.mean(f, dim=2) for f in feat_maps]
        pred_mask = self.decoder(feat_maps_pooled)
        return pred_mask
    
    def load_pretrained_weights(self, state_dict):
        # Convert 3 channel weights to single channel
        # ref - https://timm.fast.ai/models#Case-1:-When-the-number-of-input-channels-is-1
        conv1_weight = state_dict['conv1.weight']
        state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
        print(self.encoder.load_state_dict(state_dict, strict=False))

In [None]:
# class EnsembleModel:
#     def __init__(self):
#         self.models = []
#         self.thresholds = []
#         self.image_sizes = []

#     def __call__(self, x):

#         img_wh = x.shape[1]
        
#         outputs = []
#         for model, threshold in zip(self.models, self.thresholds):
#             for img_size in self.image_sizes:
#                 if img_wh == img_size:
#                     out = (torch.sigmoid(model(x)).to('cpu').numpy()>=threshold).astype(np.float32)
#                     outputs.append(out)
        
#         avg_preds = np.mean(outputs, axis=0)
        
#         return avg_preds

#     def add_model(self, model, threshold, image_size):
#         self.models.append(model)
#         self.thresholds.append(threshold)
#         self.image_sizes.append(image_size)

# def build_ensemble_model():
#     model = EnsembleModel()
    
#     # "bce-dice-models/fold-1/depth-18/image-256/ckpts/ # resnet34_3d_seg_best_0.900.pt
#     folds = ["fold-1", "fold-2", "fold-3"]
#     depths = ["depth-18", "depth-34"]
#     images = ["image-256", "image-512"]
#     for fold in folds:
#         for depth in depths:
#             for image in images:
#                 dir_path = f"bce-dice-models/{fold}/{depth}/{image}/ckpts/"
#                 # get .pt model from the dir_path
#                 model_path = dir_path + [file for file in os.listdir(dir_path) if file.endswith(".pt")][0]
#                 # get the model threshold: it is of the format resnet34_3d_seg_best_0.900.pt
#                 model_threshold = float(model_path.split("_")[-1].split(".")[0])
#                 model_path = dir_path + model_path
                
#                 model.add_model(SegModel().to(CFG.device).eval(), 0.5)
    
    
#     thresholds = np.array([0.45, 0.4, 0.5]) + 0.2
    
#     for fold in range(len(model_paths)):
#         # _model = build_model(CFG, weight=None)
#         _model = SegModel()
#         model_path = model_paths[fold]
#         _model.load_state_dict(torch.load(model_path))
#         _model.to(device)
#         _model.eval()
#         threshold = thresholds[fold]
        
#         model.add_model(_model, threshold)
    
#     return model

In [None]:
if mode == 'test':
    fragment_ids = sorted(os.listdir(CFG.comp_dataset_path + mode))
else:
    fragment_ids = [3]

In [None]:
# model = build_ensemble_model()

I have the following structure of models.
I want to 

```bash
├── fold-1
│   ├── depth-18
│   │   ├── image-256
│   │   │   ├── ckpts
│   │   │   │   └── resnet34_3d_seg_best_0.900.pt
│   │   │   ├── vesuvius-challenge-3d-resnet-training-step-2-with-augm.ipynb
│   │   └── image-512
│   │       ├── ckpts
│   │       │   └── resnet34_3d_seg_best_0.900.pt
│   └── depth-34
│       ├── image-256
│       │   ├── ckpts
│       │   │   └── resnet34_3d_seg_best_0.800.pt
│       └── image-512
│           ├── ckpts
│           │   └── resnet34_3d_seg_best_0.800.pt
├── fold-2
│   ├── depth-18
│   │   ├── image-256
│   │   │   ├── ckpts
│   │   │   │   └── resnet34_3d_seg_best_0.200.pt
│   │   └── image-512
│   │       ├── ckpts
│   │       │   └── resnet34_3d_seg_best_0.200.pt
│   └── depth-34
│       ├── image-256
│       │   ├── ckpts
│       │   │   └── resnet34_3d_seg_best_0.200.pt
│       └── image-512
│           ├── ckpts
│           │   └── resnet34_3d_seg_best_0.200.pt
├── fold-3
│   ├── depth-18
│   │   ├── image-256
│   │   │   ├── ckpts
│   │   │   │   └── resnet34_3d_seg_best_0.850.pt
│   │   └── image-512
│   │       ├── ckpts
│   │       │   └── resnet34_3d_seg_best_0.650.pt
│   └── depth-34
│       ├── image-256
│       │   ├── ckpts
│       │   │   └── resnet34_3d_seg_best_0.800.pt
│       └── image-512
│           ├── ckpts
│           │   └── resnet34_3d_seg_best_0.800.pt
└── prepare-experiment.sh
```


## main

In [None]:
results = []

folds = ["fold-1", "fold-2", "fold-3"]
depths = ["depth-18", "depth-34"]
image_sizes = ["image-256", "image-512"]


for fragment_id in fragment_ids:
    
    CFG.size = 256
    CFG.tile_size = 256
    CFG.stride = 256
    test_loader_256, xyxys_256 = make_test_dataset(fragment_id)
    
    CFG.size = 512
    CFG.tile_size = 512
    CFG.stride = 512
    test_loader_512, xyxys_512 = make_test_dataset(fragment_id)
    
    final_mask_pred = np.zeros_like(binary_mask)
    
    for fold in folds:
        for depth in depths:
            for image_size in image_sizes:
                dir_path = f"bce-dice-models/{fold}/{depth}/{image_size}/ckpts/"
                # get .pt model from the dir_path
                model_path = dir_path + [file for file in os.listdir(dir_path) if file.endswith(".pt")][0]
                # get the model threshold: it is of the format resnet34_3d_seg_best_0.900.pt
                model_threshold = float(model_path.split("_")[-1].split(".")[0])
                model_path = dir_path + model_path
                
                
                image_wh = int(image_size.split("-")[-1])
                CFG.size = image_wh
                CFG.tile_size = image_wh
                CFG.stride = image_wh
                
                model_depth = int(depth.split("-")[-1])
                model = SegModel(model_depth)
                model.load_state_dict(torch.load(model_path))
                model = model.to(CFG.device).eval()
                
    

                binary_mask = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/mask.png", 0)
                binary_mask = (binary_mask / 255).astype(int)
                
                ori_h = binary_mask.shape[0]
                ori_w = binary_mask.shape[1]
                # mask = mask / 255
                
                
                pad0 = (CFG.tile_size - binary_mask.shape[0] % CFG.tile_size)
                pad1 = (CFG.tile_size - binary_mask.shape[1] % CFG.tile_size)
                
                binary_mask = np.pad(binary_mask, [(0, pad0), (0, pad1)], constant_values=0)

                mask_pred = np.zeros(binary_mask.shape)
                mask_count = np.zeros(binary_mask.shape)

                if image_wh == 256:
                    test_loader = test_loader_256
                    xyxys = xyxys_256
                else:
                    test_loader = test_loader_512
                    xyxys = xyxys_512
                    
                for step, (images) in tqdm(enumerate(test_loader), total=len(test_loader)):
                    images = images.to(device)
                    images = images.unsqueeze(1)

                    
                    batch_size = images.size(0)

                    with torch.no_grad():
                        y_preds = model(images)
                        y_preds = torch.sigmoid(y_preds)

                    start_idx = step*CFG.batch_size
                    end_idx = start_idx + batch_size
                    for i, (x1, y1, x2, y2) in enumerate(xyxys[start_idx:end_idx]):
                        mask_pred[y1:y2, x1:x2] += y_preds[i].squeeze(0)
                        mask_count[y1:y2, x1:x2] += np.ones((CFG.tile_size, CFG.tile_size))
        
                plt.imshow(mask_count)
                plt.show()
    
                print(f'mask_count_min: {mask_count.min()}')
                mask_pred /= mask_count
    
                mask_pred = mask_pred[:ori_h, :ori_w]
                binary_mask = binary_mask[:ori_h, :ori_w]
                
                
                mask_pred = (mask_pred >= model_threshold).astype(int)
                mask_pred = mask_pred * binary_mask
                
                final_mask_pred.append(mask_pred)
                
             
    final_mask_pred = np.mean(final_mask_pred, axis=0)
    # apply final_threshold
    final_threshold = 8./12
    final_mask_pred = (final_mask_pred > final_threshold).astype(int)
       
    
    plt.imshow(final_mask_pred)
    plt.show()
    
    inklabels_rle = rle(final_mask_pred)
    
    results.append((fragment_id, inklabels_rle))
    

    del final_mask_pred, mask_count
    del test_loader
    
    gc.collect()
    torch.cuda.empty_cache()


## submission

In [None]:
sub = pd.DataFrame(results, columns=['Id', 'Predicted'])

In [None]:
sub

In [None]:
sample_sub = pd.read_csv(CFG.comp_dataset_path + 'sample_submission.csv')
sample_sub = pd.merge(sample_sub[['Id']], sub, on='Id', how='left')

In [None]:
sample_sub

In [None]:
sample_sub.to_csv("submission.csv", index=False)