diff --git a/specifyweb/backend/trees/defaults.py b/specifyweb/backend/trees/defaults.py index 105f126f806..02d60e399e1 100644 --- a/specifyweb/backend/trees/defaults.py +++ b/specifyweb/backend/trees/defaults.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, TypedDict, NotRequired +from typing import Dict, Optional, TypedDict, NotRequired, Union import json from django.db import transaction @@ -133,9 +133,14 @@ class RankMappingConfiguration(TypedDict): fullnameseparator: NotRequired[str] fields: Dict[str, str] +class TreeConfiguration(TypedDict): + all_columns: list[str] + ranks: list[RankMappingConfiguration] + root: NotRequired[dict] + class DefaultTreeContext(): """Context for a default tree creation task""" - def __init__(self, tree_type: str, tree_def, tree_cfg: dict[str, RankMappingConfiguration], create_missing_ranks: bool): + def __init__(self, tree_type: str, tree_def, tree_cfg: TreeConfiguration, create_missing_ranks: bool): self.tree_type = tree_type self.tree_def = tree_def @@ -144,6 +149,9 @@ def __init__(self, tree_type: str, tree_def, tree_cfg: dict[str, RankMappingConf self.tree_cfg = tree_cfg if create_missing_ranks: self.create_missing_ranks() + + self.local_count = 0 + self.local_id_field = 'text1' self.create_rank_map() self.root_parent = self.tree_node_model.objects.filter( @@ -176,40 +184,54 @@ def create_rank_map(self): self.tree_def_item_map = {rank.name: rank for rank in ranks} # Buffers for batches self.rankid_map = {rank.rankid: rank for rank in ranks} + # All node objects to be created in this batch, separated by rank self.buffers = {rank.rankid: {} for rank in ranks} + # Contains all nodes that can be parents at the current row. Name -> Object or database ID. + self.parent_lookup = {rank.rankid: {} for rank in ranks} + # IDs of nodes commited to the database. Local ID -> Database ID self.created = {rank.rankid: {} for rank in ranks} + self.highest_rank = 0 def add_node_to_buffer(self, node, rank_id, row_id): """Add node to the current batch of nodes to be created""" if rank_id not in self.buffers: self.buffers[rank_id] = {} + self.parent_lookup[rank_id] = {} self.created[rank_id] = {} self.buffers[rank_id][row_id] = node + self.parent_lookup[rank_id][node.name] = node return node - def get_node_in_buffer(self, rank_id: int, name: str): + def get_existing_parent(self, rank_id: int, name: str) -> Union[object, int, None]: """Gets a node if its already in the current batch's buffer. Prevents duplication within a batch.""" # Check for node in buffer, return node - buffer = self.buffers.get(rank_id, {}) - for node in buffer.values(): - if node.name == name: - return node - return None + lookup = self.parent_lookup.get(rank_id, {}) + return lookup.get(name, None) - def get_existing_node_id(self, rank_id: int, name: str) -> Optional[int]: - """Gets a node's id if it has already been created. Prevents duplication across an entire import.""" - # Check for existing id, return id - created_in_rank = self.created.get(rank_id) - if created_in_rank: - return created_in_rank.get(name) - return None + def clear_parent_lookup(self, highest_rank: int): + """Clears all higher-rank buffers, since they are no longer relevant""" + # This will prevent a node from being parented to an incorrect parent with the same name + if highest_rank < self.highest_rank: + for id in list(self.parent_lookup.keys()): + if id > highest_rank: + self.parent_lookup[id] = {} + self.highest_rank = highest_rank + self.highest_rank = max(highest_rank, self.highest_rank) + + def finalize(self): + """Clears temporary local id values from tree.""" + self.tree_node_model.objects.filter( + definition=self.tree_def + ).update( + **{f"{self.local_id_field}": None} + ) def flush(self, force=False): """Flushes this batch's buffer if the batch is complete. Bulk creates the nodes in a complete batch.""" self.counter += 1 if not (force or self.counter > self.batch_size): return - logger.debug(f"Batch creating {self.batch_size} rows.") + logger.debug(f"Batch creating {self.counter} rows.") # Go through ranks in ascending order and bulk create nodes ordered_rank_ids = sorted(self.buffers.keys()) @@ -228,7 +250,7 @@ def flush(self, force=False): parent = getattr(node, 'parent', None) parent_id = getattr(node, 'parent_id', None) if parent is not None and getattr(parent, 'pk', None) is None: - saved_parent_id = self.created[parent.rankid].get(parent.name) + saved_parent_id = self.created[parent.rankid].get(int(getattr(parent, self.local_id_field))) # Handle root if not saved_parent_id and parent.name == getattr(self.root_parent, 'name', None): saved_parent_id = self.root_parent.id @@ -246,22 +268,32 @@ def flush(self, force=False): self.tree_node_model.objects.bulk_create(nodes_to_create, ignore_conflicts=True) # Store the ids of the nodes were created in this batch - created_names = [n.name for n in nodes_to_create] - placeholders = ",".join(["%s"] * len(created_names)) + created_local_ids = [str(getattr(n, self.local_id_field)) for n in nodes_to_create] created_nodes = self.tree_node_model.objects.filter( definition=self.tree_def, definitionitem=rank, - ).extra( - where=[f"BINARY name IN ({placeholders})"], - params=created_names + **{f"{self.local_id_field}__in": created_local_ids} + ) + self.created[rank_id].update({int(getattr(n, self.local_id_field)): n.id for n in created_nodes}) + + # parent_lookup still contains unsaved objects. Replace them with IDs. + sorted_created_nodes = sorted( + created_nodes, + key=lambda n: int(getattr(n, self.local_id_field)) ) - self.created[rank_id].update({n.name: n.id for n in created_nodes}) + for node in sorted_created_nodes: + local_id = int(getattr(node, self.local_id_field)) + name = node.name + # Check that the name is already in the lookup as to not re-introduce irrelevant parents. + if self.parent_lookup[rank_id].get(name): + self.parent_lookup[rank_id][name] = self.created[rank_id].get(local_id) + self.buffers[rank_id] = {} self.counter = 0 -def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: dict[str, RankMappingConfiguration], row_id: int): +def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: TreeConfiguration, row_id: int): """ Given one CSV row and a column mapping / rank configuration dictionary, walk through the 'ranks' in order, creating or updating each tree record and linking @@ -271,16 +303,18 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di tree_def = context.tree_def parent = context.root_parent parent_id = None - rank_id = 10 - for rank_mapping in tree_cfg['ranks']: + highest_rank = 0 + rank_count = len(tree_cfg['ranks']) + for index in range(rank_count): + rank_mapping = tree_cfg['ranks'][index] rank_name = rank_mapping['name'] fields_mapping = rank_mapping['fields'] record_name = row.get(rank_mapping.get('column', rank_name)) # Record's name is in the column. if not record_name: - continue # This row doesn't contain a record for this rank. + continue # This row doesn't contain a record for this rank. defaults = {} for model_field, csv_col in fields_mapping.items(): @@ -301,16 +335,30 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di if tree_def_item is None: continue + + # Check if this is the last node in this row. + # If so, do not attempt to de-duplicate it. Non-parent nodes are allowed to share names. + is_last = (index == rank_count-1) + if not is_last and index < rank_count-1: + next_rank_mapping = tree_cfg['ranks'][index+1] + next_rank_name = next_rank_mapping['name'] + next_record_name = row.get(next_rank_mapping.get('column', next_rank_name)) + if not next_record_name: + is_last = True + + highest_rank = tree_def_item.rankid # Create the node at this rank if it isn't already there. - buffered = context.get_node_in_buffer(tree_def_item.rankid, record_name) - existing_id = context.get_existing_node_id(tree_def_item.rankid, record_name) - if existing_id is not None: - parent_id = existing_id - parent = None - elif buffered is not None: - parent_id = None - parent = buffered + existing = context.get_existing_parent(tree_def_item.rankid, record_name) + if not is_last and existing is not None: + if type(existing) is int: + # Use parent's true id + parent_id = existing + parent = None + else: + # Unsaved parent, use the object directly (It will be replaced with the true id when buffer is flushed) + parent_id = None + parent = existing else: # Add new node to buffer data = { @@ -323,6 +371,11 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di if hasattr(tree_node_model, 'isaccepted'): data['isaccepted'] = True data.update(defaults) + + # Add a unique identifier in this import context (to be deleted when tree is finalized) + # This will be used to query this exact node again once its saved + context.local_count += 1 + data[context.local_id_field] = context.local_count if parent is not None: data['parent'] = parent @@ -334,7 +387,9 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di parent = obj parent_id = None - rank_id += 10 + + # Clear irrelevant parents + context.clear_parent_lookup(highest_rank) def queue_create_default_tree_task(task_id): """Store queued (and active) default tree creation tasks so they can be reliably tracked later.""" @@ -428,13 +483,14 @@ def progress(cur: int, additional_total: int=0) -> None: total_rows = row_count-2 progress(0, total_rows) - for row in stream_default_tree_csv(url): - add_default_tree_record(context, row, tree_cfg, current) + for row_idx, row in enumerate(stream_default_tree_csv(url)): + add_default_tree_record(context, row, tree_cfg, row_idx) context.flush() progress(1, 0) context.flush(force=True) # Finalize Tree + context.finalize() renumber_tree(tree_type) set_fullnames(tree_def) except Exception as e: