# Imports

In [None]:
import functools
import graphblas as gb
import numpy as np
import pronto

# Step 1. Load an ontology using Pronto

In [None]:
#go_ontology = pronto.Ontology("../data/out/go.obo")
go_ontology = pronto.Ontology("../tests/resources/dummy_ontology.obo")

# Step 8. Create classes to organize the code

Objects

- LUTs
- properties of graph
  - number of nodes
  - number of edges
  - relation types
- Matrices

###  Lookup Tables Class

In [None]:
class LookUpTables:
    def __init__(self, ontology):
        self.__lut_term_to_index = self.__create_lut_term_index(ontology=ontology)
        self.__lut_index_to_term = self.__create_lut_nodes(lookup_table=self.__lut_term_to_index)
        self.__lut_term_to_description = self.__create_lut_term_description(ontology=ontology)
        self.__lut_description_to_term = self.__create_lut_description_term(ontology=ontology)

    # Private methods
    def __create_lut_term_index(self, ontology)-> dict[str, int]:
        terms = [term for term in ontology.terms() if not term.obsolete]
        terms.sort(key=lambda term: term.id)
        return {term.id: idx for idx, term in enumerate(terms)}
    
    def __create_lut_nodes(self, lookup_table):
        return list(lookup_table.keys())
    
    def __create_lut_term_description(self, ontology):
        return {term.id: term.name for term in ontology.terms() if not term.obsolete}
    
    def __create_lut_description_term(self, ontology):
        return {term.name: term.id for term in ontology.terms() if not term.obsolete}
    

    # Public methods
    def get_lut_term_to_index(self):
        return self.__lut_term_to_index
    
    def get_lut_index_to_term(self):
        return self.__lut_index_to_term
    
    def get_lut_term_to_description(self):
        return self.__lut_term_to_description
    
    def get_lut_description_to_term(self):
        return self.__lut_description_to_term


    def term_to_index(self, terms: str | list):
        if isinstance(terms, str):
            # Single term ID
            return self.__lut_term_to_index[terms]
        elif isinstance(terms, list):
            # List of term IDs
            return [self.__lut_term_to_index[term] for term in terms]

    def index_to_term(self, indexes: int | list):
        if isinstance(indexes, int):
            # Single index
            return self.__lut_index_to_term[indexes]
        elif isinstance(indexes, list):
            # List of indexes
            return [self.__lut_index_to_term[idx] for idx in indexes]
        elif isinstance(indexes, np.ndarray):
            # NumPy array of indices (vectorized lookup)
            return [self.__lut_index_to_term[idx] for idx in indexes.tolist()]

        else:
            raise TypeError(
                f"Expected int, list[int], or np.ndarray, got {type(indexes).__name__}."
            )

    def term_to_description(self, terms: str | list):
        if isinstance(terms, str):
            # Single term ID
            return self.__lut_term_to_description[terms]
        elif isinstance(terms, list):
            # List of term IDs
            return [self.__lut_term_to_description[term] for term in terms]

    def description_to_term(self, descriptions: str | list):
        # TODO: improve with Levenstein or Regex expresions
        if isinstance(descriptions, str):
            # Single description
            return self.__lut_description_to_term[descriptions]
        elif isinstance(descriptions, list):
            # List of descriptions
            return [self.__lut_description_to_term[term] for term in descriptions]


In [None]:
# Create LookUpTables given an ontology
L = LookUpTables(ontology=go_ontology)

In [None]:
# [x] verified agains Protege

list_terms = ['GO:0008150', 'GO:0000017', 'GO:0000001']
list_indices = [5099, 10, 0]
list_descriptions = ['biological_process', 'alpha-glucoside transport', 'mitochondrion inheritance']


print(f"Indices:\n\t{L.term_to_index(list_terms)}")
print(f"Terms:\n\t{L.index_to_term(list_indices)}")
print(f"Descriptions:\n\t{L.term_to_description(list_terms)}")
print(f"Terms from descriptions:\n\t{L.description_to_term(list_descriptions)}")


### Graph Class

In [None]:
from functools import cached_property
from collections import defaultdict

