In [None]:
!pip install --quiet attrdict

In [None]:
from attrdict import AttrDict
from tqdm import tqdm
import os
import glob 
import numpy as np
from PIL import Image
import cv2 as cv
from matplotlib import pyplot as plt
import math
import numpy as np

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

In [None]:
cfg = AttrDict(
    metadata=AttrDict(
        images_path="/kaggle/input/rsna-miccai-png",
        labels_path='/kaggle/input/training-labels/train_labels.csv',
        val_mod=3,        # 1/3 of training images are kept for validation
        limit_first_k=None, # Load only 10 images of train/val/test
    ),
    dataset=AttrDict(
        input_keys=['T1wCE'],
        img_size=128,     # Resize images to smaller size
    ),
    dataloader=AttrDict(
        batch_size=1,
        num_workers=8,
    )
)

In [None]:
def load_metadata(images_path: str, labels_path: str, val_mod=3, limit_first_k=None):
    result = {'train': [], 'val': [], 'test': []}
    
    scan_id_to_label = {}
    with open(labels_path, 'r') as f:
        for i, line in enumerate(f):
            if i > 0:
                idx, label = line.split(',')
                scan_id_to_label[idx.strip()] = int(label.strip())
    
    for items_key in ['train', 'test']:
        all_files = list(os.listdir(f"{images_path}/{items_key}"))
        if limit_first_k:
            all_files = all_files[:limit_first_k]
            
        for scan_id in tqdm(all_files):
            scan_slices = {}

            for filepath in glob.glob(f"{images_path}/{items_key}/{scan_id}/*/*.png"):
                kind = filepath.split('/')[-2]
                slices = scan_slices.get(kind, [])
                slice_id = filepath.split('/')[-1].split('-')[-1].split('.')[0]
                slices.append((slice_id, filepath))
                scan_slices[kind] = slices
            
            for key in scan_slices:
                slices = scan_slices[key]
                slices.sort()
                scan_slices[key] = [path for _, path in slices]
            
            key = items_key
            if hash(scan_id) % val_mod == 0 and items_key == 'train':
                key = 'val'
                
            result[key].append({
                'scan': scan_id,
                'label': scan_id_to_label.get(scan_id),
                **scan_slices,
            })
    return AttrDict(result)

In [None]:
metadata = load_metadata(**cfg.metadata)
print(f"{len(metadata.train)} train | {len(metadata.val)} val | {len(metadata.test)} test")
# print(metadata.train[1])

In [None]:
class Dataset3d(torch.utils.data.Dataset):
    def __init__(self, metadata, input_keys, img_size=None):
        super().__init__()
        
        self.metadata = metadata
        self.input_keys = input_keys
        self.img_size = img_size

        # self.load()

    def load(self, idx, prop):
        img_size = self.img_size
        filenames = self.metadata[idx].get(prop)
        
        result = []
        for filename in filenames:
            img = np.array(Image.open(filename))
            img = (img / 255 - 0.5) * 2
            img = cv.resize(img, (img_size, img_size),
                            interpolation=cv.INTER_NEAREST)
            result.append(img)
        return np.array(result)

    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        img_size = self.img_size
        
        result = {}
        result['label'] = self.metadata[idx]['label']  
        result['scan'] = self.metadata[idx]['scan']
        for key in self.input_keys:
            result[key] = self.load(idx, key)
        
        return result

In [None]:
dataset = AttrDict(
  train=Dataset3d(metadata.train, **cfg.dataset),
  val=Dataset3d(metadata.val, **cfg.dataset),
  test=Dataset3d(metadata.test, **cfg.dataset),
)

In [None]:
# item = dataset.train[2]
# plt.figure(figsize=(5 * len(cfg.dataset.input_keys), 5))
# print(f"Scan: {item['scan']} (label: {item['label']})")
# for i, key in enumerate(cfg.dataset.input_keys):
#     print(f"  {key}: {item[key].shape}")
#     plt.subplot(1, len(cfg.dataset.input_keys), i + 1)
#     plt.imshow(item[key][10])
# plt.show()

In [None]:
dataloader = AttrDict(
    train=torch.utils.data.DataLoader(dataset.train, shuffle=True, **cfg.dataloader),
    val=torch.utils.data.DataLoader(dataset.train, shuffle=False, **cfg.dataloader),
    test=torch.utils.data.DataLoader(dataset.train, shuffle=False, **cfg.dataloader),
)

In [None]:
# for batch in dataloader.train:
#     print(batch)
#     break

In [None]:
# print(len(metadata.train[2]['FLAIR']))

In [None]:
dataloader.train

### T1
Fat is depicted in white and water in black.<br/>
The shape of the brain can be clearly seen, and morphological abnormalities are easy to detect (Atrophy, tumors, etc.)<br/>

### T2
Water is painted white.<br/>
Lesions appear white. Suitable for lesion evaluation.<br/>

### FLAIR
In T2, the spinal fluid (water) is white and the lesion is also white, so you have to look for the white in the white, which is difficult to understand.<br/>
FLAIR can be roughly thought of as T2, in which the water is also black, making it easier to find the lesion.<br/>


