In [None]:
!pip install segmentation_models_pytorch -q
!pip install timm -q

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
import cv2
from glob import glob
import numpy as np
import pandas as pd
import re
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
sample_submission = pd.read_csv('../input/rsna-miccai-brain-tumor-radiogenomic-classification/sample_submission.csv')
test=sample_submission
test['BraTS21ID5'] = [format(x, '05d') for x in test.BraTS21ID]
test.head(3)

In [None]:
def avg_img(img_size, img_list):
    n = len(img_list)
    arr = np.zeros((img_size, img_size), np.float32)
    
    for img_array in img_list:
        arr = arr + img_array/n

    arr=np.array(np.round(arr),dtype=np.uint8)

    return arr


def load_slices_3d(path_to_scan_dir, num_imgs=16, img_size=256):
    """
        root_dir: the path to the dataset
        scan_id: (integer/string) the id of the scan
        num_imgs: (integer) the number of the slices to get from the scan returns 3d tensor of shape (num_imgs, img_size, img_size)        
    """
    path_to_scan_dir = os.path.join(path_to_scan_dir, "*")
    slices_path = sorted(glob(path_to_scan_dir), key=lambda var:[int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
        
    slices = [cv2.imread(f, cv2.IMREAD_GRAYSCALE) for f in slices_path]
    slices = [cv2.resize(img, (img_size, img_size)) for img in slices]
        
    chunk_size = len(slices) // num_imgs #if chunk size if 
    
    averaged_slices = []
    
    for i in range(num_imgs):
        
        img_list = slices[i*chunk_size: (i+1)*chunk_size]
        averaged_slices.append(avg_img(img_size, img_list))

    return np.array(averaged_slices)

In [None]:
from glob import glob
from tqdm import tqdm
mri_types = ['T1w','T1wCE','T2w','FLAIR']
for mri_type in mri_types:
    
    scan_paths = sorted(glob("../input/rsna-miccai-png/test/*"))
    scan_paths = [i+"/"+mri_type for i in scan_paths]

    for scan_path in tqdm(scan_paths): #../input/rsna-miccai-png/test/00001/FLAIR

        img = load_slices_3d(scan_path, num_imgs=1, img_size=384)
        folder = scan_path.split('/')[4]
        mr_type = scan_path.split('/')[5]
        cv2.imwrite(f'{folder}_{mr_type}.png', img[0])
        

In [None]:
plt.imshow(cv2.imread('./00702_T1wCE.png'))

In [None]:
plt.imshow(cv2.imread('./00592_T1wCE.png'))

In [None]:
plt.imshow(cv2.imread('./00256_T1wCE.png'))

In [None]:
print('OK')