In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import tensorflow as tf

from tensorflow.keras.models import Model, load_model
from tqdm import tqdm


In [None]:
img = np.load('/kaggle/input/rsna-miccai-voxel-256-dataset/voxel/train/00000/FLAIR.npy')

In [None]:
for i in range(100,256):
    plt.imshow(img[i,:,:],cmap='gray')
    plt.show()

In [None]:
class DataLoader(tf.keras.utils.Sequence):
    def __init__(self,base_dir='/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/test',\
                mods=['FLAIR']):
        self.batch_size = 1 
        self.base_dir = base_dir
        self.pat_ids = sorted(glob.glob(os.path.join(base_dir, '*')))
        self.modalities = mods
        print('PAT IDS:',len(self.pat_ids),' | Modalities:',self.modalities)
    
    def __getitem__(self,index):
        batch_patids = self.pat_ids[index:index+self.batch_size]
        all_images = {}
        for K in self.modalities:
            all_images[K] = {'images':[],'ids':[]}
        for patid in batch_patids:
            for MOD in all_images.keys():
                all_images[MOD]['images'].append(np.load(os.path.join(patid, MOD+'.npy')))
                all_images[MOD]['ids'].append(patid.replace('\\','/').split('/')[-1])
        return all_images
   
    def __len__(self):
        return int(len(self.pat_ids)/self.batch_size)

In [None]:
class StageOne:
    def __init__(self,modelpath='',height=256,width=256,score_min_thresh=0.5):
        self.model = load_model(modelpath)
        self.score_min_thresh = score_min_thresh
        self.height = height
        self.width = width
        self.offset_perc = 0.1
    
    def infer(self,image_batch,filter_batchsize=16):
        filtered_images = {}
        for K in image_batch.keys():
            filtered_images[K] = {'images':[],'ids':[]}
            for imagesbatch, patid in zip(image_batch[K]['images'],image_batch[K]['ids']):
                div,mod = divmod(len(imagesbatch),filter_batchsize)
                if mod!=0:
                    div+=1
                dset = tqdm(range(0,len(imagesbatch),filter_batchsize),total=div,position=0, leave=True)
                dset.set_description(f'{patid}|Filtering')
                filtered_batch_images = []
                for i in dset:
                    org_batchimgs = imagesbatch[i:i+filter_batchsize]
                    batchimgs = org_batchimgs
                    batchimgs = np.array([cv2.resize(img,(self.width,self.height))/255. for img in batchimgs])
                    out = self.model.predict(batchimgs)
                    maxindexes = np.argmax(out,axis=1)
                    for j in range(len(maxindexes)):
                        if maxindexes[j] == 1 and out[j][maxindexes[j]] >= self.score_min_thresh:
                            filtered_batch_images.append(org_batchimgs[j])
                            
                if len(filtered_batch_images)==0:
                    offset = math.ceil(len(imagesbatch)*self.offset_perc)
                    filtered_batch_images = imagesbatch[offset:-offset]
                filtered_images[K]['images'].append(filtered_batch_images)
                filtered_images[K]['ids'].append(patid)
                filtered_batch_images = None
                dset = None
                
            return filtered_images
              

In [None]:
mods = ['FLAIR']
generator = DataLoader(base_dir='/kaggle/input/rsna-miccai-voxel-256-dataset/voxel/train/')
stage_one = StageOne(modelpath='/kaggle/input/models/FINAL_MODELALL_acc0.9825_ep26.h5')

In [None]:
dset = tqdm(enumerate(generator),total=len(generator),position=0, leave=True)
dset.set_description('Loading_test')
all_results = {}
for i,sample in dset:
    #if i > 2:
    #    break
    filtered_images = stage_one.infer(sample)
    print(np.array(filtered_images['FLAIR']['images']).shape)

In [None]:
for i in filtered_images['FLAIR']['images'][0]:
    plt.imshow(i,cmap='gray')
    plt.show()

In [None]:
sample['FLAIR']['images'][0].shape