Skip to content

Commit

Permalink
sync support for labels (#2070)
Browse files Browse the repository at this point in the history
* more sync support for file paths + saved searches

* sync support for labels

* update sync prisma generator to support more than tags

* workey

* don't do illegal db migration

* use name as label id in explorer
  • Loading branch information
Brendonovich committed Feb 9, 2024
1 parent 6f28d8e commit 177b2a2
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 69 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 9 additions & 8 deletions core/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,14 @@ model Tag {
@@map("tag")
}

/// @relation(item: tag, group: object)
/// @relation(item: object, group: tag)
model TagOnObject {
tag_id Int
tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict)
object_id Int
object Object @relation(fields: [object_id], references: [id], onDelete: Restrict)
tag_id Int
tag Tag @relation(fields: [tag_id], references: [id], onDelete: Restrict)
date_created DateTime?
@@id([tag_id, object_id])
Expand All @@ -344,9 +344,9 @@ model TagOnObject {

//// Label ////

/// @shared(id: name)
model Label {
id Int @id @default(autoincrement())
pub_id Bytes @unique
name String @unique
date_created DateTime @default(now())
date_modified DateTime @default(now())
Expand All @@ -356,15 +356,16 @@ model Label {
@@map("label")
}

/// @relation(item: object, group: label)
model LabelOnObject {
date_created DateTime @default(now())
label_id Int
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
object_id Int
object Object @relation(fields: [object_id], references: [id], onDelete: Restrict)
label_id Int
label Label @relation(fields: [label_id], references: [id], onDelete: Restrict)
@@id([label_id, object_id])
@@map("label_on_object")
}
Expand Down
28 changes: 23 additions & 5 deletions core/src/api/labels.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::{invalidate_query, library::Library, object::media::thumbnail::get_indexed_thumb_key};

use sd_prisma::prisma::{label, label_on_object, object, SortOrder};
use sd_prisma::{
prisma::{label, label_on_object, object, SortOrder},
prisma_sync,
};
use sd_sync::OperationFactory;

use std::collections::BTreeMap;

Expand Down Expand Up @@ -117,12 +121,26 @@ pub(crate) fn mount() -> AlphaRouter<Ctx> {
"delete",
R.with2(library())
.mutation(|(_, library), label_id: i32| async move {
library
.db
let Library { db, sync, .. } = library.as_ref();

let label = db
.label()
.delete(label::id::equals(label_id))
.find_unique(label::id::equals(label_id))
.exec()
.await?;
.await?
.ok_or_else(|| {
rspc::Error::new(
rspc::ErrorCode::NotFound,
"Label not found".to_string(),
)
})?;

sync.write_op(
db,
sync.shared_delete(prisma_sync::label::SyncId { name: label.name }),
db.label().delete(label::id::equals(label_id)),
)
.await?;

invalidate_query!(library, "labels.list");

Expand Down
9 changes: 7 additions & 2 deletions core/src/object/media/media_processor/job.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl StatefulJob for MediaProcessorJobInit {
ctx: &WorkerContext,
data: &mut Option<Self::Data>,
) -> Result<JobInitOutput<Self::RunMetadata, Self::Step>, JobError> {
let Library { db, .. } = ctx.library.as_ref();
let Library { db, sync, .. } = ctx.library.as_ref();

let location_id = self.location.id;
let location_path =
Expand Down Expand Up @@ -186,6 +186,7 @@ impl StatefulJob for MediaProcessorJobInit {
location_path.clone(),
file_paths_for_labeling,
Arc::clone(db),
sync.clone(),
)
.await;

Expand Down Expand Up @@ -336,7 +337,11 @@ impl StatefulJob for MediaProcessorJobInit {
match ctx
.node
.image_labeller
.resume_batch(data.labeler_batch_token, Arc::clone(&ctx.library.db))
.resume_batch(
data.labeler_batch_token,
Arc::clone(&ctx.library.db),
ctx.library.sync.clone(),
)
.await
{
Ok(labels_rx) => labels_rx,
Expand Down
3 changes: 2 additions & 1 deletion core/src/object/media/media_processor/shallow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ const BATCH_SIZE: usize = 10;
pub async fn shallow(
location: &location::Data,
sub_path: &PathBuf,
library @ Library { db, .. }: &Library,
library @ Library { db, sync, .. }: &Library,
#[cfg(feature = "ai")] regenerate_labels: bool,
node: &Node,
) -> Result<(), JobError> {
Expand Down Expand Up @@ -116,6 +116,7 @@ pub async fn shallow(
location_path.clone(),
file_paths_for_labelling,
Arc::clone(db),
sync.clone(),
)
.await;

Expand Down
3 changes: 3 additions & 0 deletions crates/ai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ edition = { workspace = true }

[dependencies]
sd-prisma = { path = "../prisma" }
sd-core-sync = { path = "../../core/crates/sync" }
sd-sync = { path = "../sync" }
sd-utils = { path = "../utils" }
sd-file-path-helper = { path = "../file-path-helper" }

Expand All @@ -24,6 +26,7 @@ prisma-client-rust = { workspace = true }
reqwest = { workspace = true, features = ["stream", "native-tls-vendored"] }
rmp-serde = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["fs"] }
tokio-stream = { workspace = true }
Expand Down
20 changes: 15 additions & 5 deletions crates/ai/src/image_labeler/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ const PENDING_BATCHES_FILE: &str = "pending_image_labeler_batches.bin";
type ResumeBatchRequest = (
BatchToken,
Arc<PrismaClient>,
Arc<sd_core_sync::Manager>,
oneshot::Sender<Result<chan::Receiver<LabelerOutput>, ImageLabelerError>>,
);

Expand All @@ -53,6 +54,7 @@ pub(super) struct Batch {
pub(super) output_tx: chan::Sender<LabelerOutput>,
pub(super) is_resumable: bool,
pub(super) db: Arc<PrismaClient>,
pub(super) sync: Arc<sd_core_sync::Manager>,
}

#[derive(Serialize, Deserialize, Debug)]
Expand Down Expand Up @@ -165,6 +167,7 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
is_resumable: bool,
) -> (BatchToken, chan::Receiver<LabelerOutput>) {
let (tx, rx) = chan::bounded(usize::max(file_paths.len(), 1));
Expand All @@ -180,6 +183,7 @@ impl ImageLabeler {
output_tx: tx,
is_resumable,
db,
sync,
})
.await
.is_err()
Expand All @@ -201,8 +205,9 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> chan::Receiver<LabelerOutput> {
self.new_batch_inner(location_id, location_path, file_paths, db, false)
self.new_batch_inner(location_id, location_path, file_paths, db, sync, false)
.await
.1
}
Expand All @@ -214,8 +219,9 @@ impl ImageLabeler {
location_path: PathBuf,
file_paths: Vec<file_path_for_media_processor::Data>,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> (BatchToken, chan::Receiver<LabelerOutput>) {
self.new_batch_inner(location_id, location_path, file_paths, db, true)
self.new_batch_inner(location_id, location_path, file_paths, db, sync, true)
.await
}

Expand Down Expand Up @@ -284,11 +290,12 @@ impl ImageLabeler {
&self,
token: BatchToken,
db: Arc<PrismaClient>,
sync: Arc<sd_core_sync::Manager>,
) -> Result<chan::Receiver<LabelerOutput>, ImageLabelerError> {
let (tx, rx) = oneshot::channel();

self.resume_batch_tx
.send((token, db, tx))
.send((token, db, sync, tx))
.await
.expect("critical error: image labeler communication channel unexpectedly closed");

Expand Down Expand Up @@ -334,6 +341,7 @@ async fn actor_loop(
ResumeBatch(
BatchToken,
Arc<PrismaClient>,
Arc<sd_core_sync::Manager>,
oneshot::Sender<Result<chan::Receiver<LabelerOutput>, ImageLabelerError>>,
),
UpdateModel(
Expand All @@ -350,7 +358,8 @@ async fn actor_loop(

let mut msg_stream = pin!((
new_batches_rx.map(StreamMessage::NewBatch),
resume_batch_rx.map(|(token, db, done_tx)| StreamMessage::ResumeBatch(token, db, done_tx)),
resume_batch_rx
.map(|(token, db, sync, done_tx)| StreamMessage::ResumeBatch(token, db, sync, done_tx)),
update_model_rx.map(|(model, done_tx)| StreamMessage::UpdateModel(model, done_tx)),
done_rx.clone().map(StreamMessage::BatchDone),
shutdown_rx.map(StreamMessage::Shutdown)
Expand All @@ -376,7 +385,7 @@ async fn actor_loop(
}
}

StreamMessage::ResumeBatch(token, db, resume_done_tx) => {
StreamMessage::ResumeBatch(token, db, sync, resume_done_tx) => {
let resume_result = if let Some((batch, output_rx)) =
to_resume_batches.write().await.remove(&token).map(
|ResumableBatch {
Expand All @@ -390,6 +399,7 @@ async fn actor_loop(
Batch {
token,
db,
sync,
output_tx,
location_id,
location_path,
Expand Down
Loading

0 comments on commit 177b2a2

Please sign in to comment.