In [None]:
# This notebook trains a featurenet model on the TissueNet dataset

### Load the training data

In [2]:
import os

# create folder for this set of experiments
experiment_folder = "featurenet"
MODEL_DIR = os.path.join("/data/analyses", experiment_folder)
NPZ_DIR = "/data/npz_data/20201018_freeze/"
LOG_DIR = '/data/logs'

if not os.path.isdir(MODEL_DIR):
    os.makedirs(MODEL_DIR)

### Set up filepath constants

### Set up training parameters

In [3]:
from tensorflow.keras.optimizers import SGD
from deepcell.utils.train_utils import rate_scheduler

model = '3'

conv_model_name = 'featurenet_split_{}_100_epochs'.format(model)
npz_name = "20201018_multiplex_seed_{}_train_256x256.npz".format(model)
DATA_FILE = os.path.join(NPZ_DIR, npz_name)

n_epoch = 25  # Number of training epochs
norm_method = None  # data normalization
receptive_field = 61  # should be adjusted for the scale of the data

optimizer = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

lr_sched = rate_scheduler(lr=0.01, decay=0.99)

# FC training settings
n_skips = 3  # number of skip-connections (only for FC training)
batch_size = 1  # FC training uses 1 image per batch

# Transformation settings
transform = 'pixelwise'
dilation_radius = 1  # change dilation radius for edge dilation
separate_edge_classes = False  # break edges into cell-background edge, cell-cell edge
n_features = 4 if separate_edge_classes else 3

### Next, Create a model for the edge/interior segmentation

#### Instantiate the segmentation transform model

In [4]:
from deepcell import model_zoo

conv_model = model_zoo.bn_feature_net_skip_2D(
    receptive_field=receptive_field,
    n_skips=n_skips,
    n_features=n_features,
    norm_method=norm_method,
    n_conv_filters=32,
    n_dense_filters=128,
    last_only=False,
    input_shape=(256, 256, 2))

W1115 03:11:12.391070 140369887557440 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


#### Train the segmentation transform model

In [5]:
from deepcell.training import train_model_conv

conv_model = train_model_conv(
    model=conv_model,
    dataset=DATA_FILE,  # full path to npz file
    model_name=conv_model_name,
    test_size=0.1,
    seed=1,
    transform=transform,
    dilation_radius=dilation_radius,
    separate_edge_classes=separate_edge_classes,
    optimizer=optimizer,
    batch_size=batch_size,
    n_epoch=n_epoch,
    log_dir=LOG_DIR,
    model_dir=MODEL_DIR,
    lr_sched=lr_sched,
    rotation_range=180,
    flip=True,
    shear=False,
    zoom_range=(0.7, 1/0.7))

X_train shape: (2384, 256, 256, 2)
y_train shape: (2384, 256, 256, 1)
X_test shape: (265, 256, 256, 2)
y_test shape: (265, 256, 256, 1)
Output Shape: (None, 256, 256, 3)
Number of Classes: 3
Training on 1 GPUs
Epoch 1/25


W1115 03:11:51.598669 140369887557440 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


   1/2384 [..............................] - ETA: 7:34:05 - loss: 5.5636 - model_loss: 1.3106 - model_1_loss: 1.4279 - model_2_loss: 1.4038 - model_3_loss: 1.3852 - model_acc: 0.2765 - model_1_acc: 0.4211 - model_2_acc: 0.1967 - model_3_acc: 0.1994

W1115 03:12:02.653802 140369887557440 callbacks.py:257] Method (on_train_batch_end) is slow compared to the batch update (0.224747). Check your callbacks.


Epoch 00001: val_loss improved from inf to 3.17922, saving model to /data/analyses/featurenet/featurenet_split_3_redo.h5
Epoch 2/25
Epoch 00002: val_loss did not improve from 3.17922
Epoch 3/25
Epoch 00003: val_loss improved from 3.17922 to 2.88583, saving model to /data/analyses/featurenet/featurenet_split_3_redo.h5
Epoch 4/25
Epoch 00004: val_loss improved from 2.88583 to 2.78427, saving model to /data/analyses/featurenet/featurenet_split_3_redo.h5
Epoch 5/25
Epoch 00005: val_loss did not improve from 2.78427
Epoch 6/25
Epoch 00006: val_loss did not improve from 2.78427
Epoch 7/25
Epoch 00007: val_loss did not improve from 2.78427
Epoch 8/25
Epoch 00008: val_loss did not improve from 2.78427
Epoch 9/25
Epoch 00009: val_loss did not improve from 2.78427
Epoch 10/25
Epoch 00010: val_loss improved from 2.78427 to 2.74173, saving model to /data/analyses/featurenet/featurenet_split_3_redo.h5
Epoch 11/25
Epoch 00011: val_loss did not improve from 2.74173
Epoch 12/25
Epoch 00012: val_loss d

Epoch 22/25
Epoch 00022: val_loss did not improve from 2.58759
Epoch 23/25
Epoch 00023: val_loss did not improve from 2.58759
Epoch 24/25
Epoch 00024: val_loss did not improve from 2.58759
Epoch 25/25
Epoch 00025: val_loss did not improve from 2.58759
