In [14]:
import os
import glob
import numpy as np
import pandas as pd
import radio
from radio.batchflow import FilesIndex, Dataset, Pipeline
from radio import CTImagesMaskedBatch as CTIMB
from radio_utils import show_slices, get_nodules_pixel_coords, num_of_cancerous_pixels
from radio.pipelines import split_dump

from config import config
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [32, 32]

# import utils
import warnings
warnings.filterwarnings("ignore")

%matplotlib inline

In [15]:
PATH_TO_FOLDER = config['luna_raw']
cancerous_folder, non_cancerous_folder = '../jfr_pp/test/positive', '../jfr_pp/test/negative'

cancer_index = FilesIndex(path=os.path.join(cancerous_folder, '*'), dirs=True)
non_cancer_index = FilesIndex(path=os.path.join(non_cancerous_folder, '*'), dirs=True)
cancer_set = Dataset(cancer_index, batch_class=CTIMB)
non_cancer_set = Dataset(non_cancer_index, batch_class=CTIMB)

from radio.pipelines import combine_crops
crops_sampling = combine_crops(cancer_set, non_cancer_set, batch_sizes=(5, 5))

In [16]:
from radio.models import Keras3DUNet
from radio.models.keras.losses import dice_loss

unet_config = dict(
    input_shape = (1, 64, 64, 64),
    num_targets = 1,
    loss= dice_loss
)

In [17]:
from radio.batchflow import F

train_unet_pipeline = (
    combine_crops(cancer_set, non_cancer_set, batch_sizes=(4, 6))
    .init_model(
        name='3dunet', model_class=Keras3DUNet,
        config=unet_config, mode='static'
    )
    .train_model(
        name='3dunet',
        x=F(CTIMB.unpack, component='images', data_format='channels_first'),
        y=F(CTIMB.unpack, component='masks', data_format='channels_first')
    )
)

In [18]:
%%time

N_ITERS = 1

# workflow = train_unet_pipeline << 
for i in range (N_ITERS):
    print(i)
    train_unet_pipeline.next_batch(1)
# train_unet_pipeline.run()
keras_unet = train_unet_pipeline.get_model_by_name('3dunet')
keras_unet.save('./model/test.h5')

0
CPU times: user 3min 38s, sys: 17.5 s, total: 3min 56s
Wall time: 59.4 s
