In [None]:
!pip install pyradiomics

In [None]:
import os
import sys 
from tqdm import tqdm 
import numpy as np
import pandas as pd
from PIL import Image
import pydicom
import torch
import nibabel as nib
import matplotlib.pyplot as plt
import SimpleITK as sitk
import radiomics
import cv2
from PIL import Image
import pydicom as dicom
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import ExplicitVRLittleEndian
import pydicom._storage_sopclass_uids
from pydicom.uid import RLELossless
import tensorflow as tf

train_path = '../input/rsna-miccai-brain-tumor-radiogenomic-classification/train/'

In [None]:
#train_dirs is a list of patient ids
train_dirs = sorted(os.listdir(train_path))
#reads series of files
reader = sitk.ImageSeriesReader()
#keeps the private metadata required for the files
reader.LoadPrivateTagsOn()

In [None]:
print(len(train_dirs))

In [None]:
def resample(image, ref_image):

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(ref_image)
    resampler.SetInterpolator(sitk.sitkLinear)
    
    resampler.SetTransform(sitk.AffineTransform(image.GetDimension()))

    resampler.SetOutputSpacing(ref_image.GetSpacing())

    resampler.SetSize(ref_image.GetSize())

    resampler.SetOutputDirection(ref_image.GetDirection())

    resampler.SetOutputOrigin(ref_image.GetOrigin())

    resampler.SetDefaultPixelValue(image.GetPixelIDValue())

    resamped_image = resampler.Execute(image)
    
    return resamped_image

