Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug/tf upsampling #208

Merged
merged 2 commits into from Oct 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions sleap/config/active.yaml
Expand Up @@ -145,6 +145,24 @@ learning:

inference:

- name: conf_job
label: Node (confmap) Training Profile
type: list
default: a
options: a,b,c

- name: paf_job
label: Edge (paf) Training Profile
type: list
default: a
options: a,b,c

- name: centroid_job
label: Centroid Training Profile
type: list
default: a
options: a,b,c

- name: _predict_frames
label: Predict On
type: list
Expand Down
47 changes: 32 additions & 15 deletions sleap/gui/active.py
Expand Up @@ -19,6 +19,9 @@
from PySide2 import QtWidgets, QtCore


SELECT_FILE_OPTION = "Select a training profile file..."


class ActiveLearningDialog(QtWidgets.QDialog):
"""Active learning dialog.

Expand Down Expand Up @@ -49,6 +52,10 @@ def __init__(
self.labels_filename = labels_filename
self.labels = labels
self.mode = mode
self._job_filter = None

if self.mode == "inference":
self._job_filter = lambda job: job.is_trained

print(f"Number of frames to train on: {len(labels.user_labeled_frames)}")

Expand Down Expand Up @@ -162,6 +169,13 @@ def _rebuild_job_options(self):
# list default profiles
find_saved_jobs(profile_dir, self.job_options)

# Apply any filters
if self._job_filter:
for model_type, jobs_list in self.job_options.items():
self.job_options[model_type] = [
(path, job) for (path, job) in jobs_list if self._job_filter(job)
]

def _update_job_menus(self, init: bool = False):
"""Updates the menus with training profile options.

Expand All @@ -176,9 +190,11 @@ def _update_job_menus(self, init: bool = False):
if model_type not in self.job_options:
self.job_options[model_type] = []
if init:
field.currentIndexChanged.connect(
lambda idx, mt=model_type: self._update_from_selected_job(mt, idx)
)

def menu_action(idx, mt=model_type, field=field):
self._update_from_selected_job(mt, idx, field)

field.currentIndexChanged.connect(menu_action)
else:
# block signals so we can update combobox without overwriting
# any user data with the defaults from the profile
Expand Down Expand Up @@ -365,6 +381,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]:
for model_type in self._get_model_types_to_use():
job, _ = self._get_current_job(model_type)

if job is None:
continue

if job.model.output_type != ModelOutputType.CENTROIDS:
# update training job from params in form
trainer = job.trainer
Expand Down Expand Up @@ -499,8 +518,9 @@ def _option_list_from_jobs(self, model_type: ModelOutputType):
"""Returns list of menu options for given model type."""
jobs = self.job_options[model_type]
option_list = [name for (name, job) in jobs]
option_list.append("")
option_list.append("---")
option_list.append("Select a training profile file...")
option_list.append(SELECT_FILE_OPTION)
return option_list

def _add_job_file(self, model_type):
Expand Down Expand Up @@ -548,9 +568,10 @@ def _add_job_file_to_list(self, filename: str, model_type: ModelOutputType):
text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}."
).exec_()

def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
def _update_from_selected_job(self, model_type: ModelOutputType, idx: int, field):
"""Updates dialog settings after user selects a training profile."""
jobs = self.job_options[model_type]
field_text = field.currentText()
if idx == -1:
return
if idx < len(jobs):
Expand All @@ -569,17 +590,13 @@ def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
self.form_widget.set_form_data(training_params)

# is the model already trained?
has_trained = False
final_model_filename = job.final_model_filename
if final_model_filename is not None:
if os.path.exists(os.path.join(job.save_dir, final_model_filename)):
has_trained = True
is_trained = job.is_trained
field_name = f"_use_trained_{str(model_type)}"
# update "use trained" checkbox
self.form_widget.fields[field_name].setEnabled(has_trained)
self.form_widget[field_name] = has_trained
else:
# last item is "select file..."
# update "use trained" checkbox if present
if field_name in self.form_widget.fields:
self.form_widget.fields[field_name].setEnabled(is_trained)
self.form_widget[field_name] = is_trained
elif field_text == SELECT_FILE_OPTION:
self._add_job_file(model_type)


