In [4]:
import torch
import pandas as pd
import numpy as np
from glob import glob
import cv2
import os
from matplotlib import pyplot as plt
import pydicom
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import multiprocessing as mp
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
    
from torch import nn
from torchvision import models

In [5]:
# # !pip install dicom2nifti --upgrade
# !pip install dicom2nifti
# !pip install monai

In [9]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [9]:
config = {}
channels = ['FLAIR','T1w','T1wCE','T2w']

config['project_path'] = '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification'

config['size'] = 256
config['batch_size'] = 64
config['num_workers'] = 8
config['output_nodes'] = 1
config['epochs'] = 10
config['lr'] = 1e-3
config['save_checkpoint'] = './checkpoint'
config['min_loss'] = 1
config['weight_decay'] = 0.0001
config['step_size'] = 5
config['gamma'] = 0.85

In [10]:
glob(os.path.join(config['project_path'],'*'))

['/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv',
 '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv',
 '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/test',
 '/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train']

In [11]:
_IMG_TRAIN = os.path.join(config['project_path'],'train')
_WORKING = '../'

In [12]:
train_labels = pd.read_csv(os.path.join(config['project_path'],'train_labels.csv'))
train_labels.head()

Unnamed: 0,BraTS21ID,MGMT_value
0,0,1
1,2,1
2,3,0
3,5,1
4,6,1


In [13]:
## generate absolute path to image
def gen_fold_id(BraTS21ID):
    std_len = 5
    
    BraTS21ID_str = str(BraTS21ID)
    if len(BraTS21ID_str) < std_len:
        app_len = std_len - len(BraTS21ID_str)

        full_id = ['0']*app_len + BraTS21ID_str.split()
        full_id_str = ''.join(full_id)
        return full_id_str
    else:
        return BraTS21ID_str
    
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    data = (data * 255).astype(np.uint8)
    return data

def area_visiable(img_array,threshold=0.2):
    img_area = int(img_array.shape[0])*int(img_array.shape[1])
    non_zero_area = len(np.nonzero(img_gray)[0])
    
    ratio = float(non_zero_area/img_area)
    if ratio >= threshold:
        return True
    else:
        return False
    
    
import numpy as np
from sklearn import metrics

def f1_score(preds, targets, sigmoid=True, thresh=0.5, average='micro', idx=None):
    if sigmoid: preds = 1/(1 + np.exp(-preds))
    preds = (preds >= thresh).astype(np.uint8)
    if idx is not None:
        return metrics.fbeta_score(y_true=targets, y_pred=preds, beta=1, average=None)[idx]
    return metrics.fbeta_score(y_true=targets, y_pred=preds, beta=1, average=average)

def create_metrics_DISEASE(func, name, label_cols_list):
    return lambda preds, target: func(preds, targets, idx=label_cols_list.index(name))

def classification_counts(y_pred, y_true, log=False):
    TP,FP,TN,FN = 0,0,0,0

    for i in range(len(y_pred)): 
        if y_true[i]==y_pred[i]==1:
            TP += 1
        if y_pred[i]==1 and y_true[i]== 0:
            FP += 1
        if y_true[i]==y_pred[i]==0:
            TN += 1
        if y_pred[i]==0 and y_true[i]==1:
            FN += 1
    if log:
        print('TP\tFP\tTN\tFN\t')
        print(f'{TP}\t{FP}\t{TN}\t{FN}\t')
    return {'TP': TP, 'FP': FP, 'TN': TN, 'FN': FN}

def binary_metrics(y_pred, y_true, log=False):
    precision = metrics.precision_score(y_true, y_pred, average='binary')
    recall = metrics.recall_score(y_true, y_pred, average='binary')
    spec = metrics.recall_score(y_true, y_pred, average='binary', pos_label=0)
    f1 = metrics.f1_score(y_true, y_pred, average='binary')
    f2 = metrics.fbeta_score(y_true, y_pred, average='binary', beta=2)
    auc = metrics.roc_auc_score(y_true, y_pred)
    acc = metrics.accuracy_score(y_true, y_pred)
    if log:
        print('pre\trecall\tspec\tacc\tauc\tf2\tf1\t')
        print(f'{precision:0.4f}\t{recall:0.4f}\t{spec:0.4f}\t{acc:0.4f}\t{auc:0.4f}\t{f2:0.4f}\t{f1:0.4f}')
    return {
        'precision': precision, 'recall': recall, 'specificity': spec, 
        'accuracy': acc, 'auc': auc, 'f2': f2, 'f1': f1 
    }