class Graph:
    def __init__(self, ontology, lookup_tables):
        # Core elements
        self.lookup_tables = lookup_tables
        self.nodes_container = self.lookup_tables.get_lut_index_to_term()
        self.edges_container = self.populate_index_containers(ontology=ontology,
                                                              lut_term_index=self.lookup_tables.get_lut_term_to_index()
                                                              )
        self.matrices_container = self.create_multiple_matrices(edge_container=self.edges_container,
                                                             nrows=len(self.nodes_container),
                                                             ncols=len(self.nodes_container))
 
        # Metadata
        self.relation_types =  self.get_ontology_relationships(ontology=ontology)
        self.number_nodes = self.number_nodes_ontology(ontology=ontology)
        self.number_edges = self.edges_container['is_a']['rows'].shape[0]
    ## Private methods
    def create_edges_index_containers(self, ontology):

        """    
            rows represents sources
            cols represents targets
        """

        # Extract explicit relationships 
        relationships = self.get_ontology_relationships(ontology)

        # Append implicit relationship 'is_a'
        relationships.append('is_a')

        # create edge container
        edge_container = {rel: {'rows': [], 'cols':[] } for rel in relationships}

        return edge_container
    

    def populate_index_containers(self, ontology, lut_term_index):
        # create an empty edges container
        edge_container = self.create_edges_index_containers(ontology)

        # main loop to extract terms and relationships
        for idx, term in enumerate(ontology.terms()):
            # ignore obsolete terms
            if term.obsolete:
                continue
            
            # extract super classes for each term ('is_a' relationship)
            for subclass in term.subclasses(with_self=False, distance=1):
                if subclass.obsolete:
                    continue
                edge_container['is_a']['rows'].append(
                    self.lookup_tables.term_to_index(subclass.id)
                )
                edge_container['is_a']['cols'].append(
                    self.lookup_tables.term_to_index(term.id)
                )
                
            # extract explict relationships (i.e., 'part_of')
            for rel, targets in term.relationships.items():
                for target in targets:
                    if target.obsolete:
                        continue
                    edge_container[rel.id]['rows'].append(
                        self.lookup_tables.term_to_index(term.id)
                    )
                    edge_container[rel.id]['cols'].append(
                        self.lookup_tables.term_to_index(target.id)
                    )

        # convert lists to numpy arrays
        for rel, data in edge_container.items():
            data['rows'] = np.array(data['rows'])
            data['cols'] = np.array(data['cols'])
        
        return edge_container

    def create_graphblas_matrix(self, rows_indexes, cols_indexes, nrows, ncols, name):
        M = gb.Matrix.from_coo(rows=rows_indexes, columns=cols_indexes, values=1.0, nrows=nrows, ncols=ncols, dtype=bool, name=name)
        return M

    def create_multiple_matrices(self, edge_container, nrows, ncols):
        matrices = {}
        for relation, indexes in edge_container.items():
            M = self.create_graphblas_matrix(rows_indexes=indexes['rows'],
                                        cols_indexes=indexes['cols'],
                                        nrows=nrows,
                                        ncols=ncols,
                                        name=relation)

            matrices[relation] = M
        return matrices
    
    # Function to get all name of relationships
    def get_ontology_relationships(self, ontology):
        set_relations = set()
        for term in ontology.terms():
            for rel in term.relationships:
                set_relations.add(rel.id)

        return sorted(set_relations)


    # Calculate the number of nodes for the current ontology
    def number_nodes_ontology(self, ontology):
        return len([term for term in ontology.terms() if not term.obsolete])
        
    ## Public methods

    # Generates a one-hot encoded vector for a given index
    @functools.lru_cache(maxsize=None)
    def one_hot_vector(self, index: int) -> gb.Vector:
        return gb.Vector.from_coo([index], [1], size=self.number_nodes, dtype=int)

    # -- get_children(term_id, include_self=False)
    def get_children(self, term_id, include_self=False):
        # validate and resolve the index
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        index = self.lookup_tables.term_to_index(term_id)

        # Initialize a one-hot vector for the term node
        vector_node = self.one_hot_vector(index=index)

        # Propagate to children using matrix-vector multiplication
        children_vec = (self.matrices_container['is_a'] @ vector_node).new()

        # Optionally include the node itself
        if include_self:
            children_vec[index] = True

        # translate indexes to terms
        terms = [term for term in children_vec]
        
        return self.lookup_tables.index_to_term(terms)
    
    # -- get_parents(term_id, include_self=False)
    def get_parents(self, term_id, include_self=False):
        # validate and resolve the index
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        index = self.lookup_tables.term_to_index(term_id)

        # Initialize a one-hot vector for the term node
        vector_node = self.one_hot_vector(index=index)

        # Propagate to children using matrix-vector multiplication
        parent_vec = (self.matrices_container['is_a'].T @ vector_node).new()

        # Optionally include the node itself
        if include_self:
            parent_vec[index] = True

        # translate indexes to terms
        terms = [term for term in parent_vec]
        
        return self.lookup_tables.index_to_term(terms)

    # -- get_root()
    def get_root(self):
    
        matrix = self.matrices_container['is_a'].T

        # 1. Compute the number of incoming edges per node (column-wise sum)
        col_sums_expr = matrix.reduce_columnwise(gb.binary.plus)

        # 2. Materialize the VectorExpression
        col_sums_vec = col_sums_expr.new()

        # 3. Extract non-zero indices and their counts
        indices, values = col_sums_vec.to_coo()

        # 4. Create dense array of incoming edge counts
        col_sums_np = np.zeros(matrix.ncols, dtype=np.int64)
        col_sums_np[indices] = values

        # 5. Roots = nodes with zero incoming edges
        roots = np.where(col_sums_np == 0)[0]

        return self.lookup_tables.index_to_term(roots)
    
    def _traverse_graph(self, term_id, adjacency_matrix, distance=None, include_self=False):
        """
        Generalized function to traverse a graph in either direction.
        
        Parameters
        ----------
        term_id : str
            The starting term ID.
        adjacency_matrix : gb.Matrix
            Adjacency matrix to traverse (forward for descendants, transposed for ancestors).
        distance : int or None
            Maximum distance to traverse. None means unlimited.
        include_self : bool
            Whether to include the starting node in the result.

        Returns
        -------
        List[str]
            List of term IDs reached.
        """
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")
        
        index = self.lookup_tables.term_to_index(term_id)
        current_vector = self.one_hot_vector(index=index)
        visited = set()
        
        if include_self:
            visited.add(index)
        
        while current_vector.nvals != 0 and distance != 0:
            next_vector = gb.Vector(dtype=int, size=adjacency_matrix.nrows)
            next_vector << gb.semiring.plus_times(adjacency_matrix @ current_vector)  # forward or transposed depends on matrix
            
            next_indices = set(next_vector.to_coo()[0])
            next_indices.difference_update(visited)
            
            if not next_indices:
                break
            
            visited.update(next_indices)
            current_vector = gb.Vector.from_coo(list(next_indices), [1]*len(next_indices), size=adjacency_matrix.nrows)

            if distance is not None:
                distance -= 1

        return self.lookup_tables.index_to_term(list(visited))


    # Public API functions
    def get_ancestors(self, term_id, distance=None, include_self=False):
        adjacency_matrix = self.matrices_container['is_a'].T  # transpose for ancestors
        return self._traverse_graph(term_id, adjacency_matrix, distance, include_self)

    def get_descendants(self, term_id, distance=None, include_self=False):
        adjacency_matrix = self.matrices_container['is_a']  # normal direction for descendants
        return self._traverse_graph(term_id, adjacency_matrix, distance, include_self)

    def _traverse_graph_with_distance(self, term_id, adjacency_matrix, include_self=False):
        """
        Generalized function to traverse a graph and return nodes with distance from start.
        
        Parameters
        ----------
        term_id : str
            The starting term ID.
        adjacency_matrix : gb.Matrix
            Adjacency matrix to traverse (forward for descendants, transposed for ancestors).
        include_self : bool
            Whether to include the starting node with distance 0.

        Returns
        -------
        List[Tuple[int, int]]
            List of tuples (node_index, distance_from_start)
        """
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")
        
        start_index = self.lookup_tables.term_to_index(term_id)
        current_vector = self.one_hot_vector(index=start_index)
        
        distances = {}  # {node_index: distance}
        distance_counter = 0
        
        if include_self:
            distances[start_index] = 0

        while current_vector.nvals != 0:
            next_vector = gb.Vector(dtype=int, size=adjacency_matrix.nrows)
            next_vector << gb.semiring.plus_times(adjacency_matrix @ current_vector)
            
            next_indices = set(next_vector.to_coo()[0])
            # remove already visited nodes
            next_indices.difference_update(distances.keys())
            
            if not next_indices:
                break
            
            distance_counter += 1
            for idx in next_indices:
                distances[idx] = distance_counter
            
            current_vector = gb.Vector.from_coo(list(next_indices), [1]*len(next_indices), size=adjacency_matrix.nrows)

        # return as list of tuples
        return [(self.lookup_tables.index_to_term(int(index)), distance) for index, distance in distances.items()]


    # Public API functions
    def get_ancestors_with_distance(self, term_id, include_self=False):
        adjacency_matrix = self.matrices_container['is_a'].T  # transpose for ancestors
        return self._traverse_graph_with_distance(term_id, adjacency_matrix, include_self)

    def get_descendants_with_distance(self, term_id, include_self=False):
        adjacency_matrix = self.matrices_container['is_a']  # normal direction for descendants
        return self._traverse_graph_with_distance(term_id, adjacency_matrix, include_self)
    
    def get_common_ancestors(self, node_ids):
        """
        Return the common ancestors of a list of terms.

        Parameters
        ----------
        node_ids : List[str]
            List of starting term IDs.
        include_self : bool
            Whether to include the starting nodes themselves in the ancestor sets.

        Returns
        -------
        List[str]
            List of term IDs that are common ancestors to all input terms.
        """
        if not node_ids:
            return []

        # get ancestors for the first node
        common_ancestors = set(self.get_ancestors(node_ids[0], include_self=False))

        # intersect with ancestors of the rest
        for term_id in node_ids[1:]:
            ancestors = set(self.get_ancestors(term_id, include_self=False))
            common_ancestors.intersection_update(ancestors)

            # early exit if no common ancestor remains
            if not common_ancestors:
                return []

        return set(common_ancestors)
    
    def get_lowest_common_ancestors(self, node_ids):
        """
        Return the lowest common ancestor(s) of a list of terms. 
        Lowest = closest to the given terms.

        Parameters
        ----------
        node_ids : List[str]
            List of starting term IDs.
        include_self : bool
            Whether to include the starting nodes in ancestor sets.

        Returns
        -------
        List[str]
            List of term IDs that are the lowest common ancestors.
        """
        if not node_ids:
            return []

        # Compute ancestors with distances for the first node
        first_ancestors = dict(self.get_ancestors_with_distance(node_ids[0], include_self=False))
        common_ancestors = set(first_ancestors.keys())

        # Initialize distances dict for LCA calculation
        # key: ancestor index, value: max distance from any node
        lca_distances = {idx: dist for idx, dist in first_ancestors.items()}

        # Process remaining nodes
        for term_id in node_ids[1:]:
            ancestors_with_distance = dict(self.get_ancestors_with_distance(term_id, include_self=False))
            ancestors_set = set(ancestors_with_distance.keys())
            common_ancestors.intersection_update(ancestors_set)

            # Update max distance for each common ancestor
            lca_distances = {idx: max(lca_distances[idx], ancestors_with_distance[idx])
                            for idx in common_ancestors}

            # Early exit if no common ancestor remains
            if not common_ancestors:
                return []

        if not lca_distances:
            return []

        # Find the minimum of the maximum distances
        min_distance = min(lca_distances.values())

        # Return ancestor IDs that have this minimum distance
        lowest_common_indices = [idx for idx, dist in lca_distances.items() if dist == min_distance]
        return lowest_common_indices
        
    def get_distance_from_root(self, term_id):
        """
        Calculate the distance from the given term to the root node(s) of the ontology.

        Parameters
        ----------
        term_id : str
            The term ID for which to compute the distance from root.

        Returns
        -------
        int
            Distance from the term to the root (number of edges).
            Returns 0 if the term is a root itself.
        """
        # Validate term
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        # Get all ancestors with distance
        ancestors_with_distance = self.get_ancestors_with_distance(term_id, include_self=True)

        if not ancestors_with_distance:
            # No ancestors, this term is a root
            return 0

        # Distance from root = maximum distance in the ancestors path
        max_distance = max(distance for _, distance in ancestors_with_distance)

        return max_distance
    
    def get_path_between(self, node_a, node_b):
        """
        Find the shortest path between two nodes in the ontology.

        Parameters
        ----------
        node_a : str
            Starting term ID.
        node_b : str
            Ending term ID.

        Returns
        -------
        List[str]
            List of term IDs representing the path from node_a to node_b (inclusive).
            Returns empty list if no path exists.
        """
        if node_a not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {node_a}")
        if node_b not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {node_b}")

        # Check if a path exists
        if not (self.is_ancestor(node_a, node_b) or self.is_descendant(node_a, node_b)):
            return []

        # Determine direction
        if self.is_ancestor(node_a, node_b):
            start, end = node_a, node_b
            adjacency_matrix = self.matrices_container['is_a']
        else:
            start, end = node_b, node_a
            adjacency_matrix = self.matrices_container['is_a']

        start_idx = self.lookup_tables.term_to_index(start)
        end_idx = self.lookup_tables.term_to_index(end)

        # BFS to find shortest path
        from collections import deque
        queue = deque([[start_idx]])
        visited = set([start_idx])

        while queue:
            path = queue.popleft()
            current = path[-1]

            if current == end_idx:
                return self.lookup_tables.index_to_term(path)

            # Get children (or parents depending on direction)
            neighbors_vec = adjacency_matrix @ self.one_hot_vector(current)
            neighbors = neighbors_vec.to_coo()[0]

            for n in neighbors:
                if n not in visited:
                    visited.add(n)
                    queue.append(path + [n])

        return []
    
    def is_ancestor(self, ancestor_node, descendant_node):
        """
        Check if `ancestor_node` is an ancestor of `descendant_node`.

        Parameters
        ----------
        ancestor_node : str
            Candidate ancestor term ID.
        descendant_node : str
            Candidate descendant term ID.

        Returns
        -------
        bool
            True if `ancestor_node` is an ancestor of `descendant_node`, else False.
        """
        if descendant_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {descendant_node}")
        if ancestor_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {ancestor_node}")
        
        # Retrieve ancestors of the descendant
        ancestors = set(self.get_ancestors(descendant_node, include_self=False))
        return ancestor_node in ancestors


    def is_descendant(self, descendant_node, ancestor_node):
        """
        Check if `descendant_node` is a descendant of `ancestor_node`.

        Parameters
        ----------
        descendant_node : str
            Candidate descendant term ID.
        ancestor_node : str
            Candidate ancestor term ID.

        Returns
        -------
        bool
            True if `descendant_node` is a descendant of `ancestor_node`, else False.
        """
        if ancestor_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {ancestor_node}")
        if descendant_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {descendant_node}")
        
        # Retrieve descendants of the ancestor
        descendants = set(self.get_descendants(ancestor_node, include_self=False))
        return descendant_node in descendants
    
    def get_siblings(self, term_id, include_self: bool = False):
        """
        Retrieve all siblings of a given term (i.e., nodes that share at least one parent).

        Parameters
        ----------
        term_id : str
            The term ID whose siblings are to be found.
        include_self : bool, optional (default=False)
            Whether to include the term itself in the returned set.

        Returns
        -------
        List[str]
            List of sibling term IDs.
        """
        # Validate term existence
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        # Step 1: Get parents of the given term
        parents = self.get_parents(term_id, include_self=False)
        if not parents:
            # No parents means this term is a root -> no siblings
            return []

        # Step 2: For each parent, get its children
        siblings_set = set()
        for parent_id in parents:
            children = self.get_children(parent_id, include_self=False)
            siblings_set.update(children)

        # Step 3: Optionally remove the term itself
        if not include_self and term_id in siblings_set:
            siblings_set.remove(term_id)

        # Return as sorted list for deterministic output
        return sorted(siblings_set)
    
    def is_sibling(self, node_a: str, node_b: str) -> bool:
        """
        Check if two nodes are siblings (i.e., share at least one common parent).

        Parameters
        ----------
        node_a : str
            First node (term ID).
        node_b : str
            Second node (term ID).

        Returns
        -------
        bool
            True if both nodes share at least one parent; False otherwise.
        """
        # Validate existence
        lut = self.lookup_tables.get_lut_term_to_index()
        if node_a not in lut:
            raise KeyError(f"Unknown term ID: {node_a}")
        if node_b not in lut:
            raise KeyError(f"Unknown term ID: {node_b}")

        # Step 1: Get parents for both nodes
        parents_a = set(self.get_parents(node_a, include_self=False))
        parents_b = set(self.get_parents(node_b, include_self=False))

        # Step 2: Intersection of parents indicates sibling relationship
        shared_parents = parents_a.intersection(parents_b)

        # Step 3: Return True if they share any parent
        return len(shared_parents) > 0
    
    def get_trajectories_from_root(self, term_id: str) -> list[list[dict]]:
        """
        Get all ancestor trajectories from the root(s) to the given term using GraphBLAS operations.

        Args:
            term_id (str): The identifier of the term.

        Returns:
            list[list[dict]]: List of trajectories; each trajectory is a list of dictionaries
                            with keys: 'id', 'name', and 'distance' (from the queried term).
        """
        # Validate input
        lut_term_to_index = self.lookup_tables.get_lut_term_to_index()
        if term_id not in lut_term_to_index:
            raise KeyError(f"Unknown term ID: {term_id}")

        A_T = self.matrices_container['is_a'].T
        term_idx = int(self.lookup_tables.term_to_index(term_id))

        # Root detection
        roots = set(self.get_root())
        root_indices = {int(self.lookup_tables.term_to_index(r)) for r in roots}

        from collections import deque
        queue = deque([[term_idx]])
        trajectories = []

        while queue:
            path = queue.popleft()
            current_idx = int(path[0])

            # Parent discovery using GraphBLAS multiplication
            parent_vec = (A_T @ self.one_hot_vector(current_idx)).new()
            parent_indices = [int(i) for i in parent_vec.to_coo()[0]]

            # Termination condition: reached a root or no parents
            if not parent_indices or current_idx in root_indices:
                # Reverse path → root → term order
                reversed_path = list(reversed(path))
                traj = []
                for dist, idx in enumerate(reversed_path[::-1]):  # distance from term
                    idx = int(idx)
                    traj.append({
                        'id': self.lookup_tables.index_to_term(idx),
                        'name': self.lookup_tables.term_to_description(self.lookup_tables.index_to_term(idx)),
                        'distance': dist
                    })
                trajectories.append(list(reversed(traj)))  # ensure root→term order
            else:
                for p in parent_indices:
                    if p not in path:
                        queue.append([p] + path)

        for traj in trajectories:
            traj.reverse()  # optional: reverse to have root-first order

        return trajectories  # optional: reverse to have root-first order
    
    def print_term_trajectories_tree(self,trajectories: list[dict]) -> None:
        """Print all ancestor trajectories as a single ASCII tree from root to the original term.

        Combining shared nodes.

        Args:
            trajectories: List of lists, each inner list is a trajectory (branch) as returned by ancestor_trajectories.
        """
        if not trajectories:
            print('No trajectories to display.')
            return
        root = self._build_tree_from_trajectories(trajectories)
        self._print_ascii_tree(root)

    @staticmethod
    def _build_tree_from_trajectories(trajectories: list[dict]) -> object:
        """Build a tree structure from the list of branches (trajectories).

        Returns the root node.

        Args:
            trajectories (list[dict]): List of trajectory branches.

        Returns:
            object: The root node of the tree.
        """

        class Node:
            def __init__(self, node_id: str, name: str, distance: int) -> None:
                self.id = node_id
                self.name = name
                self.distance = distance
                self.children = {}

        def insert_branch(root: Node, branch: list) -> None:
            node = root
            for item in branch:
                key = (item['id'], item['name'], item['distance'])
                if key not in node.children:
                    node.children[key] = Node(*key)
                node = node.children[key]

        # All branches are sorted from term to root, so reverse to root-to-term
        branch_lists = [list(branch) for branch in trajectories]
        root_info = branch_lists[0][0]
        root = Node(root_info['id'], root_info['name'], root_info['distance'])
        for branch in branch_lists:
            insert_branch(root, branch[1:])  # skip root itself, already created
        return root

    @staticmethod
    def _print_ascii_tree(root: object) -> None:
        """Print the tree structure in ASCII format starting from the root node."""

        def print_ascii_tree(
            node: object, prefix: str = '', is_last: bool = True
        ) -> None:
            connector = '└── ' if is_last else '├── '
            print(
                f'{prefix}{connector}{node.id}: {node.name} (distance={node.distance})'
            )
            child_items = list(node.children.values())
            for idx, child in enumerate(child_items):
                is_last_child = idx == len(child_items) - 1
                next_prefix = prefix + ('    ' if is_last else '│   ')
                print_ascii_tree(child, next_prefix, is_last_child)

        # Print root without prefix
        print(f'{root.id}: {root.name} (distance={root.distance})')
        child_items = list(root.children.values())
        for idx, child in enumerate(child_items):
            is_last_child = idx == len(child_items) - 1
            print_ascii_tree(child, '', is_last_child)

