# Skipping Slices
In case of scans with very large number of slices, there isn't a lot of difference between the consecutive slices.
For eg. In a scan having 400 slices, there wouldn't be much difference between slice number 50,51 and 52. Therefore we can skip a few slices after every selected slice.

Here I have calculated the number of slices to skip based on the shape of the scan and required stack size.
Using this approach along, I trained 4 different models on different MRI Sequences using different CV splits and then used ensembling.

In [None]:
!pip install monai

In [None]:
import os
import requests
import glob
from tqdm import tqdm_notebook as tqdm
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, utils
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
from sklearn.metrics import classification_report, accuracy_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import StratifiedKFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gc
from monai import transforms as T

import warnings
warnings.filterwarnings("ignore")

In [None]:
path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'
train_data = pd.read_csv(os.path.join(path, 'train_labels.csv'))

train_data = train_data.dropna().reset_index(drop=True)
train_data["kfold"] = -1

train_data = train_data.sample(frac=1, random_state=7).reset_index(drop=True)

kf = StratifiedKFold(n_splits=4)

for fold, (trn_, val_) in enumerate(kf.split(X=train_data, y=train_data.MGMT_value.values)):
    print(len(trn_), len(val_))
    train_data.loc[val_, 'kfold'] = fold

print('Num of train samples:', len(train_data))
# train_data.head()

In [None]:
img_size = 256
stack_size = 64

def dicom2array(path, voi_lut=True, fix_monochrome=True):
    dicom = pydicom.read_file(path)
    # VOI LUT (if available by DICOM device) is used to
    # transform raw DICOM data to "human-friendly" view
    if voi_lut:
        data = apply_voi_lut(dicom.pixel_array, dicom)
    else:
        data = dicom.pixel_array
    # depending on this value, X-ray may look inverted - fix that:
    if fix_monochrome and dicom.PhotometricInterpretation == "MONOCHROME1":
        data = np.amax(data) - data
    data = cv2.resize(data, (img_size, img_size))
    return data

def load_sequence(paths):
    stack = []
    
    # load only non zero slices
    for i, path in  enumerate(paths):
        data = dicom2array(path)
        if data.max() <= 0:
            continue
        else:
            stack.append(data)
    
    # if all empty
    if len(stack)==0:
        return np.zeros((img_size,img_size,stack_size))
    
    stack = np.dstack(stack)
    
    # Skip slices(take every nth slice)
    # Calculate value of n based on scan's shape and stack_size
    n = stack.shape[2]//stack_size + 1 
    # Select a random starting slice from first n slices
    start = np.random.choice([i for i in range(n)]) 
    # Take every nth slice from the start slice
    stack = stack[:,:,start::n] 
    
    # If sequence is very small, repeat it multiple times. Better than leaving slices empty(ie. zero padding).
    num_of_repetitions = stack_size//stack.shape[2]
    stack = np.concatenate((stack,)*num_of_repetitions + (stack[:,:,:stack_size-stack.shape[2]*num_of_repetitions],), axis=2)
    
    return stack

def load_3d_dicom_images(scan_id, split = "train"):
    """
    Sort the slices based on name "correctly". By default it comes like 1->10->100->101->102 (sorted by name in string format)
    """
    # Flair
    flair = sorted(glob.glob(f"{path}/{split}/{scan_id}/FLAIR/*.dcm"), key=lambda x: int(x.split('/')[-1].split('-')[-1].split('.')[0]))
    flair_img = load_sequence(flair)
    
    # T1W
#     t1w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1w/*.dcm"), key=lambda x: int(x.split('/')[-1].split('-')[-1].split('.')[0]))
#     t1w_img = load_sequence(t1w)
    
#     # T1WCE
#     t1wce = sorted(glob.glob(f"{path}/{split}/{scan_id}/T1wCE/*.dcm"), key=lambda x: int(x.split('/')[-1].split('-')[-1].split('.')[0]))
#     t1wce_img = load_sequence(t1wce)
    
    
#     # T2W
#     t2w = sorted(glob.glob(f"{path}/{split}/{scan_id}/T2w/*.dcm"), key=lambda x: int(x.split('/')[-1].split('-')[-1].split('.')[0]))
#     t2w_img = load_sequence(t2w)
    return flair_img.astype(np.float32)

In [None]:
# let's write a simple pytorch dataloader