def report_binary_thresholded_metrics(y_pred, y_true, thresh_step=0.1, lite=True):
    report = pd.DataFrame(columns=['precision', 'recall', 'specificity', 
                                   'accuracy', 'auc', 'f2', 'f1', 'TP', 'FP', 'TN', 'FN'])
    preds = y_pred
    for thresh in np.arange(thresh_step, 1.00, thresh_step):
        y_pred = (preds >= thresh).astype(np.uint8).squeeze()
        metrics = binary_metrics(y_pred, y_true, log=False)
        counts = classification_counts(y_pred, y_true, log=False)
        row = pd.DataFrame.from_dict({f'{thresh:0.2f}': dict(metrics, **counts)}, orient='index')
        report = pd.concat([report, row], ignore_index=False)
        report.index.name = 'threshold'
    if lite:
        report = report[['precision', 'recall', 'specificity', 
                        'f1', 'TP', 'FP', 'TN', 'FN']]
    return report

def save_checkpoint(model, save_path, channel, epoch):
    """
    checkpoint = {
            'model': best_model,
            'epoch':epoch+1,
            'model_state_dict':best_model.state_dict(),
            'optimizer_state_dict':best_optimizer.state_dict(),
            'scheduler_state_dict':best_scheduler.state_dict()
            }
    """
#     d = date.today().strftime("%m_%d_%Y") 
#     h = datetime.now().strftime("%H_%M_%S").split('_')
#     h_offset = int(datetime.now().strftime("%H_%M_%S").split('_')[0])+7
#     h[0] = str(h_offset)
#     h = '_'.join(h)
#     today_time = d +'_'+h
# #     today_time = date.today().strftime("%m_%d_%Y") + '_' + datetime.now().strftime("%H_%M")
    if not os.path.isdir(save_path):
        os.mkdir(save_path)
    
    if not os.path.isdir(os.path.join(save_path,channel)):
        os.mkdir(os.path.join(save_path,channel))
    
    checkpoint = {
            'model': model,
            'epoch':epoch,
            'model_state_dict':model.state_dict()
            }
    f = os.path.join(os.path.join(save_path,channel), 'best_checkpoint.pth')
    torch.save(checkpoint, f)
    print('Saved checkpoint')
    

def evaluation(model,epoch,test_dataloader,criterion,device):
    model.eval()
    batch_losses = []
    with torch.no_grad():
        model_result = []
        targets = []
        for imgs, targets_batch in tqdm(test_dataloader):
            imgs, targets_batch = imgs.to(device), targets_batch.to(device)

            model_batch_result = model(imgs)
            model_result.extend(model_batch_result.cpu().numpy())
            targets.extend(targets_batch.cpu().numpy())
            
            loss = criterion(model_batch_result, targets_batch.type(torch.float))
            batch_loss_value = loss.item()
            batch_losses.append(batch_loss_value)
            
    loss_value = np.mean(batch_losses)
    df_report = report_binary_thresholded_metrics(model_result,targets)
    model.train()

    return df_report,loss_value
    

