<a href="https://colab.research.google.com/github/raynardj/python4ml/blob/master/experiments/mri_find_brain_tumor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A study on MRI data
> First we download the data, as we found on [this page](https://figshare.com/articles/dataset/brain_tumor_dataset/1512427)

In [1]:
!wget https://ndownloader.figshare.com/articles/1512427/versions/5

--2021-01-06 10:39:16--  https://ndownloader.figshare.com/articles/1512427/versions/5
Resolving ndownloader.figshare.com (ndownloader.figshare.com)... 63.32.121.244, 108.128.58.52, 52.208.116.143, ...
Connecting to ndownloader.figshare.com (ndownloader.figshare.com)|63.32.121.244|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 879501695 (839M) [application/zip]
Saving to: ‘5’


2021-01-06 10:39:42 (33.2 MB/s) - ‘5’ saved [879501695/879501695]



In [3]:
!mv 5 1512427.zip

In [6]:
!unzip 1512427.zip

Archive:  1512427.zip
 extracting: brainTumorDataPublic_1-766.zip  
 extracting: brainTumorDataPublic_1533-2298.zip  
 extracting: brainTumorDataPublic_767-1532.zip  
 extracting: brainTumorDataPublic_2299-3064.zip  
 extracting: cvind.mat               
 extracting: README.txt              


In [4]:
!cat README.txt

This brain tumor dataset containing 3064 T1-weighted contrast-inhanced images
from 233 patients with three kinds of brain tumor: meningioma (708 slices), 
glioma (1426 slices), and pituitary tumor (930 slices). Due to the file size
limit of repository, we split the whole dataset into 4 subsets, and achive 
them in 4 .zip files with each .zip file containing 766 slices.The 5-fold
cross-validation indices are also provided.

-----
This data is organized in matlab data format (.mat file). Each file stores a struct
containing the following fields for an image:

cjdata.label: 1 for meningioma, 2 for glioma, 3 for pituitary tumor
cjdata.PID: patient ID
cjdata.image: image data
cjdata.tumorBorder: a vector storing the coordinates of discrete points on tumor border.
		For example, [x1, y1, x2, y2,...] in which x1, y1 are planar coordinates on tumor border.
		It was generated by manually delineating the tumor border. So we can use it to generate
		binary image of tumor mask.
c

In [None]:
!unzip brainTumorDataPublic_1533-2298.zip > /dev/null
!unzip brainTumorDataPublic_1-766.zip > /dev/null
!unzip brainTumorDataPublic_2299-3064.zip > /dev/null
!unzip brainTumorDataPublic_767-1532.zip > /dev/null

replace 1533.mat? [y]es, [n]o, [A]ll, [N]one, [r]ename: replace 1.mat? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
!ls -l *.mat |wc -l

In [None]:
!ls -l *.mat |head 

In [None]:
import scipy.io
import h5py

In [None]:
with h5py.File("1000.mat", "r") as f:
    print(f.keys())

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from glob import glob

In [None]:
def mat_to_data(filepath: str):
    with h5py.File(filepath, "r") as f:
        # ['cjdata']
        # ['PID', 'image', 'label', 'tumorBorder', 'tumorMask']
        djdata = f['cjdata']
        img = np.array(djdata['image'])
        pid = "".join(list(chr(i) for i in np.array(djdata['PID'])))
        label = list(djdata['label'])[0][0]
        boarder = np.array(djdata['tumorBorder'])
        mask = np.array(djdata['tumorMask'])
    return img, pid, label, boarder, mask

In [None]:
img, pid, label, boarder, mask  = mat_to_data("2000.mat")

In [None]:
!mkdir -p mri_data

In [None]:
def extract_file(filepath: str):
    img, pid, label, boarder, mask  = mat_to_data(filepath)
    img_file = f"{filepath}_img.npy"
    mask_file = f"{filepath}_mask.npy"
    boarder_file = f"{filepath}_bd.npy"
    np.save(f"mri_data/{img_file}", img)
    np.save(f"mri_data/{mask_file}", mask)
    np.save(f"mri_data/{boarder_file}", boarder)
    shape = "_".join(map(str,img.shape))
    return dict(pid=pid, img=img_file, mask=mask_file, boarder=boarder_file, label=label, shape=shape)

In [None]:
files = list(filter(lambda x:x[0]!='c',glob("*.mat")))

In [None]:
df = pd.DataFrame(list(extract_file(i) for i in tqdm(files)))

In [None]:
!du -sh mri_data/

In [None]:
df.sort_values(by=["pid","img"])

In [None]:
df["img_id"] = df.img.apply(lambda x:int(x.split('.')[0]))

In [None]:
df = df.query("shape=='512_512'").sort_values(by=["img_id"]).reset_index(drop=True)
df

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from ipywidgets import interact
from PIL import Image

In [None]:
def vis_patient(pid):
    sub_df = df.query(f"pid=='{pid}'").sort_values(by="img_id")
    img_arr = np.stack(list(np.load(f"mri_data/{i}") for i in sub_df.img))\
        .astype(np.float32)/1000
    mask_arr = np.stack(list(np.load(f"mri_data/{i}") for i in sub_df["mask"])).astype(np.float32)
    @interact
    def show_mri(i = (1,len(img_arr))):
        print(list(sub_df.img)[i-1])
        rgb_arr = np.stack([
          mask_arr[i-1],
          np.clip(img_arr[i-1]-mask_arr[i-1],0.,1.),
          img_arr[i-1],                  
        ], axis=-1)

        # rgb_arr = img_arr[i-1].astype(np.float32)
        # print(rgb_arr[200:230,200:230])
        display(plt.imshow(rgb_arr))

In [None]:
vis_patient('100360')

In [None]:
df.query("pid=='MR029209I'").sort_values(by="img_id")

In [None]:
plt.imshow(np.array(img))

In [None]:
plt.imshow(np.array(mask))

## Model Training

In [None]:
!pip install -q forgebox

In [None]:
!pip install -q pytorch-lightning

In [None]:
from forgebox.imports import *
import pytorch_lightning as pl

In [None]:
class mri_data(Dataset):
    def __init__(self, df: pd.DataFrame, data_dir: Path=Path("./mri_data")):
        super().__init__()
        self.df = df.reset_index(drop = True)
        self.data_dir = Path(data_dir)
    
    def __len__(self):
        return len(self.df)

    def __repr__(self):
        return f"MRI Dataset:\n\t{len(self.df.pid.unique())} patients, {len(self)} slices"

    def __getitem__(self,idx):
        row = dict(self.df.loc[idx])
        img = np.load(str(self.data_dir/(row["img"])))
        img = img/(img.max())
        mask = np.load(str(self.data_dir/(row["mask"])))
        return img[None, ...], mask[None, ...], row['label']-1


def split_by(
    df: pd.DataFrame,
    col: str,
    val_ratio: float=.2
):
    uniques = np.array(list(set(list(df[col]))))
    validation_ids = np.random.choice(
        uniques, size=int(len(uniques)*val_ratio), replace=False)
    val_slice = df[col].isin(validation_ids)
    return df[~val_slice].sample(frac=1.).reset_index(drop=True),\
        df[val_slice].reset_index(drop=True)

In [None]:
train_df, val_df = split_by(df, "pid")

In [None]:
len(train_df), len(val_df)

In [None]:
len(train_df.pid.unique()), len(val_df.pid.unique())

In [None]:
total_ds = mri_data(df)
train_ds = mri_data(train_df)
val_ds = mri_data(val_df)

In [None]:
train_ds, val_ds

In [None]:
x,y,z = train_ds[5]

### Mean and standard variation
> Of the entire dataset

In [None]:
all_x = []
for i in tqdm(range(len(total_ds))):
    x,yy,zz = total_ds[i]
    all_x.append(np.array([x.mean(), x.std()]))

In [None]:
all_arr = np.array(all_x)
x_mean, x_std = all_arr.mean(0)
x_mean, x_std

In [None]:
all_arr[:,0].min(), all_arr[:,0].max(),all_arr[:,1].min(), all_arr[:,1].max()

## Model

In [None]:
!pip install -q segmentation-models-pytorch

In [None]:
import segmentation_models_pytorch as smp

In [None]:
model = smp.Unet(
    "efficientnet-b5",
    encoder_weights="imagenet",
    in_channels=1,                  # model input channels (1 for grayscale images, 3 for RGB, etc.)
    classes=1, 
    )

### Test model pipeline

In [None]:
model(torch.FloatTensor(x)[None,...]).shape

In [None]:
??pl.LightningModule

In [None]:
class PlData(pl.LightningDataModule):
    def __init__(self, train_df, val_df, bs):
        super().__init__()
        self.bs = bs
        self.train_df = train_df
        self.val_df = val_df
        self.train_ds = mri_data(self.train_df)
        self.val_ds = mri_data(self.val_df)

    def train_dataloader(self):
        return DataLoader(self.train_ds, shuffle=True, batch_size=self.bs)

    def val_dataloader(self):
        return DataLoader(self.val_ds, shuffle=False, batch_size=self.bs)

In [None]:
class PlMRIModel(pl.LightningModule):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model
        self.sigmoid = nn.Sigmoid()
        self.crit = nn.BCEWithLogitsLoss()
        self.accuracy_f = pl.metrics.Accuracy()

    def forward(self, x):
        return self.base(x)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.base.parameters(), lr=1e-4)

    def training_step(self, batch, batch_idx):
        x,y,z = batch
        x = x.float(); y=y.float()
        y_ = self(x)
        loss = self.crit(y_, y)
        acc = self.accuracy_f(self.sigmoid(y_), y)

        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x,y,z = batch
        x = x.float(); y=y.float()
        y_ = self(x)
        loss = self.crit(y_, y)
        acc = self.accuracy_f(self.sigmoid(y_), y)

        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss

In [None]:
pl_data = PlData(train_df, val_df, bs=8)
pl_model = PlMRIModel(model)

In [None]:
logger = pl.loggers.TensorBoardLogger("tb_log")
early = pl.callbacks.EarlyStopping(monitor="val_acc")

In [None]:
trainer = pl.Trainer(
    logger=logger,
    callbacks=[early,],
    gpus=1,
    fast_dev_run=True)

In [None]:
trainer.fit(pl_model,pl_data)