### Test methods

In [None]:
L = LookUpTables(ontology=go_ontology)

G = Graph(ontology=go_ontology, lookup_tables=L)

In [None]:
G.print_term_trajectories_tree(G.get_trajectories_from_root('GO:0000017'))

In [None]:
G.get_siblings('GO:0000017', include_self=False)

In [None]:
G.is_sibling('GO:0015759', 'GO:0051325')

In [None]:
G.get_path_between('GO:0000092', 'GO:0051325')
L.term_to_description(G.get_path_between('GO:0044848', 'GO:0000092'))

In [None]:
G.get_distance_from_root('GO:0000092')

In [None]:
G.get_common_ancestors(['GO:0000092', 'GO:0051325'])

In [None]:
print(G.edges_container['ends_during'])

In [None]:
end_during_list_indexes = []
for k, v in G.edges_container['ends_during'].items():
    print(f"{k}: {v}")
    end_during_list_indexes.extend(v)

L.term_to_description(L.index_to_term(end_during_list_indexes))

In [None]:
G.get_children('GO:0048308', include_self=False)

In [None]:
G.get_parents('GO:0048308', include_self=False)

In [None]:
G.get_root()

In [None]:
G.get_ancestors('GO:0051322', distance=5, include_self=False)

In [None]:
G.get_ancestors_with_distance('GO:0051322', include_self=False)

