# Preprocessing data for [Stage 2 Type 1](https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type1)

All credit goes to the original author [Qishen Ha](https://www.kaggle.com/code/haqishen/rsna-2022-1st-place-solution-train-stage2-type1)


# Mount dataset onto Google Colab

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
!echo "deb http://packages.cloud.google.com/apt gcsfuse-`lsb_release -c -s` main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!apt-get -y -q update
!apt-get -y -q install gcsfuse

In [None]:
!mkdir -p tmp
!gcsfuse --implicit-dirs --limit-bytes-per-sec -1 --limit-ops-per-sec -1 "ADD_GCS_PATH" tmp

In [None]:
!mkdir ~/.kaggle && mv kaggle.json ~/.kaggle/

In [None]:
### Extra files
! kaggle datasets download -d boliu0/covn3d-same
! kaggle datasets download -d haqishen/pylibjpeg140py3

In [None]:
!unzip covn3d-same.zip && unzip pylibjpeg140py3.zip

In [None]:
!pip -q install monai
!pip -q install segmentation-models-pytorch==0.2.1
!pip -q install pylibjpeg-1.4.0-py3-none-any.whl
!pip -q install python_gdcm-3.0.17.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
!pip -q install pydicom

In [None]:
DEBUG = False

import os
import sys
sys.path = [
    '/content/covn3d-same',
] + sys.path

In [None]:
import os
import sys
import gc
import ast
import cv2
import time
import timm
import pickle
import random
import pydicom
import argparse
import warnings
import threading
import numpy as np
import pandas as pd
from glob import glob
import nibabel as nib
from PIL import Image
from tqdm import tqdm
import albumentations
from pylab import rcParams
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from monai.transforms import Resize
import  monai.transforms as transforms

%matplotlib inline
rcParams['figure.figsize'] = 20, 8
device = torch.device('cuda')
torch.backends.cudnn.benchmark = True

In [None]:
data_dir = '/content/tmp'
image_size_seg = (128, 128, 128)
msk_size = image_size_seg[0]
image_size_cls = 224
n_slice_per_c = 15
n_ch = 5

drop_rate = 0.
drop_path_rate = 0.
loss_weights = [1, 1]
p_mixup = 0.1
batch_size_seg = 1
batch_size = 4
num_workers = 2

n_blocks = 4
n_folds = 5
backbone = 'resnet18d'
out_dim = 7

out_dir = './preprocess_stage2_type1'
os.makedirs(out_dir, exist_ok=True)

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = cv2.resize(data, (image_size_seg[0], image_size_seg[1]), interpolation = cv2.INTER_AREA)
    return data


def load_dicom_line_par(path):

    t_paths = sorted(glob(os.path.join(path, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))

    n_scans = len(t_paths)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_size_seg[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]

    images = []
    for filename in t_paths:
        images.append(load_dicom(filename))
    images = np.stack(images, -1)
    
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)

    return images

In [None]:
df_train = pd.read_csv(os.path.join(data_dir, 'train.csv'))

mask_files = os.listdir(f'{data_dir}/segmentations')
df_mask = pd.DataFrame({
    'mask_file': mask_files,
})
df_mask['StudyInstanceUID'] = df_mask['mask_file'].apply(lambda x: x[:-4])
df_mask['mask_file'] = df_mask['mask_file'].apply(lambda x: os.path.join(data_dir, 'segmentations', x))
df = df_train.merge(df_mask, on='StudyInstanceUID', how='left')
df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'train_images', x))
df['mask_file'].fillna('', inplace=True)

kf = KFold(5)
df['fold'] = -1
for fold, (train_idx, valid_idx) in enumerate(kf.split(df, df)):
    df.loc[valid_idx, 'fold'] = fold

df.tail()

In [None]:
print(df.shape)
df.head()

In [None]:
sid = []
cs = []
label = []
fold = []
image_folder = []
for _, row in df.iterrows():
    for i in [1,2,3,4,5,6,7]:
        sid.append(row.StudyInstanceUID)
        image_folder.append(row.image_folder)
        cs.append(i)
        label.append(row[f'C{i}'])
        fold.append(row.fold)