Expand Down
90 changes: 46 additions & 44 deletions sleap/nn/peakfinding_tf.py
Expand Up @@ -45,6 +45,49 @@ def impeaksnms_tf(I, min_thresh=0.3):

return inds, peak_vals

def upsample_peaks(unrolled_confmaps, peaks, h, w, channel_sample_ind, upsample_factor, win_size):
offset = (win_size - 1) / 2

# Get the boxes coordinates centered on the peaks, normalized to image
# coordinates
box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32))
top_left = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([-offset, -offset], dtype="float32")
) / (h - 1.0)
bottom_right = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([offset, offset], dtype="float32")
) / (w - 1.0)
boxes = tf.concat([top_left, bottom_right], axis=1)

small_windows = tf.image.crop_and_resize(
unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size]
)

# Upsample cropped windows
windows = tf.image.resize_bicubic(
small_windows, [upsample_factor * win_size, upsample_factor * win_size]
)

windows = tf.squeeze(windows)

# Find global maximum of each window
windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2)

# Adjust back to resolution before upsampling
windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(
upsample_factor, tf.float32
)

# Convert to offsets relative to the original peaks (center of cropped windows)
windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2)
windows_offsets = tf.pad(
windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0
) # (nc, 4)

# Apply offsets
return tf.cast(peaks, tf.float32) + windows_offsets

def find_peaks_tf(
confmaps,
Expand All @@ -68,54 +111,13 @@ def find_peaks_tf(
sample_ind = tf.floordiv(channel_sample_ind, c)

peaks = tf.concat([sample_ind, y, x, channel_ind], axis=1) # (nc, 4)

# If we have run prediction on low res and need to upsample the peaks
# to a higher resolution. Compute sub-pixel accurate peaks
# from these approximate peaks and return the upsampled sub-pixel peaks.
if upsample_factor > 1:

offset = (win_size - 1) / 2

# Get the boxes coordinates centered on the peaks, normalized to image
# coordinates
box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32))
top_left = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([-offset, -offset], dtype="float32")
) / (h - 1.0)
bottom_right = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([offset, offset], dtype="float32")
) / (w - 1.0)
boxes = tf.concat([top_left, bottom_right], axis=1)

small_windows = tf.image.crop_and_resize(
unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size]
)

# Upsample cropped windows
windows = tf.image.resize_bicubic(
small_windows, [upsample_factor * win_size, upsample_factor * win_size]
)

windows = tf.squeeze(windows)

# Find global maximum of each window
windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2)

# Adjust back to resolution before upsampling
windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(
upsample_factor, tf.float32
)

# Convert to offsets relative to the original peaks (center of cropped windows)
windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2)
windows_offsets = tf.pad(
windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0
) # (nc, 4)

# Apply offsets
peaks = tf.cast(peaks, tf.float32) + windows_offsets
peaks = tf.cond(tf.less(tf.shape(peaks)[0], 1),
lambda: upsample_peaks(unrolled_confmaps, peaks, h, w, channel_sample_ind, upsample_factor, win_size),
lambda: tf.cast(peaks, tf.float32))

return peaks, peak_vals

Expand Down
8 changes: 8 additions & 0 deletions sleap/nn/training.py
Expand Up @@ -625,6 +625,14 @@ class TrainingJob:
newest_model_filename: Union[str, None] = None
final_model_filename: Union[str, None] = None

@property
def is_trained(self):
if self.final_model_filename is not None:
path = os.path.join(self.save_dir, self.final_model_filename)
if os.path.exists(path):
return True
return False

@staticmethod
def save_json(training_job: "TrainingJob", filename: str):
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/gui/test_active.py
Expand Up @@ -31,6 +31,19 @@ def test_active_gui(qtbot, centered_pair_labels):
assert ModelOutputType.PART_AFFINITY_FIELD not in jobs


def test_inference_gui(qtbot, centered_pair_labels):
win = ActiveLearningDialog(
labels_filename="foo.json", labels=centered_pair_labels, mode="inference"
)
win.show()
qtbot.addWidget(win)

# There aren't any trained models, so there should be no options shown for
# inference
jobs = win._get_current_training_jobs()
assert len(jobs) == 0


def test_make_default_training_jobs():
jobs = make_default_training_jobs()

Expand Down