In [None]:
G.get_descendants('GO:0051322', distance=5, include_self=False)

In [None]:
G.get_descendants_with_distance('GO:0051322', include_self=False)

## Human-readable query engine

In [None]:
from typing import Set, List
import re

class QueryEngine:
    def __init__(self, graph):
        """
        Initialize the QueryEngine with a Graph instance.

        Parameters
        ----------
        graph : Graph
            Your ontology graph instance.
        """
        self.graph = graph

    # --------------------------
    # Parsing
    # --------------------------
    @staticmethod
    def parse_query(query: str) -> List[str]:
        """
        Very simple parser: splits query into tokens (terms, AND, OR, NOT, parentheses).

        Examples:
            "'actomyosin' AND 'stress fiber'"
            "(term1 OR term2) AND NOT term3"
        """
        # Match quoted terms, operators, and parentheses
        token_pattern = r"'[^']+'|\(|\)|\bAND\b|\bOR\b|\bNOT\b"
        tokens = re.findall(token_pattern, query, flags=re.IGNORECASE)
        return [t.upper() if t.upper() in {"AND", "OR", "NOT"} else t.strip("'") for t in tokens]

    # --------------------------
    # Postfix conversion (shunting-yard)
    # --------------------------
    def _infix_to_postfix(self, tokens: List[str]) -> List[str]:
        precedence = {"NOT": 3, "AND": 2, "OR": 1}
        output = []
        stack = []

        for token in tokens:
            if token not in {"AND", "OR", "NOT", "(", ")"}:
                output.append(token)
            elif token == "(":
                stack.append(token)
            elif token == ")":
                while stack and stack[-1] != "(":
                    output.append(stack.pop())
                stack.pop()  # remove '('
            else:
                while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence[token]:
                    output.append(stack.pop())
                stack.append(token)

        while stack:
            output.append(stack.pop())

        return output

    # --------------------------
    # Query operations
    # --------------------------
    def _operation_AND(self, sets: List[Set[str]]) -> Set[str]:
        if not sets:
            return set()
        result_set = sets[0]
        for s in sets[1:]:
            result_set = result_set.intersection(s)
        return result_set

    def _operation_OR(self, sets: List[Set[str]]) -> Set[str]:
        result_set = set()
        for s in sets:
            result_set.update(s)
        return result_set

    def _operation_NOT(self, base_set: Set[str], term_set: Set[str]) -> Set[str]:
        return base_set - term_set

    # --------------------------
    # Postfix evaluation
    # --------------------------
    def _eval_postfix(self, postfix: List[str]) -> Set[str]:
        stack = []
        for token in postfix:
            if token in {"AND", "OR"}:
                right = stack.pop()
                left = stack.pop()
                if token == "AND":
                    stack.append(self._operation_AND([left, right]))
                else:
                    stack.append(self._operation_OR([left, right]))
            elif token == "NOT":
                operand = stack.pop()
                # Universe = all nodes in the graph
                universe = set(self.graph.nodes_container)
                stack.append(self._operation_NOT(universe, operand))
            else:
                # Convert term to descendants set
                descendants = set(self.graph.get_descendants(token, include_self=False))
                stack.append(descendants)

        if len(stack) != 1:
            raise ValueError("Malformed query. Check parentheses and operators.")
        return stack[0]

    # --------------------------
    # Public API
    # --------------------------
    def execute_query(self, query: str) -> Set[str]:
        tokens = self.parse_query(query)
        postfix = self._infix_to_postfix(tokens)
        result = self._eval_postfix(postfix)
        return result

    def format_results(self, result_set: Set[str]) -> List[str]:
        return sorted(result_set)