def load_checkpoint(filepath):
    
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['model_state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False
    model.eval()
    return model

In [14]:
# Simple dataloader and label binarization, that is converting test labels into binary arrays of length 27 (number of classes) with 1 in places of applicable labels).
# from torch.utils.data.dataset import Dataset
# import pandas as pd
# import torch
# import numpy as np
# import os
# from PIL import Image
# from PIL import ImageFile
# ImageFile.LOAD_TRUNCATED_IMAGES = True

class customDataset(Dataset):
    def __init__(self, df_data, transforms=None):
        """
        param:
            data_path: is path of images
            anno_path: is path of dataframe
            classes_path: is path of h5 classes
            
            transform:
            
        return:
            
        """
        self.transforms = transforms
        
        self.imgs = df_data['path']
#         self.annos = df_data['target']        
        self.annos = [np.array([item],dtype=float) for item in list(df_data['target'])]
    
    def __getitem__(self, item):
        anno = self.annos[item]
#         img_path = os.path.join(self.data_path, self.imgs[item])
#         img = Image.open(img_path).convert('RGB')
        img_gray = load_dicom(self.imgs[item])
        img_3d = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
        img = Image.fromarray(img_3d)
        
        if self.transforms is not None:
            img = self.transforms(img)
        return img, anno

    def __len__(self):
        return len(self.imgs)
    
class customTestDataset(Dataset):
    def __init__(self, df_data, transforms=None):
        """
        param:
            data_path: is path of images
            anno_path: is path of dataframe
            classes_path: is path of h5 classes
            
            transform:
            
        return:
            
        """
        self.transforms = transforms
        
        self.imgs = df_data['path']
        
    def __getitem__(self, item):
        
        img_gray = load_dicom(self.imgs[item])
        img_3d = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
        img = Image.fromarray(img_3d)
        
        if self.transforms is not None:
            img = self.transforms(img)
        return img
    
    def __len__(self):
        return len(self.imgs)

# import torch
class customNet(nn.Module):
    def __init__(self, output_nodes):
        super().__init__()
        resnet = models.resnext50_32x4d(pretrained=True)
        feature_size = 256
        resnet.fc = nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(in_features=resnet.fc.in_features, out_features=feature_size),
            nn.Dropout(p=0.2),
            nn.Linear(in_features=feature_size, out_features=output_nodes)
        )
        self.base_model = resnet
        self.sigm = nn.Sigmoid()

    def forward(self, x):
        return self.sigm(self.base_model(x))

In [15]:
def single_process(idx):
    train_labels = pd.read_csv(os.path.join(config['project_path'],'train_labels.csv'))
    
    channels = ['FLAIR','T1w','T1wCE','T2w']

    fold_id = gen_fold_id(train_labels.iloc[idx]['BraTS21ID'])
    
    label = train_labels.iloc[idx]['MGMT_value']
    
#     nii_fold_id = os.path.join(_WORKING,'nifti',fold_id)
    
#     dict_anno = {}
    
#     if not os.path.isdir(nii_fold_id):
#         os.mkdir(nii_fold_id)
        
#     list_nii_file = []
    list_anno = []
    
    for channel in channels:
#         try:
#         dicom_directory = os.path.join(_IMG_TRAIN,fold_id,channel)
        dicom_directory = glob(os.path.join(_IMG_TRAIN,fold_id,channel,'*.dcm'))
        
        for img in dicom_directory:
            dict_anno = {}
            dict_anno['path'] = img
            dict_anno['target'] = label
            dict_anno['channel'] = channel
            list_anno.append(dict_anno)
    
#         nii_file = os.path.join(nii_fold_id, channel.lower()+'.nii.gz')

#         dicom2nifti.dicom_series_to_nifti(dicom_directory, nii_file, reorient_nifti=False)

#         list_nii_file.append(nii_file)
#         except:
#             continue
    
#     dict_anno['image'] = list_nii_file
#     dict_anno['label'] = label
    
#     return dict_anno
    return list_anno

In [16]:
list_anno = []

# nii_fold = os.path.join(_WORKING,'nifti')
# if not os.path.isdir(nii_fold):
#     os.mkdir(nii_fold)

pool = mp.Pool(8)
list_idx = range(len(train_labels))
for anno in tqdm(pool.imap_unordered(single_process,list_idx),total=len(list_idx)):
#     list_anno.append(anno)
    list_anno += anno

100%|██████████| 585/585 [00:11<00:00, 52.53it/s]


In [17]:
# df_anno = pd.DataFrame([])
# df_anno = pd.DataFrame(list_anno)
# df_anno['path'] = list_anno
# df_anno['target'] = list(anno_train['MGMT_value'])
# anno_train.iloc[-1]
# len(list_anno)
df_full_channel = pd.DataFrame(list_anno)
df_full_channel.head()

