## Run the Model on an Entire Image

The images we have trained on have been cut into smaller sizes for annotation purposes.  Now that we have a trained model, we can run it on the entire raw image.  Because Z-stacks are quite large, we use a wrapper function `process_whole_image` to slice the raw image into several crops, and stitch the predictions back into place.

### Load the raw images from disk

In [None]:
from deepcell.utils.data_utils import load_training_images_3d

whole_images = load_training_images_3d(
    '/data/data/cells/MouseBrain/generic',
    training_direcs=['set0', 'set1'],
    num_frames=30,
    raw_image_direc='raw',
    channel_names=['DAPI'],
    image_size=(1024, 1024))

### Define the input shape

The whole images must be padded so the model output will be the same size as the input.

Use `get_cropped_input_shape` to instantiate the model.

In [None]:
from deepcell.running import get_cropped_input_shape

# get the size of each cropped Z-stack
cropped_input = get_cropped_input_shape(
    whole_images, num_crops=4, receptive_field=61)

print('Whole Image shape:', whole_images.shape[1:])
print('Cropped Input Shape:', cropped_input)

### Instantiate the full-sized model

Re-create the model with the same parameters used during training except with the new `cropped_input_shape`.

In [None]:
# Example training parameters
frames_per_batch = 3
n_skips = 3
model_dir = os.path.join(os.getcwd(), 'models')

# Re-instantiate the foreground/background model
fgbg_model = bn_feature_net_skip_3D(
    receptive_field=61,
    n_skips=n_skips,
    n_features=2,  # segmentation mask (is_cell, is_not_cell)
    n_frames=frames_per_batch,
    input_shape=cropped_input,
    n_conv_filters=32,
    n_dense_filters=128,
    last_only=False)

# Load the FGBG weights
fgbg_weights_file = '2018-09-15_MouseBrain_3d_nuclear_fgbg.h5'  # use custom file
fgbg_weights_file = os.path.join(model_dir, fgbg_weights_file)
fgbg_model.load_weights(fgbg_weights_file)


# Re-instatiate the conv model
run_conv_model = bn_feature_net_skip_3D(
    fgbg_model=fgbg_model,
    n_features=4,  # number of output classes
    n_skips=n_skips,
    n_frames=frames_per_batch,
    input_shape=cropped_input,
    n_conv_filters=32,
    n_dense_filters=128,
    last_only=True)

# Load the conv weights
conv_weights_file = '2018-09-15_MouseBrain_3d_nuclear_conv.h5'  # use custom file
conv_weights_file = os.path.join(model_dir, conv_weights_file)
run_conv_model.load_weights(conv_weights_file)

### Process the entire image

Use the built-in function `process_whole_image` to iteratively predict and stitch together each of the slices of the large image.

In [None]:
from deepcell.running import process_whole_image

output = process_whole_image(
    model=run_conv_model,
    images=whole_images,
    num_crops=4,
    receptive_field=61)
print(output.shape)

### Plot the results

In [None]:
import matplotlib.pyplot as plt

index = 1
frame = 10

fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(15, 15), sharex=True, sharey=True)
ax = axes.ravel()

ax[0].imshow(whole_images[index, frame, ..., 0], cmap='gray')
ax[0].set_title('Source Image')

ax[1].imshow(output[index, frame, ..., 0] + output[index, frame, ..., 1], cmap='jet')
ax[1].set_title('Edge Segmentation Prediction')

ax[2].imshow(output[index, frame, ..., 2], cmap='jet')
ax[2].set_title('Interior Segmentation Prediction')

ax[3].imshow(np.argmax(output[index, frame, ...], axis=-1), cmap='jet')
ax[3].set_title('Argmax Prediction')

fig.tight_layout()
plt.show()

In [None]:
from deepcell.utils.plot_utils import get_js_video
from IPython.display import HTML

HTML(get_js_video(output, batch=0, channel=2))