In [None]:
L.term_to_description(G.get_children(L.description_to_term('actomyosin')))

In [None]:
qe = QueryEngine(G)

query1 = "'A' AND 'B'"
results1 = qe.execute_query(query1)
print(qe.format_results(results1))

query2 = "'A' OR 'B'"
results2 = qe.execute_query(query2)
print(qe.format_results(results2))

query3 = "'A' AND NOT 'B'"
results3 = qe.execute_query(query3)
print(qe.format_results(results3))

In [None]:
G.get_root()

In [None]:
qe = QueryEngine(G)

query1 = "actomyosin AND 'stress fiber'"
results1 = qe.execute_query(query1)
print(qe.format_results(results1))

query2 = "DESCENDANTS(actomyosin) AND 'stress fiber'"
results2 = qe.execute_query(query2)
print(qe.format_results(results2))

query3 = "actomyosin AND NOT membrane"
results3 = qe.execute_query(query3)
print(qe.format_results(results3))

In [None]:
G.get_root()

In [None]:
def tokenize(query: str):
    pattern = r"""
        (DESCENDANTS\s*\(\s*[^)]+\s*\))      # DESCENDANTS(term)
        |("[^"]+"|'[^']+')                    # quoted terms
        |\bAND\b|\bOR\b|\bNOT\b               # boolean operators
        |\(|\)                                # parentheses
        |[\w-]+                               # bare terms
    """
    matches = re.findall(pattern, query, flags=re.VERBOSE | re.IGNORECASE)
    # flatten tuples and strip
    tokens = []
    for m in matches:
        if isinstance(m, tuple):
            token = next((x for x in m if x), None)
            if token:
                tokens.append(token.strip())
        else:
            tokens.append(m.strip())
    return tokens