Unnamed: 0,path,target,channel
0,/kaggle/input/rsna-miccai-brain-tumor-radiogen...,1,FLAIR
1,/kaggle/input/rsna-miccai-brain-tumor-radiogen...,1,FLAIR
2,/kaggle/input/rsna-miccai-brain-tumor-radiogen...,1,FLAIR
3,/kaggle/input/rsna-miccai-brain-tumor-radiogen...,1,FLAIR
4,/kaggle/input/rsna-miccai-brain-tumor-radiogen...,1,FLAIR


In [18]:
df_full_channel['target'].value_counts()

1    195789
0    152852
Name: target, dtype: int64

## Pipeline for each channel

In [19]:
# channels

In [None]:
for channel in channels:
    config['channel'] = channel
    
    df_anno = df_full_channel[df_full_channel['channel']==config['channel']]

    x = list(df_anno['path'])
    y = list(df_anno['target'])
    xtrain, xval, ytrain, yval = train_test_split(x,y, test_size=0.2,random_state=25, shuffle=True)

    df_train = pd.DataFrame([])
    df_train['path'] = xtrain
    df_train['target'] = ytrain

    df_val = pd.DataFrame([])
    df_val['path'] = xval
    df_val['target'] = yval

    list_record_train = df_train.to_dict('record')
    list_record_valid = df_val.to_dict('record')
    print('#'*20)
    print('Amount current training and validation')
    print(df_train['target'].value_counts(),df_val['target'].value_counts())

    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    val_transform = transforms.Compose([
        transforms.Resize((int(config['size']), int(config['size']))),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_transform = transforms.Compose([
        transforms.Resize((int(config['size']), int(config['size']))),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    train_ds = customDataset(df_train, train_transform)

    val_ds = customDataset(df_val, val_transform)

    train_dataloader = DataLoader(
                                train_ds, 
                                batch_size=config['batch_size'], 
                                num_workers=config['num_workers'], 
                                shuffle=True,
                                drop_last=True
                                )

    test_dataloader = DataLoader(
                        val_ds, 
                        batch_size=config['batch_size'], 
                        num_workers=config['num_workers']
                        )

    model = customNet(config['output_nodes'])
    model.train()
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    criterion = nn.BCELoss()
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, 
        step_size=int(config['step_size']), 
        gamma= config['gamma'], 
        last_epoch=-1
        )

    for epoch in range(config['epochs']):
        print('*'*10)
        print('Current channel: {}'.format(channel))
        print('Current epoch: {}'.format(str(epoch)))

        batch_losses = []
        for imgs, targets in tqdm(train_dataloader):
            imgs, targets = imgs.to(device), targets.to(device)

            optimizer.zero_grad()

            model_result = model(imgs)
            loss = criterion(model_result, targets.type(torch.float))

            batch_loss_value = loss.item()
            loss.backward()
            optimizer.step()

            batch_losses.append(batch_loss_value)

        scheduler.step()

        df_eval,loss_val = evaluation(model,epoch,test_dataloader,criterion,device)

        loss_value = np.mean(batch_losses)

        if loss_value < config['min_loss']:
            config['min_loss'] = loss_value
            save_checkpoint(model, config['save_checkpoint'], config['channel'], epoch)
        
        print('Current performance: ',df_eval.iloc[4])
        print('Current learning rate: ',optimizer.param_groups[0]['lr'])
        print('Current training loss: ',loss_value)



####################
Amount current training and validation
1    34711
0    24687
Name: target, dtype: int64 1    8659
0    6191
Name: target, dtype: int64


Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth


  0%|          | 0.00/95.8M [00:00<?, ?B/s]

  0%|          | 0/928 [00:00<?, ?it/s]

**********
Current channel: FLAIR
Current epoch: 0


100%|██████████| 928/928 [21:17<00:00,  1.38s/it]
100%|██████████| 233/233 [01:51<00:00,  2.08it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  0%|          | 0/928 [00:00<?, ?it/s]

Saved checkpoint
Current performance:  precision      0.604409
recall         0.845363
specificity    0.226135
f1             0.704863
TP                 7320
FP                 4791
TN                 1400
FN                 1339
Name: 0.50, dtype: object
Current learning rate:  0.001
Current training loss:  0.684184262721703
**********
Current channel: FLAIR
Current epoch: 1


 13%|█▎        | 124/928 [02:55<18:18,  1.37s/it]