### Observatie
 - T1w este T1 weighted pre-contrast
 - T2wCE este T1 weighted post-contrast
 - T2w este T2 weighted
 - FLAIR = Fluid Attenuated Inversion Recovery
 - fiecare folder contine tipuri diferite de RMN (contrastul difera)
 - NU exista o regula de orientare a scanarilor (de ex. T1w contine rmn in plan sagital, dar si in plan coronal sau orizontal)
 - Plan Sagital = stanga-dreapta
 - Plan Coronal = fata-spate
 - Plan Orizontal = sus-jos
<br/>
<br/>

[Link](https://case.edu/med/neurology/NR/MRI%20Basics.htm) explicatii la ce inseamna T1w, T2w, FLAIR.<br/>

## #1 try - One folder only

In [None]:
# plt.figure(figsize=(6 * len(cfg.dataset.input_keys),6))
# for i in range(4):
#     plt.subplot(1, len(cfg.dataset.input_keys), i + 1)
#     plt.imshow(dataset.train[i]['T2w'][20])
# plt.show()


In [None]:
# patient = 0
# len(dataset.train[patient]['FLAIR'])

In [None]:
def get_images(dataset, patient: int, folder: str, train=True):
    images = []
    if train == True:
        for img in dataset.train[patient][folder]:
            # Exclude the blank images
            if np.max(img)!=0:
                images.append(img)
            else:
                pass
    else:
        for img in dataset.test[patient][folder]:
            # Exclude the blank images
            if np.max(img)!=0:
                images.append(img)
            else:
                pass
    
    return images

In [None]:
# patient = 1

# images = get_images(dataset, patient, 'T2w')
# print('Nr of images:', len(images))

# fig = plt.figure(figsize=(50,50))

# c = 1
# for image in images:
#     ax = fig.add_subplot(len(images)//10+1, 10, c)
#     ax.imshow(image, cmap='gray')
#     c+=1
    
#     plt.axis('off')
    
# fig.tight_layout()

In [None]:
len(metadata.train[0])

In [None]:
label = 'T1wCE'
max_images, min_images = len(metadata.train[0][label]), len(metadata.train[0][label])
for i in metadata.train[1:]:
    if len(i[label]) > max_images:
        max_images = len(i[label])
    if len(i[label]) < min_images:
        min_images = len(i[label])

print(f'Min: {min_images}\nMax: {max_images}')

In [None]:
for batch in dataloader.train:
    print(batch)
    break

In [None]:
len(dataset.train[0]['T1wCE'])

In [None]:
from tqdm import tqdm

# TODO: n/15

label = 'T1wCE'
small_dataset = []
for d in tqdm(dataloader.train):
    small_dataset.append(d[label][:15])

In [None]:
from tqdm import tqdm

# TODO: n/15

label = 'T1wCE'
small_dataset = []
for d in tqdm(dataloader.train):
#     T2w_dataset.append(d['T1w'][:15])
    nr_photos = len(d[label])
    if nr_photos > 15:
        m = nr_photos / 15
        if round(m) == math.floor(m): # nu facem padding
            for i in range(math.floor(m)):
                small_dataset.append(d[label][i::round(m)]) # trebuie completat si pentru Y
        else: # facem padding
            # add (15 - nr_photos % 15) of zero(128)
            np.append(d[label],np.zeros(((15 - nr_photos % 15), 128, 128)), axis=0)
            m = (nr_photos + (15 - nr_photos % 15))/15
            for i in range(m):
                small_dataset.append(d[label][i::m])
                      
    else:
          small_dataset.append(d[label])

#     if d['T1w'] is None:
#         continue
#     else:
#         T2w_dataset.append(d['T1w'][:15])

In [None]:
small_dataset[0][0][0]

In [None]:
class Convnet (nn.Module):
    def __init__(self):
        super(Convnet, self).__init__()
        # 128 -> 124 -> 62
        # 15 -> 11 ->
        self.conv1 = nn.Sequential(
            nn.Conv3d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=0,
            ),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)
        )
        # 62 -> 58 -> 29
        self.conv2 = nn.Sequential(
            nn.Conv3d(
                in_channels=16,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=0,
            ),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2)
        )
        self.out = nn.Linear(32 * 29 * 29, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x # return x for visualization

In [None]:
from torch import optim

cnn = Convnet()

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr = 0.001)

In [None]:
import numpy as np
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

def get_model(width=128, height=128, depth=15):
    inputs = keras.Input((width, height, depth, 1))
    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(units=1, activation="sigmoid")(x)

    model = keras.Model(inputs, outputs)
    return model

In [None]:
model = get_model(128,128,15)

model.fit(
    train[folder][pacient],
    validation_data=validation_dataset,
    epochs=10,
)

In [None]:
def minMaxNormalize(volume):
    # values between 0 and 1
    min = 0
    max = 255
    volume[volume < min] = min
    volume[volume > max] = max
    volume = (volume - min) / (max - min)
    volume = volume.astype("float32")
    return volume

In [None]:
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
  inputs, labels = data
  optimizer.zero_grad()

  # forward + backward + optimize
  outputs = net(inputs)
  loss = criterion(outputs, labels)
  loss.backward()
  optimizer.step()

  # print statistics
  running_loss += loss.item()
  if i % 2000 == 1999:    # print every 2000 mini-batches
      print('[%d, %5d] loss: %.3f' %
            (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0