In [None]:
query = "DESCENDANTS(actomyosin) AND 'stress fiber'"
print(tokenize(query))

In [None]:
import re
from functools import reduce

def parse_query(query: str):
    """
    Parses a query string into terms and operators.
    Supports multi-word terms enclosed in single or double quotes.

    Example:
        "'actomyosin' AND 'stress fiber' OR membrane NOT nucleus"
    Returns:
        terms: ['actomyosin', 'stress fiber', 'membrane', 'nucleus']
        operators: ['AND']
    """
    # Regex: match quoted terms or unquoted words, and operators
    pattern = r"'([^']+)'|\"([^\"]+)\"|(\bAND\b|\bOR\b|\bNOT\b)|(\S+)"
    terms = []
    operators = []

    for match in re.finditer(pattern, query, re.IGNORECASE):
        quoted1, quoted2, operator, word = match.groups()
        if operator:
            operators.append(operator.upper())
        else:
            term = quoted1 or quoted2 or word
            terms.append(term)

    return terms, operators

# Example usage
terms, operators = parse_query("'actomyosin' AND 'stress fiber'")
print("Terms:", terms)        # ['actomyosin', 'stress fiber']
print("Operators:", operators)  # ['AND']

In [None]:
superset = list()
for description in terms:
    term = L.description_to_term(description)    
    superset.append(set(G.get_descendants(term, include_self=False)))




sets = superset
intersection = reduce(lambda a, b: a & b, sets)
print(intersection)

In [None]:
G.print_term_trajectories_tree(G.get_trajectories_from_root('G'))