In [None]:
def normalize(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

In [None]:
%%time
def get_img(index):
    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/T1w')
    #print("Number of dicom files in T1w : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    t1_sitk = reader.Execute()
    #print(t1_sitk)

    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/FLAIR')
    #print("Number of dicom files in FLAIR : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    flair_sitk = reader.Execute()
    #print(flair_sitk)

    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/T1wCE')
    #print("Number of dicom files in T1wCE : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    t1wce_sitk = reader.Execute()
    #print(t1wce_sitk)

    flair_resampled = resample(flair_sitk, t1_sitk)
    #print(flair_resampled)
    t1wce_resampled = resample(t1wce_sitk, t1_sitk)

    t1_sitk_array = normalize(sitk.GetArrayFromImage(t1_sitk))
    flair_resampled_array = normalize(sitk.GetArrayFromImage(flair_resampled))
    t1wce_resampled_array = normalize(sitk.GetArrayFromImage(t1wce_resampled))
    
    #print(t1_sitk_array.shape)
    #print(flair_resampled_array.shape)
    #print(t1wce_resampled_array.shape)

    stacked = np.stack([t1_sitk_array, flair_resampled_array, t1wce_resampled_array])

    to_rgb = stacked[:,t1_sitk_array.shape[0]//2,:,:].transpose(1,2,0)
    im = Image.fromarray((to_rgb * 255).astype(np.uint8))
    return im

In [None]:
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
    in_channels=3, out_channels=1, init_features=32, pretrained=True)

In [None]:
!pip3 install einops

In [None]:
def center_crop(img, dim):
    """Returns center cropped image
    Args:
    img: image to be center cropped
    dim: dimensions (width, height) to be cropped
    """                                                                                      
    width, height = img.shape[1], img.shape[0]

    # process crop width and height for max available dimension
    crop_width = dim[0] if dim[0]<img.shape[1] else img.shape[1]
    crop_height = dim[1] if dim[1]<img.shape[0] else img.shape[0] 
    mid_x, mid_y = int(width/2), int(height/2)
    cw2, ch2 = int(crop_width/2), int(crop_height/2) 
    crop_img = img[mid_y-ch2:mid_y+ch2, mid_x-cw2:mid_x+cw2]
    return crop_img                                                                              

def undesired_objects (image):
    image = image.astype('uint8')
    nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(image,connectivity=8)
    sizes = stats[:, -1]
    


    max_label = 1
    max_size = sizes[1]
    for i in range(2, nb_components):
        if sizes[i] > max_size:
            max_label = i
            max_size = sizes[i]
    
    
    
    img2 = np.zeros(output.shape)
    img2[output == max_label] = 255
 
    
    truncated.append(img2)

# CROPPING THE IMAGE

In [None]:
'''
from einops import rearrange, reduce, repeat

for ind in range(0,30):    
    im = get_img(int(train_dirs[ind]))
    print(train_dirs[int(train_dirs[ind])])
    test_img = np.array([np.moveaxis(np.array(im.resize((256, 256))), -1, 0)])
    test_img = (center_crop(test_img[0, 1, :, :],(160,180)))
    test_img = cv2.resize(test_img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
    test_img = cv2.merge((test_img, test_img, test_img))
    test_img = np.expand_dims(test_img, axis=0)
    input_tensor = torch.Tensor(test_img)
    input_tensor = rearrange(input_tensor, 'b h w c -> b c h w')
    print(input_tensor.shape)
    #test_img = np.array([np.moveaxis(np.array(im.resize((256, 256))), -1, 0)])
    #print(test_img.shape)
    
    
    test_res = model(input_tensor)

    #largest connected contour
    response_array = test_res.detach().numpy()
    response_array = np.reshape(response_array,(256,256))
    temp = response_array
    truncated = []
    undesired_objects(response_array)
    
    kernel = np.ones((5, 5), 'uint8')
    truncated[0] = cv2.dilate(truncated[0],kernel,iterations = 10)
    
    
    #draw bounding box on the largest contour
    truncated[0] = truncated[0].astype('uint8')
    bounding_boxes = []
    contours, hierarchy = cv2.findContours(truncated[0],cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)[-2:]
    idx =0 
    for cnt in contours:
        idx += 1
        x,y,w,h = cv2.boundingRect(cnt)
        cv2.rectangle(truncated[0],(x,y),(x+w,y+h),(200,0,0),2)
    bounding_boxes.append(truncated[0])
    
    image_rectangles[int(train_dirs[ind])] = [x,y,w,h]
    


    f, axarr = plt.subplots(2,2, figsize=(20, 20))
    axarr[0][0].imshow(input_tensor[0, 1, :, :].detach().cpu().numpy())
    axarr[0][1].imshow(test_res[0][0].detach().cpu().numpy() > 0.5)
    axarr[1][0].imshow(truncated[0])


    #map the bounding box to the original image
    marked_image = input_tensor.detach().cpu().numpy()
    marked_image = np.ascontiguousarray(marked_image, dtype=np.uint8)
    print(type(marked_image))
    cv2.rectangle(marked_image[0, 1, :, :],(x,y),(x+w,y+h),(0,50,50),2)


    
    axarr[1][1].imshow(marked_image[0, 1, :, :])

'''

NON-CROPPED IMAGE

In [None]:
def numpyToDicom(image2d,index,path):
    
    image2d = image2d.astype(np.uint16)

    #print("Setting file meta information...")
    # Populate required values for file meta information

    meta = pydicom.Dataset()
    meta.MediaStorageSOPClassUID = pydicom._storage_sopclass_uids.MRImageStorage
    meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid()
    meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian  

    ds = Dataset()
    ds.file_meta = meta

    ds.is_little_endian = True
    ds.is_implicit_VR = False

    ds.SOPClassUID = pydicom._storage_sopclass_uids.MRImageStorage
    ds.PatientName = "Test^Firstname"
    ds.PatientID = "123456"

    ds.Modality = "MR"
    ds.SeriesInstanceUID = pydicom.uid.generate_uid()
    ds.StudyInstanceUID = pydicom.uid.generate_uid()
    ds.FrameOfReferenceUID = pydicom.uid.generate_uid()

    ds.BitsStored = 16
    ds.BitsAllocated = 16
    ds.SamplesPerPixel = 1
    ds.HighBit = 15

    ds.ImagesInAcquisition = "1"

    ds.Rows = image2d.shape[0]
    ds.Columns = image2d.shape[1]
    ds.InstanceNumber = 1

    ds.ImagePositionPatient = r"0\0\1"
    ds.ImageOrientationPatient = r"1\0\0\0\-1\0"
    ds.ImageType = r"ORIGINAL\PRIMARY\AXIAL"

    ds.RescaleIntercept = "0"
    ds.RescaleSlope = "1"
    ds.PixelSpacing = r"1\1"
    ds.PhotometricInterpretation = "MONOCHROME2"
    ds.PixelRepresentation = 1

    pydicom.dataset.validate_file_meta(ds.file_meta, enforce_standard=True)

    #print("Setting pixel data...")
    ds.PixelData = image2d.tobytes()
    ds.save_as(path + str(index)+'.dcm')

In [None]:
image_rectangles = {}


In [None]:
for ind in range(0,585):
    print(int(train_dirs[ind]))
    im = get_img(ind)
    
    test_img = np.array([np.moveaxis(np.array(im.resize((256, 256))), -1, 0)])
    test_res = model(torch.Tensor(test_img))

    #largest connected contour
    response_array = test_res.detach().numpy()
    response_array = np.reshape(response_array,(256,256))
    temp = response_array
    truncated = []
    undesired_objects(response_array)
    
    kernel = np.ones((5, 5), 'uint8')
    truncated[0] = cv2.dilate(truncated[0],kernel,iterations = 10)

    #draw bounding box on the largest contour
    truncated[0] = truncated[0].astype('uint8')
    bounding_boxes = []
    contours, hierarchy = cv2.findContours(truncated[0],cv2.RETR_LIST,cv2.CHAIN_APPROX_SIMPLE)[-2:]
    idx =0 
    for cnt in contours:
        idx += 1
        x,y,w,h = cv2.boundingRect(cnt)
        roi=truncated[0][y:y+h,x:x+w]
        #v2.imwrite(str(idx) + '.jpg', roi)
        cv2.rectangle(truncated[0],(x,y),(x+w,y+h),(200,0,0),2)
    bounding_boxes.append(truncated[0])
    
    image_rectangles[int(train_dirs[ind])] = [x,y,w,h]
    



    #f, axarr = plt.subplots(2,2, figsize=(20, 20))
    #axarr[0][0].imshow(test_img[0, 1, :, :])
    #axarr[0][1].imshow(test_res[0][0].detach().cpu().numpy() > 0.5)


    #map the bounding box to the original image
    #marked_image = test_img
    #print(type(test_img))
    #cv2.rectangle(marked_image[0, 1, :, :],(x,y),(x+w,y+h),(0,50,50),2)


    #axarr[1][0].imshow(truncated[0])
    #axarr[1][1].imshow(marked_image[0, 1, :, :])



In [None]:
len(image_rectangles)

In [None]:
st = []

In [None]:
%%time
def load_img(index):
    
    '''
    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/T1w')
    
    #print(filenamesDICOM[0])
    
    print("Patient id = "+str(train_dirs[index]))
    print("Number of dicom files in T1w : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    os.mkdir('./test_final/'+train_dirs[index]+'/T1w/')
    path = './test_final/'+train_dirs[index]+'/T1w/'
    
    t1_sitk = reader.Execute()
    
    
    cntr = 0
    
    lnn = len(filenamesDICOM)
    
    mid = int(lnn/2)
    
    j = min(mid+32,lnn)
    i = max(mid-32,0)
    

    for k in range(i,j):
        ds = dicom.dcmread(filenamesDICOM[k])
        img = ds.pixel_array
        #print(img.shape)
        
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        l = image_rectangles[int(train_dirs[index])]
        x = l[0]
        y = l[1]
        w = l[2]
        h = l[3]
        img = img[x:x+w,y:y+w]
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        #cv2.imwrite(str(cntr) + '.jpg', img)
        #numpyToDicom(img,cntr,path)
        #cv2.imwrite(os.path.join(path , str(cntr) + '.jpg'), img)
        #tfrecord_writer = tf.io.TFRecordWriter('./test_final/'+train_dirs[index]+'/T1w/'+str(cntr)+'.tfrecord')
        st.append(img)
        #print(stk.shape)
        #example = tf.train.Example(features=tf.train.Features(feature={'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img.tobytes()]))}))
        
        #tfrecord_writer.write(example.SerializeToString())
        cntr += 1
    
      
    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/FLAIR')
    print("Patient id = "+str(train_dirs[index]))
    print("Number of dicom files in FLAIR : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    #os.mkdir('./test_train/'+train_dirs[index]+'/T1w/')
    #os.mkdir('./test_final/'+train_dirs[index]+'/FLAIR/')
    #path = './test_final/'+train_dirs[index]+'/FLAIR/'
    t1_sitk = reader.Execute()
    
    
    cntr = 0
    
    
    lnn = len(filenamesDICOM)
    
    mid = int(lnn/2)
    
    j = min(mid+32,lnn)
    i = max(mid-32,0)
    

    for k in range(i,j):
        ds = dicom.dcmread(filenamesDICOM[k])
        img = ds.pixel_array
        #print(img.shape)
        
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        l = image_rectangles[int(train_dirs[index])]
        x = l[0]
        y = l[1]
        w = l[2]
        h = l[3]
        img = img[x:x+w,y:y+w]
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        #cv2.imwrite(str(cntr) + '.jpg', img)
        #cv2.imwrite(os.path.join(path , str(cntr) + '.jpg'), img)
        #numpyToDicom(img,cntr,path)
        st.append(img)
        
        cntr += 1

    '''
    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/T1wCE')
    #print("Patient id = "+str(train_dirs[index]))
    #print("Number of dicom files in T1wCE : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    #os.mkdir('./test_train/'+train_dirs[index]+'/T1w/')
    #os.mkdir('./test_final/'+train_dirs[index]+'/T1wCE/')
    path = './test_final/'+train_dirs[index]+'/T1wCE/'
    t1_sitk = reader.Execute()
    
    
    cntr = 0
    
    lnn = len(filenamesDICOM)
    
    mid = int(lnn/2)
    
    j = min(mid+32,lnn)
    i = max(mid-32,0)
    

    for k in range(i,j):
        ds = dicom.dcmread(filenamesDICOM[k])
        img = ds.pixel_array
        #print(img.shape)
        
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        l = image_rectangles[int(train_dirs[index])]
        x = l[0]
        y = l[1]
        w = l[2]
        h = l[3]
        img = img[x:x+w,y:y+w]
        img = cv2.resize(img, dsize=(1024, 1024), interpolation=cv2.INTER_CUBIC)
        #cv2.imwrite(str(cntr) + '.jpg', img)
        #cv2.imwrite(os.path.join(path , str(cntr) + '.jpg'), img)
        #numpyToDicom(img,cntr,path)
        st.append(img)
        cntr += 1

    '''
    
    filenamesDICOM = reader.GetGDCMSeriesFileNames(f'{train_path}/{train_dirs[index]}/T2w')
    print("Patient id = "+str(train_dirs[index]))
    print("Number of dicom files in T2w : "+ str(len(filenamesDICOM)))
    reader.SetFileNames(filenamesDICOM)
    #os.mkdir('./test_train/'+train_dirs[index]+'/T1w/')
    os.mkdir('./test_final/'+train_dirs[index]+'/T2w/')
    path = './test_final/'+train_dirs[index]+'/T2w/'
    t1_sitk = reader.Execute()
    
    
    cntr = 0
    
    lnn = len(filenamesDICOM)
    
    mid = int(lnn/2)
    
    j = min(mid+32,lnn)
    i = max(mid-32,0)
    

    for k in range(i,j):
        ds = dicom.dcmread(filenamesDICOM[k])
        img = ds.pixel_array
        #print(img.shape)
        
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        l = image_rectangles[int(train_dirs[index])]
        x = l[0]
        y = l[1]
        w = l[2]
        h = l[3]
        img = img[x:x+w,y:y+w]
        img = cv2.resize(img, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
        #cv2.imwrite(str(cntr) + '.jpg', img)
        #cv2.imwrite(os.path.join(path , str(cntr) + '.jpg'), img)
        #numpyToDicom(img,cntr,path)
        st.append(img)
        cntr += 1

        
      '''

In [None]:
import shutil

dir_path = './test_final'

try:
    shutil.rmtree(dir_path)
except OSError as e:
    print("Error: %s : %s" % (dir_path, e.strerror))

os.mkdir('./test_final/')

In [None]:
for i in tqdm(range(0,585)):
    #os.mkdir('./test_final/'+train_dirs[i]+'/')
    #print(i)
    load_img(i)

In [None]:
import shutil
shutil.make_archive('Test_Final','zip','./test_final')

In [None]:
len(st)

In [None]:
stkk = np.array(st)
print(stkk[0].shape)
plt.imshow(stkk[0])

In [None]:
tfrecord_writer = tf.io.TFRecordWriter('./T1wce-TRAIN.tfrecord')
for image in tqdm(stkk):
    example = tf.train.Example(features=tf.train.Features(feature={'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()]))}))
    tfrecord_writer.write(example.SerializeToString())
tfrecord_writer.close()

In [None]:
def decode_fn(record_bytes):
    return tf.io.parse_single_example(record_bytes,
      {"image": tf.io.FixedLenFeature([], dtype=tf.string, default_value='')
      })
        #shape_h': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0),
        #shape_w': tf.io.FixedLenFeature([], dtype=tf.int64, default_value=0)})
        #'type': tf.io.FixedLenFeature([], dtype=tf.string, default_value='')})

In [None]:
parsed_dataset = tf.data.TFRecordDataset("./T1w-FLAIR-T2w.tfrecord").map(decode_fn)

In [None]:
for parsed_record in tqdm(parsed_dataset):
    image = tf.io.decode_raw(parsed_record['image'], out_type=tf.float32)
    image = tf.reshape(image, [256, 256])
    print(image.shape)
    break

In [None]:
import shutil

dir_path = './T1WCE.tfrecord'

try:
    shutil.rmtree(dir_path)
except OSError as e:
    print("Error: %s : %s" % (dir_path, e.strerror))

In [None]:
os.remove('./Test_Final.zip')