Skip to content

Commit

Permalink
Merge pull request #76 from voxel51/patches-similarity
Browse files Browse the repository at this point in the history
Adding support for sorting patches views by similarity
  • Loading branch information
brimoor committed Apr 28, 2021
2 parents 2226b22 + 6eb4e31 commit 84593d2
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions fiftyone/brain/internal/core/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from fiftyone import ViewField as F
import fiftyone.core.brain as fob
import fiftyone.core.stages as fos
import fiftyone.core.patches as fop
import fiftyone.core.validation as fov
import fiftyone.zoo as foz

Expand Down Expand Up @@ -165,12 +166,39 @@ def sort_by_similarity(
sample_ids = sample_ids[keep]
embeddings = embeddings[keep]
ids = sample_ids
elif isinstance(samples, (fop.PatchesView, fop.EvaluationPatchesView)):
if (
isinstance(samples, fop.PatchesView)
and patches_field != samples.patches_field
):
raise ValueError(
"This patches view contains labels from field '%s', not "
"'%s'" % (samples.patches_field, patches_field)
)

if isinstance(
samples, fop.EvaluationPatchesView
) and patches_field not in (samples.gt_field, samples.pred_field):
raise ValueError(
"This evaluation patches view contains patches from "
"fields '%s' and '%s', not '%s'"
% (samples.gt_field, samples.pred_field, patches_field)
)

ids_map = samples._get_ids_map(patches_field)
keep = np.array([label_id in ids_map for label_id in label_ids])

label_ids = label_ids[keep]
embeddings = embeddings[keep]
sample_ids = np.array([ids_map[_id] for _id in label_ids])
ids = label_ids
else:
possible_ids = set(
l["label_id"]
for l in samples._get_selected_labels(fields=patches_field)
)
keep = np.array([_id in possible_ids for _id in label_ids])

sample_ids = sample_ids[keep]
label_ids = label_ids[keep]
embeddings = embeddings[keep]
Expand Down

0 comments on commit 84593d2

Please sign in to comment.