Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 94 additions & 38 deletions specifyweb/backend/trees/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, TypedDict, NotRequired
from typing import Dict, Optional, TypedDict, NotRequired, Union
import json

from django.db import transaction
Expand Down Expand Up @@ -133,9 +133,14 @@ class RankMappingConfiguration(TypedDict):
fullnameseparator: NotRequired[str]
fields: Dict[str, str]

class TreeConfiguration(TypedDict):
all_columns: list[str]
ranks: list[RankMappingConfiguration]
root: NotRequired[dict]

class DefaultTreeContext():
"""Context for a default tree creation task"""
def __init__(self, tree_type: str, tree_def, tree_cfg: dict[str, RankMappingConfiguration], create_missing_ranks: bool):
def __init__(self, tree_type: str, tree_def, tree_cfg: TreeConfiguration, create_missing_ranks: bool):
self.tree_type = tree_type
self.tree_def = tree_def

Expand All @@ -144,6 +149,9 @@ def __init__(self, tree_type: str, tree_def, tree_cfg: dict[str, RankMappingConf
self.tree_cfg = tree_cfg
if create_missing_ranks:
self.create_missing_ranks()

self.local_count = 0
self.local_id_field = 'text1'

self.create_rank_map()
self.root_parent = self.tree_node_model.objects.filter(
Expand Down Expand Up @@ -176,40 +184,54 @@ def create_rank_map(self):
self.tree_def_item_map = {rank.name: rank for rank in ranks}
# Buffers for batches
self.rankid_map = {rank.rankid: rank for rank in ranks}
# All node objects to be created in this batch, separated by rank
self.buffers = {rank.rankid: {} for rank in ranks}
# Contains all nodes that can be parents at the current row. Name -> Object or database ID.
self.parent_lookup = {rank.rankid: {} for rank in ranks}
# IDs of nodes commited to the database. Local ID -> Database ID
self.created = {rank.rankid: {} for rank in ranks}
self.highest_rank = 0

def add_node_to_buffer(self, node, rank_id, row_id):
"""Add node to the current batch of nodes to be created"""
if rank_id not in self.buffers:
self.buffers[rank_id] = {}
self.parent_lookup[rank_id] = {}
self.created[rank_id] = {}
self.buffers[rank_id][row_id] = node
self.parent_lookup[rank_id][node.name] = node
return node

def get_node_in_buffer(self, rank_id: int, name: str):
def get_existing_parent(self, rank_id: int, name: str) -> Union[object, int, None]:
"""Gets a node if its already in the current batch's buffer. Prevents duplication within a batch."""
# Check for node in buffer, return node
buffer = self.buffers.get(rank_id, {})
for node in buffer.values():
if node.name == name:
return node
return None
lookup = self.parent_lookup.get(rank_id, {})
return lookup.get(name, None)

def get_existing_node_id(self, rank_id: int, name: str) -> Optional[int]:
"""Gets a node's id if it has already been created. Prevents duplication across an entire import."""
# Check for existing id, return id
created_in_rank = self.created.get(rank_id)
if created_in_rank:
return created_in_rank.get(name)
return None
def clear_parent_lookup(self, highest_rank: int):
"""Clears all higher-rank buffers, since they are no longer relevant"""
# This will prevent a node from being parented to an incorrect parent with the same name
if highest_rank < self.highest_rank:
for id in list(self.parent_lookup.keys()):
if id > highest_rank:
self.parent_lookup[id] = {}
self.highest_rank = highest_rank
self.highest_rank = max(highest_rank, self.highest_rank)

def finalize(self):
"""Clears temporary local id values from tree."""
self.tree_node_model.objects.filter(
definition=self.tree_def
).update(
**{f"{self.local_id_field}": None}
)

def flush(self, force=False):
"""Flushes this batch's buffer if the batch is complete. Bulk creates the nodes in a complete batch."""
self.counter += 1
if not (force or self.counter > self.batch_size):
return
logger.debug(f"Batch creating {self.batch_size} rows.")
logger.debug(f"Batch creating {self.counter} rows.")

# Go through ranks in ascending order and bulk create nodes
ordered_rank_ids = sorted(self.buffers.keys())
Expand All @@ -228,7 +250,7 @@ def flush(self, force=False):
parent = getattr(node, 'parent', None)
parent_id = getattr(node, 'parent_id', None)
if parent is not None and getattr(parent, 'pk', None) is None:
saved_parent_id = self.created[parent.rankid].get(parent.name)
saved_parent_id = self.created[parent.rankid].get(int(getattr(parent, self.local_id_field)))
# Handle root
if not saved_parent_id and parent.name == getattr(self.root_parent, 'name', None):
saved_parent_id = self.root_parent.id
Expand All @@ -246,22 +268,32 @@ def flush(self, force=False):
self.tree_node_model.objects.bulk_create(nodes_to_create, ignore_conflicts=True)

# Store the ids of the nodes were created in this batch
created_names = [n.name for n in nodes_to_create]
placeholders = ",".join(["%s"] * len(created_names))
created_local_ids = [str(getattr(n, self.local_id_field)) for n in nodes_to_create]
created_nodes = self.tree_node_model.objects.filter(
definition=self.tree_def,
definitionitem=rank,
).extra(
where=[f"BINARY name IN ({placeholders})"],
params=created_names
**{f"{self.local_id_field}__in": created_local_ids}
)
self.created[rank_id].update({int(getattr(n, self.local_id_field)): n.id for n in created_nodes})

# parent_lookup still contains unsaved objects. Replace them with IDs.
sorted_created_nodes = sorted(
created_nodes,
key=lambda n: int(getattr(n, self.local_id_field))
)
self.created[rank_id].update({n.name: n.id for n in created_nodes})
for node in sorted_created_nodes:
local_id = int(getattr(node, self.local_id_field))
name = node.name
# Check that the name is already in the lookup as to not re-introduce irrelevant parents.
if self.parent_lookup[rank_id].get(name):
self.parent_lookup[rank_id][name] = self.created[rank_id].get(local_id)


self.buffers[rank_id] = {}

self.counter = 0

def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: dict[str, RankMappingConfiguration], row_id: int):
def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: TreeConfiguration, row_id: int):
"""
Given one CSV row and a column mapping / rank configuration dictionary,
walk through the 'ranks' in order, creating or updating each tree record and linking
Expand All @@ -271,16 +303,18 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di
tree_def = context.tree_def
parent = context.root_parent
parent_id = None
rank_id = 10

for rank_mapping in tree_cfg['ranks']:
highest_rank = 0
rank_count = len(tree_cfg['ranks'])
for index in range(rank_count):
rank_mapping = tree_cfg['ranks'][index]
rank_name = rank_mapping['name']
fields_mapping = rank_mapping['fields']

record_name = row.get(rank_mapping.get('column', rank_name)) # Record's name is in the <rank_name> column.

if not record_name:
continue # This row doesn't contain a record for this rank.
continue # This row doesn't contain a record for this rank.

defaults = {}
for model_field, csv_col in fields_mapping.items():
Expand All @@ -301,16 +335,30 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di

if tree_def_item is None:
continue

# Check if this is the last node in this row.
# If so, do not attempt to de-duplicate it. Non-parent nodes are allowed to share names.
is_last = (index == rank_count-1)
if not is_last and index < rank_count-1:
next_rank_mapping = tree_cfg['ranks'][index+1]
next_rank_name = next_rank_mapping['name']
next_record_name = row.get(next_rank_mapping.get('column', next_rank_name))
if not next_record_name:
is_last = True

highest_rank = tree_def_item.rankid

# Create the node at this rank if it isn't already there.
buffered = context.get_node_in_buffer(tree_def_item.rankid, record_name)
existing_id = context.get_existing_node_id(tree_def_item.rankid, record_name)
if existing_id is not None:
parent_id = existing_id
parent = None
elif buffered is not None:
parent_id = None
parent = buffered
existing = context.get_existing_parent(tree_def_item.rankid, record_name)
if not is_last and existing is not None:
if type(existing) is int:
# Use parent's true id
parent_id = existing
parent = None
else:
# Unsaved parent, use the object directly (It will be replaced with the true id when buffer is flushed)
parent_id = None
parent = existing
else:
# Add new node to buffer
data = {
Expand All @@ -323,6 +371,11 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di
if hasattr(tree_node_model, 'isaccepted'):
data['isaccepted'] = True
data.update(defaults)

# Add a unique identifier in this import context (to be deleted when tree is finalized)
# This will be used to query this exact node again once its saved
context.local_count += 1
data[context.local_id_field] = context.local_count

if parent is not None:
data['parent'] = parent
Expand All @@ -334,7 +387,9 @@ def add_default_tree_record(context: DefaultTreeContext, row: dict, tree_cfg: di

parent = obj
parent_id = None
rank_id += 10

# Clear irrelevant parents
context.clear_parent_lookup(highest_rank)

def queue_create_default_tree_task(task_id):
"""Store queued (and active) default tree creation tasks so they can be reliably tracked later."""
Expand Down Expand Up @@ -428,13 +483,14 @@ def progress(cur: int, additional_total: int=0) -> None:
total_rows = row_count-2
progress(0, total_rows)

for row in stream_default_tree_csv(url):
add_default_tree_record(context, row, tree_cfg, current)
for row_idx, row in enumerate(stream_default_tree_csv(url)):
add_default_tree_record(context, row, tree_cfg, row_idx)
context.flush()
progress(1, 0)
context.flush(force=True)

# Finalize Tree
context.finalize()
renumber_tree(tree_type)
set_fullnames(tree_def)
except Exception as e:
Expand Down
Loading