Skip to content

Commit

Permalink
Merge 62a7c67 into 87a55f1
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Feb 8, 2021
2 parents 87a55f1 + 62a7c67 commit eb43341
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 38 deletions.
20 changes: 10 additions & 10 deletions deepcell/layers/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def compute_output_shape(self, input_shape):
def call(self, inputs):
input_shape = self.in_shape
if self.data_format == 'channels_first':
x = K.arange(0, input_shape[1], dtype=K.floatx())
y = K.arange(0, input_shape[2], dtype=K.floatx())
x = K.arange(0, input_shape[1], dtype=inputs.dtype)
y = K.arange(0, input_shape[2], dtype=inputs.dtype)
else:
x = K.arange(0, input_shape[0], dtype=K.floatx())
y = K.arange(0, input_shape[1], dtype=K.floatx())
x = K.arange(0, input_shape[0], dtype=inputs.dtype)
y = K.arange(0, input_shape[1], dtype=inputs.dtype)

x = x / K.max(x)
y = y / K.max(y)
Expand Down Expand Up @@ -131,13 +131,13 @@ def call(self, inputs):
input_shape = self.in_shape

if self.data_format == 'channels_first':
z = K.arange(0, input_shape[1], dtype=K.floatx())
x = K.arange(0, input_shape[2], dtype=K.floatx())
y = K.arange(0, input_shape[3], dtype=K.floatx())
z = K.arange(0, input_shape[1], dtype=inputs.dtype)
x = K.arange(0, input_shape[2], dtype=inputs.dtype)
y = K.arange(0, input_shape[3], dtype=inputs.dtype)
else:
z = K.arange(0, input_shape[0], dtype=K.floatx())
x = K.arange(0, input_shape[1], dtype=K.floatx())
y = K.arange(0, input_shape[2], dtype=K.floatx())
z = K.arange(0, input_shape[0], dtype=inputs.dtype)
x = K.arange(0, input_shape[1], dtype=inputs.dtype)
y = K.arange(0, input_shape[2], dtype=inputs.dtype)

x = x / K.max(x)
y = y / K.max(y)
Expand Down
9 changes: 5 additions & 4 deletions deepcell/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def build(self, input_shape):
# trainable=False,
# dtype=self.dtype)

W = K.ones(kernel_shape, dtype=K.floatx())
W = W / K.cast(K.prod(K.int_shape(W)), dtype=K.floatx())
W = K.ones(kernel_shape, dtype=self.compute_dtype)
W = W / K.cast(K.prod(K.int_shape(W)), dtype=self.compute_dtype)
self.kernel = W
# self.set_weights([W])

Expand All @@ -166,6 +166,7 @@ def compute_output_shape(self, input_shape):
return tensor_shape.TensorShape(input_shape)

def _average_filter(self, inputs):
# Depthwise convolution on CPU is only supported for NHWC format
if self.data_format == 'channels_first':
inputs = K.permute_dimensions(inputs, pattern=[0, 2, 3, 1])
outputs = tf.nn.depthwise_conv2d(inputs, self.kernel, [1, 1, 1, 1],
Expand Down Expand Up @@ -329,8 +330,8 @@ def build(self, input_shape):
# trainable=False,
# dtype=self.dtype)

W = K.ones(kernel_shape, dtype=K.floatx())
W = W / K.cast(K.prod(K.int_shape(W)), dtype=K.floatx())
W = K.ones(kernel_shape, dtype=self.compute_dtype)
W = W / K.cast(K.prod(K.int_shape(W)), dtype=self.compute_dtype)
self.kernel = W
# self.set_weights([W])

Expand Down
44 changes: 31 additions & 13 deletions deepcell/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def categorical_crossentropy(y_true, y_pred, class_weights=None, axis=None, from
"""
# Note: tf.nn.softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1
if not from_logits:
Expand Down Expand Up @@ -83,6 +85,9 @@ def weighted_categorical_crossentropy(y_true, y_pred,
"""
if from_logits:
raise Exception('weighted_categorical_crossentropy cannot take logits')
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
n_classes = K.cast(n_classes, y_pred.dtype)
if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1
reduce_axis = [x for x in list(range(K.ndim(y_pred))) if x != axis]
Expand All @@ -91,11 +96,10 @@ def weighted_categorical_crossentropy(y_true, y_pred,
# manual computation of crossentropy
_epsilon = tf.convert_to_tensor(K.epsilon(), y_pred.dtype.base_dtype)
y_pred = tf.clip_by_value(y_pred, _epsilon, 1. - _epsilon)
y_true_cast = K.cast(y_true, K.floatx())
total_sum = K.sum(y_true_cast)
class_sum = K.sum(y_true_cast, axis=reduce_axis, keepdims=True)
class_weights = 1.0 / K.cast_to_floatx(n_classes) * tf.divide(total_sum, class_sum + 1.)
return - K.sum((y_true_cast * K.log(y_pred) * class_weights), axis=axis)
total_sum = K.sum(y_true)
class_sum = K.sum(y_true, axis=reduce_axis, keepdims=True)
class_weights = 1.0 / n_classes * tf.divide(total_sum, class_sum + 1.)
return - K.sum((y_true * K.log(y_pred) * class_weights), axis=axis)


def sample_categorical_crossentropy(y_true,
Expand All @@ -120,6 +124,8 @@ def sample_categorical_crossentropy(y_true,
"""
# Note: tf.nn.softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1
if not from_logits:
Expand Down Expand Up @@ -148,6 +154,8 @@ def dice_loss(y_true, y_pred, smooth=1):
Returns:
tensor: Output tensor.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
Expand All @@ -167,6 +175,8 @@ def discriminative_instance_loss(y_true, y_pred,
Returns:
tensor: Output tensor.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)

def temp_norm(ten, axis=None):
if axis is None:
Expand All @@ -180,11 +190,11 @@ def temp_norm(ten, axis=None):
# Compute variance loss
cells_summed = tf.tensordot(y_true, y_pred, axes=[axes, axes])
nonzeros = tf.math.count_nonzero(y_true, axis=axes)
n_pixels = K.cast(nonzeros, dtype=K.floatx()) + K.epsilon()
n_pixels = K.cast(nonzeros, dtype=y_pred.dtype) + K.epsilon()
n_pixels_expand = K.expand_dims(n_pixels, axis=1) + K.epsilon()
mu = tf.divide(cells_summed, n_pixels_expand)

delta_v = K.constant(delta_v, dtype=K.floatx())
delta_v = K.constant(delta_v, dtype=y_pred.dtype)
mu_tensor = tf.tensordot(y_true, mu, axes=[[channel_axis], [0]])
L_var_1 = y_pred - mu_tensor
L_var_2 = K.square(K.relu(temp_norm(L_var_1) - delta_v))
Expand All @@ -198,8 +208,8 @@ def temp_norm(ten, axis=None):

diff_matrix = tf.subtract(mu_b, mu_a)
L_dist_1 = temp_norm(diff_matrix)
L_dist_2 = K.square(K.relu(K.constant(2 * delta_d, dtype=K.floatx()) - L_dist_1))
diag = K.constant(0, dtype=K.floatx()) * tf.linalg.diag_part(L_dist_2)
L_dist_2 = K.square(K.relu(K.constant(2 * delta_d, dtype=y_pred.dtype) - L_dist_1))
diag = K.constant(0, dtype=y_pred.dtype) * tf.linalg.diag_part(L_dist_2)
L_dist_3 = tf.linalg.set_diag(L_dist_2, diag)
L_dist = K.mean(L_dist_3)

Expand Down Expand Up @@ -228,6 +238,9 @@ def weighted_focal_loss(y_true, y_pred, n_classes=3, gamma=2., axis=None, from_l
"""
if from_logits:
raise Exception('weighted_focal_loss cannot take logits')
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)
n_classes = K.cast(n_classes, y_pred.dtype)
if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1
reduce_axis = [x for x in list(range(K.ndim(y_pred))) if x != axis]
Expand All @@ -236,10 +249,9 @@ def weighted_focal_loss(y_true, y_pred, n_classes=3, gamma=2., axis=None, from_l
# manual computation of crossentropy
_epsilon = tf.convert_to_tensor(K.epsilon(), y_pred.dtype.base_dtype)
y_pred = tf.clip_by_value(y_pred, _epsilon, 1. - _epsilon)
y_true_cast = K.cast(y_true, K.floatx())
total_sum = K.sum(y_true_cast)
class_sum = K.sum(y_true_cast, axis=reduce_axis, keepdims=True)
class_weights = 1.0 / K.cast_to_floatx(n_classes) * tf.divide(total_sum, class_sum + 1.)
total_sum = K.sum(y_true)
class_sum = K.sum(y_true, axis=reduce_axis, keepdims=True)
class_weights = 1.0 / n_classes * tf.divide(total_sum, class_sum + 1.)
temp_loss = (K.pow(1. - y_pred, gamma) * K.log(y_pred) * class_weights)
focal_loss = - K.sum(y_true * temp_loss, axis=axis)
return focal_loss
Expand All @@ -258,6 +270,9 @@ def smooth_l1(y_true, y_pred, sigma=3.0, axis=None):
Returns:
The smooth L1 loss of ``y_pred`` w.r.t. ``y_true``.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)

if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1

Expand Down Expand Up @@ -289,6 +304,9 @@ def focal(y_true, y_pred, alpha=0.25, gamma=2.0, axis=None):
Returns:
float: The focal loss of ``y_pred`` w.r.t. ``y_true``.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = K.cast(y_true, y_pred.dtype)

if axis is None:
axis = 1 if K.image_data_format() == 'channels_first' else K.ndim(y_pred) - 1

Expand Down
1 change: 0 additions & 1 deletion deepcell/model_zoo/panopticnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def PanopticNet(backbone,
If equal to 1, assumes 2D data.
temporal_mode: Mode of temporal convolution. Choose from
``{'conv','lstm','gru', None}``.
num_semantic_heads (int): Total number of semantic heads to build.
num_semantic_classes (list): Number of semantic classes
for each semantic head.
norm_method (str): Normalization method to use with the
Expand Down
3 changes: 0 additions & 3 deletions deepcell/model_zoo/panopticnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ def test_panopticnet(self, pooling, location, frames_per_batch,
location=location,
pooling=pooling,
upsample_type=upsample_type,
num_semantic_heads=len(num_semantic_classes),
num_semantic_classes=num_semantic_classes,
use_imagenet=False,
)
Expand Down Expand Up @@ -395,7 +394,6 @@ def test_panopticnet_bad_input(self):
norm_method=norm_method,
location=True,
pooling='avg',
num_semantic_heads=len(num_semantic_classes),
num_semantic_classes=num_semantic_classes,
use_imagenet=False,
)
Expand All @@ -410,7 +408,6 @@ def test_panopticnet_bad_input(self):
norm_method=norm_method,
location=True,
pooling='avg',
num_semantic_heads=len(num_semantic_classes),
num_semantic_classes=num_semantic_classes,
use_imagenet=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@
" backbone='resnet50',\n",
" input_shape=X_train.shape[1:],\n",
" norm_method='std',\n",
" num_semantic_heads=3,\n",
" num_semantic_classes=[1, 1, 2], # inner distance, outer distance, fgbg\n",
" location=True, # should always be true\n",
" include_top=True)"
Expand Down Expand Up @@ -393,11 +392,11 @@
"\n",
"\n",
"def semantic_loss(n_classes):\n",
" def _semantic_loss(y_pred, y_true):\n",
" def _semantic_loss(y_true, y_pred):\n",
" if n_classes > 1:\n",
" return 0.01 * losses.weighted_categorical_crossentropy(\n",
" y_pred, y_true, n_classes=n_classes)\n",
" return MSE(y_pred, y_true)\n",
" y_true, y_pred, n_classes=n_classes)\n",
" return MSE(y_true, y_pred)\n",
" return _semantic_loss\n",
"\n",
"\n",
Expand Down Expand Up @@ -425,7 +424,7 @@
"source": [
"## Train the model\n",
"\n",
"Call `fit_generator` on the compiled model, along with a default set of callbacks."
"Call `fit` on the compiled model, along with a default set of callbacks."
]
},
{
Expand Down Expand Up @@ -493,7 +492,7 @@
" monitor='val_loss',\n",
" verbose=1)\n",
"\n",
"loss_history = model.fit_generator(\n",
"loss_history = model.fit(\n",
" train_data,\n",
" steps_per_epoch=train_data.y.shape[0] // batch_size,\n",
" epochs=n_epoch,\n",
Expand Down Expand Up @@ -808,4 +807,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

0 comments on commit eb43341

Please sign in to comment.