df = pd.DataFrame({
    'StudyInstanceUID': sid,
    'c': cs,
    'label': label,
    'image_folder': image_folder,
    'fold': fold
})

df.tail()

# Dataset

In [None]:
class SegDataset(Dataset):

    def __init__(self, df):
        self.df = df.reset_index()

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]
        
        studyinstanceuid = row.StudyInstanceUID
        c = row.c

        image = load_dicom_line_par(row.image_folder)
        if image.ndim < 4:
            image = np.expand_dims(image, 0)
        image = image.astype(np.float32).repeat(3, 0)  # to 3ch
        image = image / 255.

        ### using local cache
        # image_file = os.path.join('./data/train_images_npy/', f'{row.StudyInstanceUID}.npy')
        # image = np.load(image_file).astype(np.float32)

        return {
            "StudyInstanceUID": studyinstanceuid,
            "c": c,
            "image": torch.tensor(image).float()
        }

# Saving dataset to disk to use for caching

In [None]:
# from pathlib import Path

# data = Path('data')
# data.mkdir(exist_ok=True)

# stage2_train_image_path_data = Path('data/stage2_train_images_npy')
# stage2_train_image_path_data.mkdir(exist_ok=True)

# for step, (batch) in tqdm(enumerate(dataset_seg), total=len(dataset_seg)):
#   image = batch['image'].cpu().detach().numpy()
#   study_instance_uid = batch['StudyInstanceUID'].pop()
#   c = int(batch['c'].cpu().detach().numpy())

#   np.save(os.path.join(stage2_train_image_path_data, f'{study_instance_uid}'), image)

In [None]:
dataset_seg = SegDataset(df)
loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, shuffle=False, num_workers=num_workers)

# check batch
batch = dataset_seg[0]
batch['StudyInstanceUID'], batch['c'], batch['image'].shape

