Skip to content

Commit

Permalink
Resolve saved views when serializing datasets (#3231)
Browse files Browse the repository at this point in the history
* resolve saved views

* fix get query ids

* tweaks
  • Loading branch information
benjaminpkane authored Jun 27, 2023
1 parent fcaccba commit 24c1f20
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
9 changes: 5 additions & 4 deletions app/packages/core/src/components/Actions/similar/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,28 @@ export const getQueryIds = async (

const selectedSamples = await snapshot.getPromise(fos.selectedSamples);
const isPatches = await snapshot.getPromise(fos.isPatchesView);
const modal = await snapshot.getPromise(fos.modalSample);

if (isPatches) {
if (selectedSamples.size) {
return [...selectedSamples].map((id) => {
const sample = fos.getSample(id);
if (sample) {
return sample.sample[labels_field]._id;
return sample.sample[labels_field]._id as string;
}

throw new Error("sample not found");
});
}

return modal.sample[labels_field]._id;
return (await snapshot.getPromise(fos.modalSample)).sample[labels_field]
._id as string;
}

if (selectedSamples.size) {
return [...selectedSamples];
}

return modal.id;
return await snapshot.getPromise(fos.modalSampleId);
};

export const useSortBySimilarity = (close) => {
Expand All @@ -77,6 +77,7 @@ export const useSortBySimilarity = (close) => {
const queryIds = parameters.query
? null
: await getQueryIds(snapshot, parameters.brainKey);

const view = await snapshot.getPromise(fos.view);
const subscription = await snapshot.getPromise(fos.stateSubscription);

Expand Down
15 changes: 14 additions & 1 deletion fiftyone/server/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import typing as t
from dataclasses import asdict
from datetime import date, datetime
from enum import Enum
import os
import typing as t

import eta.core.serial as etas
import eta.core.utils as etau
Expand Down Expand Up @@ -319,6 +320,7 @@ async def resolver(
dataset_name=name,
serialized_view=view,
saved_view_slug=saved_view_slug,
dicts=False,
)


Expand Down Expand Up @@ -542,6 +544,7 @@ async def serialize_dataset(
dataset_name: str,
serialized_view: BSONArray,
saved_view_slug: t.Optional[str] = None,
dicts=True,
) -> Dataset:
def run():
dataset = fod.load_dataset(dataset_name)
Expand Down Expand Up @@ -587,6 +590,16 @@ def run():
if dataset.media_type == fom.GROUP:
data.group_slice = collection.group_slice

if dicts:
saved_views = []
for view in data.saved_views:
view_dict = asdict(view)
view_dict["view_name"] = view.view_name()
view_dict["stage_dicts"] = view.stage_dicts()
saved_views.append(view_dict)

data.saved_views = saved_views

for brain_method in data.brain_methods:
try:
type = brain_method.config.type().value
Expand Down

0 comments on commit 24c1f20

Please sign in to comment.