diff --git a/integration/test_collection_batch.py b/integration/test_collection_batch.py index c6ef7f28d..5bed1dcbb 100644 --- a/integration/test_collection_batch.py +++ b/integration/test_collection_batch.py @@ -202,3 +202,24 @@ def test_error_reset(batch_collection: BatchCollection) -> None: assert len(errs) == 1 assert errs[0].object_.properties is not None assert errs[0].object_.properties["name"] == 1 + + +def test_refs_and_objects(batch_collection: BatchCollection) -> None: + """Test that references are not added before the source object is added.""" + col = batch_collection() + uuids = [uuid.uuid4() for _ in range(10)] + with col.batch.fixed_size(1, concurrent_requests=1) as batch: + for uid in uuids: + batch.add_object(properties={}, uuid=uid) + batch.add_reference( + from_uuid=uuids[-1], + from_property="test", + to=uuids[-1], + ) + + assert len(col.batch.failed_objects) == 0 + assert len(col.batch.failed_references) == 0 + + obj = col.query.fetch_object_by_id(uuids[-1], return_references=QueryReference(link_on="test")) + assert "test" in obj.references + assert obj.references["test"].objects[0].uuid == uuids[-1] diff --git a/profiling/test_refs.py b/profiling/test_refs.py new file mode 100644 index 000000000..1fd4509b9 --- /dev/null +++ b/profiling/test_refs.py @@ -0,0 +1,279 @@ +import datetime +import random +import uuid as uuid_lib +from dataclasses import dataclass, field +from typing import List, Dict, Optional, Any + +import weaviate +import weaviate.classes as wvc + +schema = { + "classes": [ + { + "class": "Author", + "properties": [{"dataType": ["string"], "name": "name"}], + "vectorizer": "none", + }, + { + "class": "Paragraph", + "properties": [ + {"dataType": ["text"], "name": "contents"}, + {"dataType": ["Paragraph"], "name": "hasParagraphs"}, + {"dataType": ["Author"], "name": "author"}, + ], + "vectorizer": "none", + }, + { + "class": "Article", + "properties": [ + {"dataType": ["string"], "name": "title"}, + {"dataType": ["Paragraph"], "name": "hasParagraphs"}, + {"dataType": ["date"], "name": "datePublished"}, + {"dataType": ["Author"], "name": "author"}, + {"dataType": ["string"], "name": "somestring"}, + {"dataType": ["int"], "name": "counter"}, + ], + "vectorizer": "none", + }, + ] +} + + +@dataclass(frozen=True) +class Reference: + to_class: str + to_uuid: uuid_lib.UUID + + +@dataclass +class DataObject: + properties: Dict[str, Any] + class_name: str + uuid: uuid_lib.UUID + + +@dataclass +class Author: + name: str + uuid: uuid_lib.UUID = field(init=False) + class_name: str = field(init=False) + + def to_data_object(self) -> DataObject: + return DataObject({"name": self.name}, self.class_name, self.uuid) + + def __post_init__(self) -> None: + self.uuid = uuid_lib.uuid4() + self.class_name = "Author" + + +@dataclass +class Paragraph: + contents: str + author: Reference + hasParagraphs: Optional[Reference] + uuid: uuid_lib.UUID = field(init=False) + class_name: str = field(init=False) + + def to_data_object(self) -> DataObject: + return DataObject({"contents": self.contents}, self.class_name, self.uuid) + + def __post_init__(self) -> None: + self.uuid = uuid_lib.uuid4() + self.class_name = "Paragraph" + + +@dataclass +class Article: + title: str + datePublished: str + somestring: str + counter: int + author: Reference + hasParagraphs: Reference + uuid: uuid_lib.UUID = field(init=False) + class_name: str = field(init=False) + + def to_data_object(self) -> DataObject: + return DataObject( + {"title": self.title, "datePublished": self.datePublished}, self.class_name, self.uuid + ) + + def __post_init__(self) -> None: + self.uuid = uuid_lib.uuid4() + self.class_name = "Article" + + +def test_stress() -> None: + random.seed(0) + client = weaviate.connect_to_local() + for col in schema["classes"]: + client.collections.delete(str(col["class"])) + + authors = client.collections.create_from_dict(schema["classes"][0]) + paragraphs = client.collections.create_from_dict(schema["classes"][1]) + articles = client.collections.create_from_dict(schema["classes"][2]) + + author_data = create_authors(100) + paragraph_data = create_paragraphs(num_paragraphs=200, authors=author_data) + article_data = create_articles(3000, author_data, paragraph_data) + + add_authors(client, author_data) + add_paragraphs(client, paragraph_data) + add_articles(client, article_data) + + assert len(authors) == len(author_data) + assert len(paragraphs) == len(paragraph_data) + assert len(articles) == len(article_data) + + # verify references + for article in article_data: + article_weav = articles.query.fetch_object_by_id( + article.uuid, + return_references=[ + wvc.query.QueryReference(link_on="hasParagraphs"), + wvc.query.QueryReference(link_on="author"), + ], + ) + assert article_weav.uuid == article.uuid + assert ( + article_weav.references["hasParagraphs"].objects[0].uuid + == article.hasParagraphs.to_uuid + ) + assert article_weav.references["author"].objects[0].uuid == article.author.to_uuid + + for i, paragraph in enumerate(paragraph_data): + para_weav = paragraphs.query.fetch_object_by_id( + paragraph.uuid, + return_references=[ + wvc.query.QueryReference(link_on="hasParagraphs"), + wvc.query.QueryReference(link_on="author"), + ], + ) + assert para_weav.uuid == paragraph.uuid + assert "author" in para_weav.references, i + assert para_weav.references["author"].objects[0].uuid == paragraph.author.to_uuid, i + if paragraph.hasParagraphs is not None: + assert ( + para_weav.references["hasParagraphs"].objects[0].uuid + == paragraph.hasParagraphs.to_uuid + ) + else: + assert "hasParagraphs" not in para_weav.references + + for col in schema["classes"]: + client.collections.delete(str(col["class"])) + + +def add_authors(client: weaviate.WeaviateClient, authors: List[Author]) -> None: + with client.batch.dynamic() as batch: + for author in authors: + data_object = author.to_data_object() + batch.add_object( + collection=data_object.class_name, + properties=data_object.properties, + uuid=data_object.uuid, + ) + assert len(client.batch.failed_objects) == 0 + + +def add_paragraphs(client: weaviate.WeaviateClient, paragraphs: List[Paragraph]) -> None: + with client.batch.dynamic() as batch: + for paragraph in paragraphs: + data_object = paragraph.to_data_object() + batch.add_object( + collection=data_object.class_name, + properties=data_object.properties, + uuid=data_object.uuid, + ) + batch.add_reference( + from_uuid=str(paragraph.uuid), + from_property="author", + to=str(paragraph.author.to_uuid), + from_collection="Paragraph", + ) + if paragraph.hasParagraphs is not None: + batch.add_reference( + from_uuid=str(paragraph.uuid), + from_property="hasParagraphs", + to=str(paragraph.hasParagraphs.to_uuid), + from_collection="Paragraph", + ) + assert len(client.batch.failed_references) == 0 + assert len(client.batch.failed_objects) == 0 + + +def add_articles(client: weaviate.WeaviateClient, articles: List[Article]) -> None: + with client.batch.dynamic() as batch: + for article in articles: + data_object = article.to_data_object() + batch.add_object( + collection=data_object.class_name, + properties=data_object.properties, + uuid=data_object.uuid, + ) + batch.add_reference( + str(article.uuid), + from_property="author", + to=str(article.author.to_uuid), + from_collection="Article", + ) + batch.add_reference( + str(article.uuid), + from_property="hasParagraphs", + to=str(article.hasParagraphs.to_uuid), + from_collection="Article", + ) + assert len(client.batch.failed_references) == 0 + assert len(client.batch.failed_objects) == 0 + + +def create_authors(num_authors: int) -> List[Author]: + authors: List[Author] = [Author(f"{i}") for i in range(num_authors)] + return authors + + +def create_paragraphs(num_paragraphs: int, authors: List[Author]) -> List[Paragraph]: + paragraphs: List[Paragraph] = [] + for i in range(num_paragraphs): + content: str = f"{i} {i} {i} {i}" + + paragraph_to_reference: Optional[Paragraph] = None + if len(paragraphs) > 0 and i % 2 == 0: + paragraph_to_reference = paragraphs[i % len(paragraphs)] + author_to_reference: Author = authors[0] + paragraphs.append( + Paragraph( + content, + Reference("Author", author_to_reference.uuid), + ( + Reference("Paragraph", paragraph_to_reference.uuid) + if paragraph_to_reference is not None + else None + ), + ) + ) + return paragraphs + + +def create_articles( + num_articles: int, authors: List[Author], paragraphs: List[Paragraph] +) -> List[Article]: + articles: List[Article] = [] + base_date: datetime.date = datetime.datetime(2023, 12, 9, 7, 1, 34) + for i in range(num_articles): + title: str = f"{i} {i} {i}" + paragraph_to_reference: Paragraph = paragraphs[i % len(paragraphs)] + author_to_reference: Author = authors[i % len(authors)] + date_published: str = (base_date + datetime.timedelta(hours=i)).isoformat() + "Z" + articles.append( + Article( + title, + date_published, + str(i), + i, + Reference("Author", author_to_reference.uuid), + Reference("Paragraph", paragraph_to_reference.uuid), + ) + ) + + return articles diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index c27d9d8a8..937c13d92 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -25,6 +25,7 @@ from weaviate.collections.batch.grpc_batch_objects import _BatchGRPC from weaviate.collections.batch.rest import _BatchRESTAsync from weaviate.collections.classes.batch import ( + _BatchReference, BatchObject, BatchReference, BatchResult, @@ -32,7 +33,6 @@ ErrorReference, _BatchObject, BatchObjectReturn, - _BatchReference, BatchReferenceReturn, Shard, ) @@ -61,52 +61,69 @@ class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]): """`BatchRequest` abstract class used as a interface for batch requests.""" def __init__(self) -> None: - self.__items: List[TBatchInput] = [] - self.__lock = threading.Lock() + self._items: List[TBatchInput] = [] + self._lock = threading.Lock() def __len__(self) -> int: - return len(self.__items) + return len(self._items) def add(self, item: TBatchInput) -> None: """Add an item to the BatchRequest.""" - self.__lock.acquire() - self.__items.append(item) - self.__lock.release() + self._lock.acquire() + self._items.append(item) + self._lock.release() def prepend(self, item: List[TBatchInput]) -> None: """Add items to the front of the BatchRequest. This is intended to be used when objects should be retries, eg. after a temporary error. """ - self.__lock.acquire() - self.__items = item + self.__items - self.__lock.release() + self._lock.acquire() + self._items = item + self._items + self._lock.release() - def pop_items(self, pop_amount: int) -> List[TBatchInput]: + +class ReferencesBatchRequest(BatchRequest[_BatchReference, BatchReferenceReturn]): + """Collect Weaviate-object references to add them in one request to Weaviate.""" + + def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[_BatchReference]: """Pop the given number of items from the BatchRequest queue. Returns - `List[TBatchInput]` items from the BatchRequest. + `List[_BatchReference]` items from the BatchRequest. """ - self.__lock.acquire() - if pop_amount >= len(self.__items): - ret = copy(self.__items) - self.__items.clear() - else: - ret = copy(self.__items[:pop_amount]) - self.__items = self.__items[pop_amount:] - - self.__lock.release() + ret: List[_BatchReference] = [] + i = 0 + self._lock.acquire() + while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items): + if self._items[i].from_uuid not in uuid_lookup: + ret.append(self._items.pop(i)) + else: + i += 1 + self._lock.release() return ret -class ReferencesBatchRequest(BatchRequest[_BatchReference, BatchReferenceReturn]): - """Collect Weaviate-object references to add them in one request to Weaviate.""" - - class ObjectsBatchRequest(BatchRequest[_BatchObject, BatchObjectReturn]): """Collect objects for one batch request to weaviate.""" + def pop_items(self, pop_amount: int) -> List[_BatchObject]: + """Pop the given number of items from the BatchRequest queue. + + Returns + `List[_BatchObject]` items from the BatchRequest. + """ + self._lock.acquire() + if pop_amount >= len(self._items): + ret = copy(self._items) + self._items.clear() + else: + ret = copy(self._items[:pop_amount]) + self._items = self._items[pop_amount:] + + self._lock.release() + return ret + @dataclass class _BatchDataWrapper: @@ -143,18 +160,20 @@ def __init__( results: _BatchDataWrapper, batch_mode: _BatchMode, objects_: Optional[ObjectsBatchRequest] = None, - references: Optional[BatchRequest[_BatchReference, BatchReferenceReturn]] = None, + references: Optional[ReferencesBatchRequest] = None, ) -> None: self.__batch_objects = objects_ or ObjectsBatchRequest() - self.__batch_references = ( - references or BatchRequest[_BatchReference, BatchReferenceReturn]() - ) + self.__batch_references = references or ReferencesBatchRequest() self.__connection = connection self.__consistency_level: Optional[ConsistencyLevel] = consistency_level self.__batch_grpc = _BatchGRPC(connection, self.__consistency_level) self.__batch_rest = _BatchRESTAsync(connection, self.__consistency_level) + # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet + self.__uuid_lookup_lock = threading.Lock() + self.__uuid_lookup: Set[str] = set() + # we do not want that users can access the results directly as they are not thread-safe self.__results_for_wrapper_backup = results self.__results_for_wrapper = _BatchDataWrapper() @@ -186,7 +205,6 @@ def __init__( self.__recommended_num_objects = 10 self.__concurrent_requests = 2 - # there seems to be a bug with weaviate when sending > 50 refs at once self.__recommended_num_refs: int = 50 self.__active_requests = 0 @@ -281,11 +299,16 @@ def __batch_send(self) -> None: self.__active_requests_lock.release() objs = self.__batch_objects.pop_items(self.__recommended_num_objects) + self.__uuid_lookup_lock.acquire() + refs = self.__batch_references.pop_items( + self.__recommended_num_refs, uuid_lookup=self.__uuid_lookup + ) + self.__uuid_lookup_lock.release() # do not block the thread - the results are written to a central (locked) list and we want to have multiple concurrent batch-requests asyncio.run_coroutine_threadsafe( self.__send_batch_async( objs, - self.__batch_references.pop_items(self.__recommended_num_refs), + refs, readd_rate_limit=isinstance(self.__batching_mode, _RateLimitedBatching), ), loop, @@ -422,6 +445,7 @@ async def __send_batch_async( uuids={}, ) + readded_uuids = set() if readd_rate_limit: readded_objects = [] highest_retry_count = 0 @@ -450,13 +474,14 @@ async def __send_batch_async( self.__fix_rate_batching_base_time * (highest_retry_count + 1), ) - self.__batch_objects.prepend( - [ - err.object_ - for i, err in response_obj.errors.items() - if i in readded_objects - ] - ) + readd_objects = [ + err.object_ + for i, err in response_obj.errors.items() + if i in readded_objects + ] + readded_uuids = {obj.uuid for obj in readd_objects} + + self.__batch_objects.prepend(readd_objects) new_errors = { i: err for i, err in response_obj.errors.items() if i not in readded_objects @@ -482,6 +507,11 @@ async def __send_batch_async( self.__fix_rate_batching_base_time += ( 1 # increase the base time as the current one is too low ) + self.__uuid_lookup_lock.acquire() + self.__uuid_lookup.difference_update( + obj.uuid for obj in objs if obj.uuid not in readded_uuids + ) + self.__uuid_lookup_lock.release() self.__results_lock.acquire() self.__results_for_wrapper.results.objs += response_obj @@ -547,6 +577,9 @@ def _add_object( ) except ValidationError as e: raise WeaviateBatchValidationError(repr(e)) + self.__uuid_lookup_lock.acquire() + self.__uuid_lookup.add(str(batch_object.uuid)) + self.__uuid_lookup_lock.release() self.__batch_objects.add(batch_object._to_internal()) # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do diff --git a/weaviate/collections/classes/batch.py b/weaviate/collections/classes/batch.py index 9cdbed5d0..517261376 100644 --- a/weaviate/collections/classes/batch.py +++ b/weaviate/collections/classes/batch.py @@ -14,7 +14,7 @@ class _BatchObject: collection: str vector: Optional[List[float]] - uuid: Optional[UUID] + uuid: str properties: Optional[Dict[str, WeaviateField]] tenant: Optional[str] references: Optional[ReferenceInputs] @@ -26,6 +26,7 @@ class _BatchReference: from_: str to: str tenant: Optional[str] + from_uuid: str class BatchObject(BaseModel): @@ -54,7 +55,7 @@ def _to_internal(self) -> _BatchObject: return _BatchObject( collection=self.collection, vector=cast(list, self.vector), - uuid=self.uuid, + uuid=str(self.uuid), properties=self.properties, tenant=self.tenant, references=self.references, @@ -113,6 +114,7 @@ def _to_internal(self) -> _BatchReference: else: self.to_object_collection = self.to_object_collection + "/" return _BatchReference( + from_uuid=str(self.from_object_uuid), from_=f"{BEACON}{self.from_object_collection}/{self.from_object_uuid}/{self.from_property_name}", to=f"{BEACON}{self.to_object_collection}{str(self.to_object_uuid)}", tenant=self.tenant, diff --git a/weaviate/collections/data.py b/weaviate/collections/data.py index d795d1406..e851851f2 100644 --- a/weaviate/collections/data.py +++ b/weaviate/collections/data.py @@ -213,6 +213,7 @@ def _reference_add_many(self, refs: List[DataReferences]) -> BatchReferenceRetur from_=f"{BEACON}{self.name}/{ref.from_uuid}/{ref.from_property}", to=beacon, tenant=self._tenant, + from_uuid=str(ref.from_uuid), ) for ref in refs for beacon in ref._to_beacons() @@ -404,7 +405,7 @@ def insert_many( _BatchObject( collection=self.name, vector=obj.vector, - uuid=obj.uuid, + uuid=str(obj.uuid if obj.uuid is not None else uuid_package.uuid4()), properties=cast(dict, obj.properties), tenant=self._tenant, references=obj.references, @@ -413,7 +414,7 @@ def insert_many( else _BatchObject( collection=self.name, vector=None, - uuid=None, + uuid=str(uuid_package.uuid4()), properties=cast(dict, obj), tenant=self._tenant, references=None,