In [None]:
class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model(
            backbone,
            in_chans=3,
            features_only=True,
            drop_rate=drop_rate,
            drop_path_rate=drop_path_rate,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [256, 128, 64, 32, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
            )

        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        global_features = [0] + self.encoder(x)[:n_blocks]
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features

In [None]:
from timm.models.layers.conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame

def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module(
            name, convert_3d(child)
        )
    del module

    return module_output

m = TimmSegModel(backbone)
m = convert_3d(m)
m(torch.rand(1, 3, 128,128,128)).shape

In [None]:
def load_bone(msk, cid, t_paths, cropped_images):
    n_scans = len(t_paths)
    bone = []
    try:
        msk_b = msk[cid] > 0.2
        msk_c = msk[cid] > 0.05

        x = np.where(msk_b.sum(1).sum(1) > 0)[0]
        y = np.where(msk_b.sum(0).sum(1) > 0)[0]
        z = np.where(msk_b.sum(0).sum(0) > 0)[0]

        if len(x) == 0 or len(y) == 0 or len(z) == 0:
            x = np.where(msk_c.sum(1).sum(1) > 0)[0]
            y = np.where(msk_c.sum(0).sum(1) > 0)[0]
            z = np.where(msk_c.sum(0).sum(0) > 0)[0]

        x1, x2 = max(0, x[0] - 1), min(msk.shape[1], x[-1] + 1)
        y1, y2 = max(0, y[0] - 1), min(msk.shape[2], y[-1] + 1)
        z1, z2 = max(0, z[0] - 1), min(msk.shape[3], z[-1] + 1)
        zz1, zz2 = int(z1 / msk_size * n_scans), int(z2 / msk_size * n_scans)

        inds = np.linspace(zz1 ,zz2-1 ,n_slice_per_c).astype(int)
        inds_ = np.linspace(z1 ,z2-1 ,n_slice_per_c).astype(int)
        for sid, (ind, ind_) in enumerate(zip(inds, inds_)):

            msk_this = msk[cid, :, :, ind_]

            images = []
            for i in range(-n_ch//2+1, n_ch//2+1):
                try:
                    dicom = pydicom.read_file(t_paths[ind+i])
                    images.append(dicom.pixel_array)
                except:
                    images.append(np.zeros((512, 512)))

            data = np.stack(images, -1)
            data = data - np.min(data)
            data = data / (np.max(data) + 1e-4)
            data = (data * 255).astype(np.uint8)
            msk_this = msk_this[x1:x2, y1:y2]
            xx1 = int(x1 / msk_size * data.shape[0])
            xx2 = int(x2 / msk_size * data.shape[0])
            yy1 = int(y1 / msk_size * data.shape[1])
            yy2 = int(y2 / msk_size * data.shape[1])
            data = data[xx1:xx2, yy1:yy2]
            data = np.stack([cv2.resize(data[:, :, i], (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR) for i in range(n_ch)], -1)
            msk_this = (msk_this * 255).astype(np.uint8)
            msk_this = cv2.resize(msk_this, (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR)

            data = np.concatenate([data, msk_this[:, :, np.newaxis]], -1)

            bone.append(torch.tensor(data))

    except:
        for sid in range(n_slice_per_c):
            bone.append(torch.ones((image_size_cls, image_size_cls, n_ch+1)).int())

    cropped_images[cid] = torch.stack(bone, 0)


def load_cropped_images(msk, image_folder, n_ch=n_ch):
    ### debugging 
    # import pdb; pdb.set_trace()
    t_paths = sorted(glob(os.path.join(image_folder, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    for cid in range(7):
        threads[cid] = threading.Thread(target=load_bone, args=(msk, cid, t_paths, cropped_images))
        threads[cid].start()
    for cid in range(7):
        threads[cid].join()

    return torch.cat(cropped_images, 0)

In [None]:
models_seg = []

kernel_type = 'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
backbone = 'resnet18d'
model_dir_seg = './drive/MyDrive/models/'
for fold in range(5):
    model = TimmSegModel(backbone, pretrained=False)
    model = convert_3d(model)
    model = model.to(device)
    load_model_file = os.path.join(model_dir_seg, f'timm3d_res18d_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep_fold0_best.pth')
    sd = torch.load(load_model_file)
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    model.load_state_dict(sd, strict=True)
    model.eval()
    models_seg.append(model)

len(models_seg)

# Save preprocessed data to .npy

In [None]:
bar = tqdm(loader_seg)
with torch.no_grad():
    for batch_id, (batch) in enumerate(bar):
        images = batch['image'].cuda()
        study_instance_uid = batch['StudyInstanceUID'].pop()
        cid = int(batch['c'].cpu().detach().numpy())

        # SEG
        pred_masks = []
        for model in models_seg:
            pmask = model(images).sigmoid()
            pred_masks.append(pmask)
        pred_masks = torch.stack(pred_masks, 0).mean(0).cpu().numpy()
      
        # Build cls input
        cls_inp = []
        threads = [None] * 7
        cropped_images = [None] * 7

        for i in range(pred_masks.shape[0]):
            row = df.iloc[batch_id*batch_size_seg+i]
            cropped_images = load_cropped_images(pred_masks[i], row.image_folder)
            cls_inp.append(cropped_images.permute(0, 3, 1, 2).float() / 255.)
        cls_inp = torch.stack(cls_inp, 0).to(device)  # (1, 105, 6, 224, 224)
        
        slices = []
        for i in range(0, cls_inp.shape[1], 7):
            ind = i // 7
            filepath = os.path.join(out_dir, f'{study_instance_uid}_{cid}_{ind}.npy')
            image_slice = cls_inp[:,i,:,:,:].squeeze(0).cpu().detach().numpy()
            np.save(filepath, image_slice)
        #     break
        # break

# Check output

In [None]:
img = np.load('./preprocess_stage2_type1/1.2.826.0.1.3680043.6200_1_0.npy')
print(img.shape)
single_img = img[0]
plt.imshow(single_img)