class BrainTumor(Dataset):
    def __init__(self, df, path = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification', split = "train", validation_fold = 0):
        
        df.BraTS21ID = df.BraTS21ID.apply(lambda x: str(x).zfill(5))
        self.labels = {}            
        if split == "val":
            self.split = 'train'
            val_data = df[df.kfold==validation_fold]
            brats = list(val_data["BraTS21ID"])
            mgmt = list(val_data["MGMT_value"])
            for b, m in zip(brats, mgmt):
                self.labels[b] = m
            
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{self.split}/" + "/*"))]
            self.ids = [id for id in self.ids if id in val_data.BraTS21ID.values]
        elif split == "train":
            self.split = split
            train_data = df[df.kfold!=validation_fold]
            brats = list(train_data["BraTS21ID"])
            mgmt = list(train_data["MGMT_value"])
            for b, m in zip(brats, mgmt):
                self.labels[b] = m
            
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{self.split}/" + "/*"))]
            self.ids = [id for id in self.ids if id in train_data.BraTS21ID.values]
        else:
            self.split = split
            self.ids = [a.split("/")[-1] for a in sorted(glob.glob(path + f"/{self.split}/" + "/*"))]
            
    
    def __len__(self):
        return len(self.ids)
    
    def preprocess(self, stack_dim):
        return T.Compose(
            [
                T.AddChanneld(keys=['image'],), # Below steps won't work without adding channel
                T.NormalizeIntensityd(keys=['image'],),
#                 T.Spacingd(keys=['image'], pixdim=(1, 1, stack_dim)),
#                 T.ResizeWithPadOrCropd(keys=['image'], mode='reflect', spatial_size=(img_size, img_size, stack_size)),
            ]
        )
    
    def get_transforms(self):
        return A.Compose([
                    A.OneOf([
                        A.HorizontalFlip(),
                        A.VerticalFlip(),
                        A.IAAAffine(shear=(-15,15), mode='constant', cval=0),
                        A.IAAAffine(scale=(0.9,1.2), mode='constant', cval=0),
                        A.IAAAffine(translate_percent=(0,0.15), mode='constant', cval=0),
                        A.IAAAffine(rotate=(-20,20), mode='constant', cval=0),
                    ], p=0.7),
                ])
    
    def __getitem__(self, idx):
        imgs = load_3d_dicom_images(self.ids[idx], self.split)
        
        # Preprocess stack
        stack_dim = imgs.shape[2]//stack_size*2 + 1
        preprocess = self.preprocess(stack_dim)
        # Extract out `3D` stacks to apply albumentation augmentations. 
        imgs = preprocess({'image':imgs})['image'][0,:,:,:] 
        
        # Augment processed stack
        transform = self.get_transforms()
        imgs = transform(image=imgs)['image']
        imgs = torch.unsqueeze(torch.from_numpy(imgs), 0) # Add channel dimension again.
        
        if self.split != "test":
            label = self.labels[self.ids[idx]]
            return imgs, torch.tensor([label], dtype = torch.float64)
        else:
            return imgs

In [None]:
bs = 4
train_dataset = BrainTumor(train_data, split='train', validation_fold=0)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=4, pin_memory=True)

val_dataset = BrainTumor(train_data, split='val', validation_fold=0)
val_loader = DataLoader(val_dataset, batch_size=bs, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
def plot_imgs(imgs, cmap='gray'):
    fig = plt.figure(figsize=(8,8))
    for i in range(64):
        img = imgs[:,:,i]
        fig.add_subplot(8, 8, i+1)
        plt.imshow(img, cmap=cmap)
    plt.show()     

for i,(img, label) in enumerate(val_loader):
    plot_imgs(img[0,0])
    break

In [None]:
# Using DenseNet because the dataset size is small, dense connections can help in this case.

save_model_name = 'DenseNet-264-flair-fold-0.pt'
from monai.networks.nets import DenseNet264
model = DenseNet264(spatial_dims=3, in_channels=1, out_channels=1)
# model.out = nn.Sequential(
#                 nn.Dropout(p=0.5)
#                 nn.Linear(1920, 1, bias=True)
#             )

criterion = nn.BCEWithLogitsLoss()

lr = 1e-4
min_lr = 3e-6
n_epochs = 1 # Trial run
accumulation_steps = 16
early_stop = 7
patience = 2

In [None]:
!pip install torch_lr_finder
from torch_lr_finder import LRFinder
optimizer = torch.optim.AdamW(model.parameters(), lr=0.000001, weight_decay=0.001)
gpu = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
lr_finder = LRFinder(model, optimizer, criterion, device=gpu)
lr_finder.range_test(train_loader, end_lr=0.001, num_iter=25)
lr_finder.plot()
lr_finder.reset()

# Suggested LR: 2.74e-4

In [None]:
# Send updates on telegram. Really useful to keep track of training while commited.
def send_message(msg, chat_id, bot_token):
    """
    params:
    -------
    msg: message you want to receive
    chat_id: CHAT_ID
    bot_token: API_KEY of your bot
    """

    url  = f'https://api.telegram.org/bot{bot_token}/sendMessage'
    data = {'chat_id': str(chat_id), 'text': f'{msg}'}
    requests.post(url, data)

chat_id = None
token = ''
# send_message('Training Started...', chat_id, token)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(),lr = lr, weight_decay=0.001, amsgrad=False)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=patience, min_lr=min_lr, verbose=True)

gpu = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
model.to(gpu)

best_loss = np.inf

early_stopping_counter = 0
for epoch in range(n_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    model.train()
    train_bar = tqdm(enumerate(train_loader, 0), total=len(train_loader), desc = 'Training') 
    for i, data in train_bar: 
        x, y = data
        
        x = x.to(gpu)
        y = y.to(gpu)

        outputs = model(x)

        loss = criterion(outputs, y)
        running_loss += loss.item()
        
        loss = loss / accumulation_steps                
        loss.backward()
        
        # Gradient Accumulation
        if (i+1)%accumulation_steps == 0 or (i+1)==len(train_loader):             
            optimizer.step()
            optimizer.zero_grad()
        train_bar.set_description(f"Epoch: {epoch+1}, loss: {running_loss/(i+1)}") 

    print(f"epoch {epoch+1} train: {running_loss/len(train_loader)}")
#     send_message(f"epoch {epoch+1} train: {running_loss/len(train_loader)}", chat_id, token)

    v_loss = 0.0
    model.eval()
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader, 0), total=len(val_loader)): 

            x, y = data

            x = x.to(gpu)
            y = y.to(gpu)

            # forward
            outputs = model(x)
            loss = criterion(outputs, y)

            # print statistics
            v_loss += loss.item()
        
    v_loss = v_loss/len(val_loader)
    print(f"epoch {epoch+1} val: {v_loss}")
#     send_message(f"epoch {epoch+1} val: {v_loss}", chat_id, token)
    scheduler.step(v_loss)
    
    if (v_loss) < best_loss:
        best_loss = v_loss
        torch.save(model,save_model_name)
        print('Loss decreased, model saved.')
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
    if early_stopping_counter==early_stop:
        print('Early Stopping!!!')
        break
    print('\n')