# Imports

In [1]:
import functools
import re
from dataclasses import dataclass

import graphblas as gb
import numpy as np
import pandas as pd
import pronto

# Step 1. Load an ontology using Pronto

In [None]:
go_ontology = pronto.Ontology("../data/out/go.obo")

In [None]:
#chebi = pronto.Ontology("../data/out/chebi.obo", encoding="utf-8")

---

# Step 2. Create Functions and Classes

#### Create Nodes Dataframe

In [3]:
# Function to split a string by multiple separators and get the last part
def get_last_part_string(s, separators=('/', ':', '.', '#')):    
    # Create a regex pattern to split by any of the separators
    pattern = '|'.join(map(re.escape, separators))
    parts = re.split(pattern, s)

    return parts[-1] if parts else s


def get_dictitionary_annotations(annotations):
    ann_dict = {}
    for annotation in annotations:
        # Get "annotation.property"
        key = get_last_part_string(annotation.property)

        # Get elements from ResourcePropertyValue annotations
        if isinstance(annotation, pronto.ResourcePropertyValue):
            ann_dict[key] = {
                'resource': annotation.resource
            }
            continue

        # Get elements from LiteralPropertyValue annotations
        elif isinstance(annotation, pronto.LiteralPropertyValue):
            ann_dict[key] = {
                'literal': annotation.literal,
                'datatype': get_last_part_string(annotation.datatype)
            }            
    return ann_dict

def get_string_relationships(relations):
    rel_list = [relation.name for relation in relations.keys()]
    return ("|").join(rel_list)

def get_dictionary_synonyms(synonyms):
    syn_dict = {}
    for synonym in synonyms:
        entry = {
            k: v for k, v in [
                ('type', getattr(synonym.type, 'id', None) if synonym.type else None),
                ('source', "|".join(str(source.id) for source in synonym.xrefs) if synonym.xrefs else None),
                ('scope', synonym.scope if synonym.scope is not None else None)
            ] if v is not None and v != ""
        }
        syn_dict[synonym.description] = entry
    return syn_dict

def get_dictionary_xrefs(xrefs):
    xref_dict = {}
    for xref in xrefs:
        entry = {
            k: v for k, v in [
                ('description', xref.description if xref.description else None)
            ] if v is not None and v != ""
        }
        xref_dict[f"{xref.id}"] = entry
    return xref_dict
    
    
def create_nodes_dataframe(terms, include_obsolete=False):
    """Create a DataFrame with fields: ID, Name, Definition, Namespace, Subsets, Synonyms, Xrefs."""
    # Pre-bind functions for efficiency
    join = "|".join
    str_ = str
    get_ann = get_dictitionary_annotations
    get_syn = get_dictionary_synonyms
    get_xref = get_dictionary_xrefs
    get_rel = get_string_relationships

    rows = [
        {
            # Identity & Naming
            "term_id": term.id,
            "name": term.name,
            "alternate_ids": join(term.alternate_ids) if term.alternate_ids else None,
            "namespace": term.namespace,

            # Status & Lifecycle
            "obsolete": term.obsolete,
            "anonymous": term.anonymous,
            "builtin": term.builtin,
            "created_by": term.created_by,
            "creation_date": term.creation_date,
            "replaced_by": join([replacer.id for replacer in term.replaced_by]) if term.replaced_by else None,
            "consider": join(term.consider) if term.consider else None,

            # Description & Annotation
            "definition": str_(term.definition) if term.definition else None,
            "comment": term.comment,
            "annotations": str_(get_ann(term.annotations)) if term.annotations else None,
            "subsets": join(term.subsets) if term.subsets else None,
            "synonyms": str_(get_syn(term.synonyms)) if term.synonyms else None,
            "xrefs": str_(get_xref(term.xrefs)) if term.xrefs else None,

            # Logical & Semantic Relations
            "relationships": get_rel(term.relationships) if term.relationships else None,
            "disjoint_from": term.disjoint_from if term.disjoint_from else None,
            "equivalent_to": term.equivalent_to if term.equivalent_to else None,
            "intersection_of": term.intersection_of if term.intersection_of else None,
        }
        for term in terms
        if include_obsolete or not term.obsolete
    ]

    # Create DataFrame
    df = pd.DataFrame(rows)
    
    # Sort by term_id and reset index
    df.sort_values("term_id", inplace=True)
    df.reset_index(drop=True, inplace=True)

    # Add index column
    df.insert(0, 'index', range(len(df)))
    
    return df

#### Create Classes for:

In [4]:
# --- Refactored LookUpTables: does NOT store terms ---
class LookUpTables:
    def __init__(self, terms: list):
        self.__lut_term_to_index = {term.id: idx for idx, term in enumerate(terms)}
        self.__lut_index_to_term = [term.id for term in terms]
        self.__lut_term_to_description = {term.id: term.name for term in terms}
        self.__lut_description_to_term = {term.name: term.id for term in terms}

    def get_lut_term_to_index(self):
        return self.__lut_term_to_index

    def get_lut_index_to_term(self):
        return self.__lut_index_to_term

    def get_lut_term_to_description(self):
        return self.__lut_term_to_description

    def get_lut_description_to_term(self):
        return self.__lut_description_to_term

    def term_to_index(self, terms: str | list):
        if isinstance(terms, str):
            return self.__lut_term_to_index[terms]
        elif isinstance(terms, list):
            return [self.__lut_term_to_index[term] for term in terms]

    def index_to_term(self, indexes: int | list):
        if isinstance(indexes, int):
            return self.__lut_index_to_term[indexes]
        elif isinstance(indexes, list):
            return [self.__lut_index_to_term[idx] for idx in indexes]
        elif isinstance(indexes, np.ndarray):
            return [self.__lut_index_to_term[idx] for idx in indexes.tolist()]
        else:
            raise TypeError(
                f"Expected int, list[int], or np.ndarray, got {type(indexes).__name__}."
            )

    def term_to_description(self, terms: str | list):
        if isinstance(terms, str):
            return self.__lut_term_to_description[terms]
        elif isinstance(terms, list):
            return [self.__lut_term_to_description[term] for term in terms]

    def description_to_term(self, descriptions: str | list):
        if isinstance(descriptions, str):
            return self.__lut_description_to_term[descriptions]
        elif isinstance(descriptions, list):
            return [self.__lut_description_to_term[term] for term in descriptions]

@dataclass
class NodeContainer:
    nodes_indices: np.ndarray

    def __len__(self):
        return len(self.nodes_indices)

    def __getitem__(self, idx):
        return self.nodes_indices[idx]

    def as_list(self):
        return self.nodes_indices.tolist()

    def as_set(self):
        return set(self.nodes_indices)

    def __contains__(self, item):
        return item in self.nodes_indices

# --- Refactored EdgesContainer: does NOT store terms ---
class EdgesContainer:
    def __init__(self, terms: list, lookup_tables: LookUpTables):
        self.edges_indices = self._populate_index_containers(terms, lookup_tables)
        self.relations = list(self.edges_indices.keys())

    # Search all possible relations in the ontology
    def _get_ontology_relationships(self, terms):
        set_relations = set()
        for term in terms:
            for rel in term.relationships:
                rel_name = rel.name.lower().replace(" ", "_")
                set_relations.add(rel_name)
        return sorted(set_relations)
    
    # Create empty containers for each relation type
    def _create_edges_index_containers(self, terms):
        relationships = self._get_ontology_relationships(terms)

        # Always include 'is_a' relationship
        relationships.append('is_a')
        edge_container = {rel: {'rows': [], 'cols': []} for rel in relationships}
        return edge_container

    # Populate the containers with row and column indices
    def _populate_index_containers(self, terms, lookup_tables):
        edge_container = self._create_edges_index_containers(terms)
        for term in terms:
            # Populate 'is_a' relationships
            for subclass in term.subclasses(with_self=False, distance=1):
                if subclass.obsolete:
                    continue
                rel_name = 'is_a'
                edge_container[rel_name]['rows'].append(lookup_tables.term_to_index(subclass.id))
                edge_container[rel_name]['cols'].append(lookup_tables.term_to_index(term.id))

            # Populate other relationships
            for rel, targets in term.relationships.items():
                rel_name = rel.name.lower().replace(" ", "_")
                for target in targets:
                    if target.obsolete:
                        continue
                    edge_container[rel_name]['rows'].append(lookup_tables.term_to_index(term.id))
                    edge_container[rel_name]['cols'].append(lookup_tables.term_to_index(target.id))

        # Convert lists to numpy arrays with dtype np.int64
        for rel, data in edge_container.items():
            data['rows'] = np.array(data['rows'], dtype=np.int64)
            data['cols'] = np.array(data['cols'], dtype=np.int64)
        return edge_container


class Graph:
    def __init__(self, nodes_indexes, nodes_dataframe, edges_indexes, edges_dataframe, lookup_tables):
        # --- Nodes ---
        self.nodes_indexes = nodes_indexes
        self.nodes_dataframe = nodes_dataframe
        # --- Edges ---
        self.edges_indexes = edges_indexes
        self.edges_dataframe = edges_dataframe

        # --- Lookup Tables ---
        self.lookup_tables = lookup_tables
        self.number_nodes = len(self.nodes_indexes)
        self.number_edges = len(self.edges_indexes.edges_indices['is_a']['rows'])

        self.matrices_container = self.create_multiple_matrices(
            edge_container=self.edges_indexes.edges_indices,
            nrows=self.number_nodes,
            ncols=self.number_nodes
        )

    def create_graphblas_matrix(self, rows_indexes, cols_indexes, nrows, ncols, name):
        M = gb.Matrix.from_coo(rows=rows_indexes, columns=cols_indexes, values=1.0, nrows=nrows, ncols=ncols, dtype=bool, name=name)
        return M

    def create_multiple_matrices(self, edge_container, nrows, ncols):
        matrices = {}
        for relation, indexes in edge_container.items():
            M = self.create_graphblas_matrix(rows_indexes=indexes['rows'],
                                        cols_indexes=indexes['cols'],
                                        nrows=nrows,
                                        ncols=ncols,
                                        name=relation)

            matrices[relation] = M
        return matrices
    

    ## Graph Operations

    # Generates a one-hot encoded vector for a given index
    @functools.lru_cache(maxsize=None)
    def one_hot_vector(self, index: int) -> gb.Vector:
        return gb.Vector.from_coo([index], [1], size=self.number_nodes, dtype=int)

    # -- get_children(term_id, include_self=False)
    def get_children(self, term_id, include_self=False):
        # validate and resolve the index
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        index = self.lookup_tables.term_to_index(term_id)

        # Initialize a one-hot vector for the term node
        vector_node = self.one_hot_vector(index=index)

        # Propagate to children using matrix-vector multiplication
        children_vec = (self.matrices_container['is_a'] @ vector_node).new()

        # Optionally include the node itself
        if include_self:
            children_vec[index] = True

        # translate indexes to terms
        terms = [term for term in children_vec]
        
        return self.lookup_tables.index_to_term(terms)
    
    # -- get_parents(term_id, include_self=False)
    def get_parents(self, term_id, include_self=False):
        # validate and resolve the index
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        index = self.lookup_tables.term_to_index(term_id)

        # Initialize a one-hot vector for the term node
        vector_node = self.one_hot_vector(index=index)

        # Propagate to children using matrix-vector multiplication
        parent_vec = (self.matrices_container['is_a'].T @ vector_node).new()

        # Optionally include the node itself
        if include_self:
            parent_vec[index] = True

        # translate indexes to terms
        terms = [term for term in parent_vec]
        
        return self.lookup_tables.index_to_term(terms)

    # -- get_root()
    def get_root(self):
    
        matrix = self.matrices_container['is_a'].T

        # 1. Compute the number of incoming edges per node (column-wise sum)
        col_sums_expr = matrix.reduce_columnwise(gb.binary.plus)

        # 2. Materialize the VectorExpression
        col_sums_vec = col_sums_expr.new()

        # 3. Extract non-zero indices and their counts
        indices, values = col_sums_vec.to_coo()

        # 4. Create dense array of incoming edge counts
        col_sums_np = np.zeros(matrix.ncols, dtype=np.int64)
        col_sums_np[indices] = values

        # 5. Roots = nodes with zero incoming edges
        roots = np.where(col_sums_np == 0)[0]

        return self.lookup_tables.index_to_term(roots)
    
    def _traverse_graph(self, term_id, adjacency_matrix, distance=None, include_self=False):
        """
        Generalized function to traverse a graph in either direction.
        
        Parameters
        ----------
        term_id : str
            The starting term ID.
        adjacency_matrix : gb.Matrix
            Adjacency matrix to traverse (forward for descendants, transposed for ancestors).
        distance : int or None
            Maximum distance to traverse. None means unlimited.
        include_self : bool
            Whether to include the starting node in the result.

        Returns
        -------
        List[str]
            List of term IDs reached.
        """
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")
        
        index = self.lookup_tables.term_to_index(term_id)
        current_vector = self.one_hot_vector(index=index)
        visited = set()
        
        if include_self:
            visited.add(index)
        
        while current_vector.nvals != 0 and distance != 0:
            next_vector = gb.Vector(dtype=int, size=adjacency_matrix.nrows)
            next_vector << gb.semiring.plus_times(adjacency_matrix @ current_vector)  # forward or transposed depends on matrix
            
            next_indices = set(next_vector.to_coo()[0])
            next_indices.difference_update(visited)
            
            if not next_indices:
                break
            
            visited.update(next_indices)
            current_vector = gb.Vector.from_coo(list(next_indices), [1]*len(next_indices), size=adjacency_matrix.nrows)

            if distance is not None:
                distance -= 1

        return self.lookup_tables.index_to_term(list(visited))


    # Public API functions
    def get_ancestors(self, term_id, distance=None, include_self=False):
        adjacency_matrix = self.matrices_container['is_a'].T  # transpose for ancestors
        return self._traverse_graph(term_id, adjacency_matrix, distance, include_self)

    def get_descendants(self, term_id, distance=None, include_self=False):
        adjacency_matrix = self.matrices_container['is_a']  # normal direction for descendants
        return self._traverse_graph(term_id, adjacency_matrix, distance, include_self)

    def _traverse_graph_with_distance(self, term_id, adjacency_matrix, include_self=False):
        """
        Generalized function to traverse a graph and return nodes with distance from start.
        
        Parameters
        ----------
        term_id : str
            The starting term ID.
        adjacency_matrix : gb.Matrix
            Adjacency matrix to traverse (forward for descendants, transposed for ancestors).
        include_self : bool
            Whether to include the starting node with distance 0.

        Returns
        -------
        List[Tuple[int, int]]
            List of tuples (node_index, distance_from_start)
        """
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")
        
        start_index = self.lookup_tables.term_to_index(term_id)
        current_vector = self.one_hot_vector(index=start_index)
        
        distances = {}  # {node_index: distance}
        distance_counter = 0
        
        if include_self:
            distances[start_index] = 0

        while current_vector.nvals != 0:
            next_vector = gb.Vector(dtype=int, size=adjacency_matrix.nrows)
            next_vector << gb.semiring.plus_times(adjacency_matrix @ current_vector)
            
            next_indices = set(next_vector.to_coo()[0])
            # remove already visited nodes
            next_indices.difference_update(distances.keys())
            
            if not next_indices:
                break
            
            distance_counter += 1
            for idx in next_indices:
                distances[idx] = distance_counter
            
            current_vector = gb.Vector.from_coo(list(next_indices), [1]*len(next_indices), size=adjacency_matrix.nrows)

        # return as list of tuples
        return [(self.lookup_tables.index_to_term(int(index)), distance) for index, distance in distances.items()]


    # Public API functions
    def get_ancestors_with_distance(self, term_id, include_self=False):
        adjacency_matrix = self.matrices_container['is_a'].T  # transpose for ancestors
        return self._traverse_graph_with_distance(term_id, adjacency_matrix, include_self)

    def get_descendants_with_distance(self, term_id, include_self=False):
        adjacency_matrix = self.matrices_container['is_a']  # normal direction for descendants
        return self._traverse_graph_with_distance(term_id, adjacency_matrix, include_self)
    
    def get_common_ancestors(self, node_ids):
        """
        Return the common ancestors of a list of terms.

        Parameters
        ----------
        node_ids : List[str]
            List of starting term IDs.
        include_self : bool
            Whether to include the starting nodes themselves in the ancestor sets.

        Returns
        -------
        List[str]
            List of term IDs that are common ancestors to all input terms.
        """
        if not node_ids:
            return []

        # get ancestors for the first node
        common_ancestors = set(self.get_ancestors(node_ids[0], include_self=False))

        # intersect with ancestors of the rest
        for term_id in node_ids[1:]:
            ancestors = set(self.get_ancestors(term_id, include_self=False))
            common_ancestors.intersection_update(ancestors)

            # early exit if no common ancestor remains
            if not common_ancestors:
                return []

        return set(common_ancestors)
    
    def get_lowest_common_ancestors(self, node_ids):
        """
        Return the lowest common ancestor(s) of a list of terms. 
        Lowest = closest to the given terms.

        Parameters
        ----------
        node_ids : List[str]
            List of starting term IDs.
        include_self : bool
            Whether to include the starting nodes in ancestor sets.

        Returns
        -------
        List[str]
            List of term IDs that are the lowest common ancestors.
        """
        if not node_ids:
            return []

        # Compute ancestors with distances for the first node
        first_ancestors = dict(self.get_ancestors_with_distance(node_ids[0], include_self=False))
        common_ancestors = set(first_ancestors.keys())

        # Initialize distances dict for LCA calculation
        # key: ancestor index, value: max distance from any node
        lca_distances = {idx: dist for idx, dist in first_ancestors.items()}

        # Process remaining nodes
        for term_id in node_ids[1:]:
            ancestors_with_distance = dict(self.get_ancestors_with_distance(term_id, include_self=False))
            ancestors_set = set(ancestors_with_distance.keys())
            common_ancestors.intersection_update(ancestors_set)

            # Update max distance for each common ancestor
            lca_distances = {idx: max(lca_distances[idx], ancestors_with_distance[idx])
                            for idx in common_ancestors}

            # Early exit if no common ancestor remains
            if not common_ancestors:
                return []

        if not lca_distances:
            return []

        # Find the minimum of the maximum distances
        min_distance = min(lca_distances.values())

        # Return ancestor IDs that have this minimum distance
        lowest_common_indices = [idx for idx, dist in lca_distances.items() if dist == min_distance]
        return lowest_common_indices
        
    def get_distance_from_root(self, term_id):
        """
        Calculate the distance from the given term to the root node(s) of the ontology.

        Parameters
        ----------
        term_id : str
            The term ID for which to compute the distance from root.

        Returns
        -------
        int
            Distance from the term to the root (number of edges).
            Returns 0 if the term is a root itself.
        """
        # Validate term
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        # Get all ancestors with distance
        ancestors_with_distance = self.get_ancestors_with_distance(term_id, include_self=True)

        if not ancestors_with_distance:
            # No ancestors, this term is a root
            return 0

        # Distance from root = maximum distance in the ancestors path
        max_distance = max(distance for _, distance in ancestors_with_distance)

        return max_distance
    
    def get_path_between(self, node_a, node_b):
        """
        Find the shortest path between two nodes in the ontology.

        Parameters
        ----------
        node_a : str
            Starting term ID.
        node_b : str
            Ending term ID.

        Returns
        -------
        List[str]
            List of term IDs representing the path from node_a to node_b (inclusive).
            Returns empty list if no path exists.
        """
        if node_a not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {node_a}")
        if node_b not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {node_b}")

        # Check if a path exists
        if not (self.is_ancestor(node_a, node_b) or self.is_descendant(node_a, node_b)):
            return []

        # Determine direction
        if self.is_ancestor(node_a, node_b):
            start, end = node_a, node_b
            adjacency_matrix = self.matrices_container['is_a']
        else:
            start, end = node_b, node_a
            adjacency_matrix = self.matrices_container['is_a']

        start_idx = self.lookup_tables.term_to_index(start)
        end_idx = self.lookup_tables.term_to_index(end)

        # BFS to find shortest path
        from collections import deque
        queue = deque([[start_idx]])
        visited = set([start_idx])

        while queue:
            path = queue.popleft()
            current = path[-1]

            if current == end_idx:
                return self.lookup_tables.index_to_term(path)

            # Get children (or parents depending on direction)
            neighbors_vec = adjacency_matrix @ self.one_hot_vector(current)
            neighbors = neighbors_vec.to_coo()[0]

            for n in neighbors:
                if n not in visited:
                    visited.add(n)
                    queue.append(path + [n])

        return []
    
    def is_ancestor(self, ancestor_node, descendant_node):
        """
        Check if `ancestor_node` is an ancestor of `descendant_node`.

        Parameters
        ----------
        ancestor_node : str
            Candidate ancestor term ID.
        descendant_node : str
            Candidate descendant term ID.

        Returns
        -------
        bool
            True if `ancestor_node` is an ancestor of `descendant_node`, else False.
        """
        if descendant_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {descendant_node}")
        if ancestor_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {ancestor_node}")
        
        # Retrieve ancestors of the descendant
        ancestors = set(self.get_ancestors(descendant_node, include_self=False))
        return ancestor_node in ancestors


    def is_descendant(self, descendant_node, ancestor_node):
        """
        Check if `descendant_node` is a descendant of `ancestor_node`.

        Parameters
        ----------
        descendant_node : str
            Candidate descendant term ID.
        ancestor_node : str
            Candidate ancestor term ID.

        Returns
        -------
        bool
            True if `descendant_node` is a descendant of `ancestor_node`, else False.
        """
        if ancestor_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {ancestor_node}")
        if descendant_node not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {descendant_node}")
        
        # Retrieve descendants of the ancestor
        descendants = set(self.get_descendants(ancestor_node, include_self=False))
        return descendant_node in descendants
    
    def get_siblings(self, term_id, include_self: bool = False):
        """
        Retrieve all siblings of a given term (i.e., nodes that share at least one parent).

        Parameters
        ----------
        term_id : str
            The term ID whose siblings are to be found.
        include_self : bool, optional (default=False)
            Whether to include the term itself in the returned set.

        Returns
        -------
        List[str]
            List of sibling term IDs.
        """
        # Validate term existence
        if term_id not in self.lookup_tables.get_lut_term_to_index():
            raise KeyError(f"Unknown term ID: {term_id}")

        # Step 1: Get parents of the given term
        parents = self.get_parents(term_id, include_self=False)
        if not parents:
            # No parents means this term is a root -> no siblings
            return []

        # Step 2: For each parent, get its children
        siblings_set = set()
        for parent_id in parents:
            children = self.get_children(parent_id, include_self=False)
            siblings_set.update(children)

        # Step 3: Optionally remove the term itself
        if not include_self and term_id in siblings_set:
            siblings_set.remove(term_id)

        # Return as sorted list for deterministic output
        return sorted(siblings_set)
    
    def is_sibling(self, node_a: str, node_b: str) -> bool:
        """
        Check if two nodes are siblings (i.e., share at least one common parent).

        Parameters
        ----------
        node_a : str
            First node (term ID).
        node_b : str
            Second node (term ID).

        Returns
        -------
        bool
            True if both nodes share at least one parent; False otherwise.
        """
        # Validate existence
        lut = self.lookup_tables.get_lut_term_to_index()
        if node_a not in lut:
            raise KeyError(f"Unknown term ID: {node_a}")
        if node_b not in lut:
            raise KeyError(f"Unknown term ID: {node_b}")

        # Step 1: Get parents for both nodes
        parents_a = set(self.get_parents(node_a, include_self=False))
        parents_b = set(self.get_parents(node_b, include_self=False))

        # Step 2: Intersection of parents indicates sibling relationship
        shared_parents = parents_a.intersection(parents_b)

        # Step 3: Return True if they share any parent
        return len(shared_parents) > 0
    
    def get_trajectories_from_root(self, term_id: str) -> list[list[dict]]:
        """
        Get all ancestor trajectories from the root(s) to the given term using GraphBLAS operations.

        Args:
            term_id (str): The identifier of the term.

        Returns:
            list[list[dict]]: List of trajectories; each trajectory is a list of dictionaries
                            with keys: 'id', 'name', and 'distance' (from the queried term).
        """
        # Validate input
        lut_term_to_index = self.lookup_tables.get_lut_term_to_index()
        if term_id not in lut_term_to_index:
            raise KeyError(f"Unknown term ID: {term_id}")

        A_T = self.matrices_container['is_a'].T
        term_idx = int(self.lookup_tables.term_to_index(term_id))

        # Root detection
        roots = set(self.get_root())
        root_indices = {int(self.lookup_tables.term_to_index(r)) for r in roots}

        from collections import deque
        queue = deque([[term_idx]])
        trajectories = []

        while queue:
            path = queue.popleft()
            current_idx = int(path[0])

            # Parent discovery using GraphBLAS multiplication
            parent_vec = (A_T @ self.one_hot_vector(current_idx)).new()
            parent_indices = [int(i) for i in parent_vec.to_coo()[0]]

            # Termination condition: reached a root or no parents
            if not parent_indices or current_idx in root_indices:
                # Reverse path → root → term order
                reversed_path = list(reversed(path))
                traj = []
                for dist, idx in enumerate(reversed_path[::-1]):  # distance from term
                    idx = int(idx)
                    traj.append({
                        'id': self.lookup_tables.index_to_term(idx),
                        'name': self.lookup_tables.term_to_description(self.lookup_tables.index_to_term(idx)),
                        'distance': dist
                    })
                trajectories.append(list(reversed(traj)))  # ensure root→term order
            else:
                for p in parent_indices:
                    if p not in path:
                        queue.append([p] + path)

        for traj in trajectories:
            traj.reverse()  # optional: reverse to have root-first order

        return trajectories  # optional: reverse to have root-first order
    
    def print_term_trajectories_tree(self,trajectories: list[dict]) -> None:
        """Print all ancestor trajectories as a single ASCII tree from root to the original term.

        Combining shared nodes.

        Args:
            trajectories: List of lists, each inner list is a trajectory (branch) as returned by ancestor_trajectories.
        """
        if not trajectories:
            print('No trajectories to display.')
            return
        root = self._build_tree_from_trajectories(trajectories)
        self._print_ascii_tree(root)

    @staticmethod
    def _build_tree_from_trajectories(trajectories: list[dict]) -> object:
        """Build a tree structure from the list of branches (trajectories).

        Returns the root node.

        Args:
            trajectories (list[dict]): List of trajectory branches.

        Returns:
            object: The root node of the tree.
        """

        class Node:
            def __init__(self, node_id: str, name: str, distance: int) -> None:
                self.id = node_id
                self.name = name
                self.distance = distance
                self.children = {}

        def insert_branch(root: Node, branch: list) -> None:
            node = root
            for item in branch:
                key = (item['id'], item['name'], item['distance'])
                if key not in node.children:
                    node.children[key] = Node(*key)
                node = node.children[key]

        # All branches are sorted from term to root, so reverse to root-to-term
        branch_lists = [list(branch) for branch in trajectories]
        root_info = branch_lists[0][0]
        root = Node(root_info['id'], root_info['name'], root_info['distance'])
        for branch in branch_lists:
            insert_branch(root, branch[1:])  # skip root itself, already created
        return root

    @staticmethod
    def _print_ascii_tree(root: object) -> None:
        """Print the tree structure in ASCII format starting from the root node."""

        def print_ascii_tree(
            node: object, prefix: str = '', is_last: bool = True
        ) -> None:
            connector = '└── ' if is_last else '├── '
            print(
                f'{prefix}{connector}{node.id}: {node.name} (distance={node.distance})'
            )
            child_items = list(node.children.values())
            for idx, child in enumerate(child_items):
                is_last_child = idx == len(child_items) - 1
                next_prefix = prefix + ('    ' if is_last else '│   ')
                print_ascii_tree(child, next_prefix, is_last_child)

        # Print root without prefix
        print(f'{root.id}: {root.name} (distance={root.distance})')
        child_items = list(root.children.values())
        for idx, child in enumerate(child_items):
            is_last_child = idx == len(child_items) - 1
            print_ascii_tree(child, '', is_last_child)

#### Extract ontology info and populate objects

In [5]:
def extract_terms(ontology, include_obsolete=False):
    """Single-pass extraction of pronto.Term objects, sorted by term.id."""
    terms = [t for t in ontology.terms() if include_obsolete or not t.obsolete]
    terms.sort(key=lambda t: t.id)
    return terms


# --- Refactored extract_data_ontology ---
def extract_data_ontology(ontology, include_obsolete=False):
    terms = extract_terms(ontology, include_obsolete=include_obsolete)
    LUTS = LookUpTables(terms=terms)
    nodes_df = create_nodes_dataframe(terms=terms, include_obsolete=include_obsolete)
    nodes_indexes = NodeContainer(nodes_indices=nodes_df['index'].to_numpy(dtype=np.int64))
    edges_indexes = EdgesContainer(terms=terms, lookup_tables=LUTS)
    edges_df = None

    G = Graph(
        nodes_indexes=nodes_indexes,
        nodes_dataframe=nodes_df,
        edges_indexes=edges_indexes,
        edges_dataframe=edges_df,
        lookup_tables=LUTS
    )

    return LUTS, G

In [6]:
LUTS, G = extract_data_ontology(ontology=go_ontology)

In [7]:
indexes = LUTS.term_to_index(G.get_root())
print(indexes)

eval(G.nodes_dataframe.iloc[indexes[2]].annotations)

[2344, 3517, 5111]


{'term_tracker_item': {'literal': 'https://github.com/geneontology/go-ontology/issues/24968',
  'datatype': 'anyURI'}}

# Step 3. Delete Ontologies from memory

In [8]:
# del go_ontology, chebi

# Step 4. Evaluate operations

- [x] `get_ancestors`
- [x] `get_ancestors_with_distance`
- [x] `get_children`
- [x] `get_common_ancestors`
- [x] `get_descendants`
- [x] `get_descendants_with_distance`
- [x] `get_distance_from_root`
- [x] `get_lowest_common_ancestors`
- [x] `get_parents`
- [x] `get_path_between`
- [x] `get_root`
- [x] `get_siblings`
- [ ] `get_term`  <--- which information to include?
- [x] `get_trajectories_from_root`
- [x] `is_ancestor`
- [x] `is_descendant`
- [x] `is_sibling`
- [ ] `load` <--- It should load the graph and the ontology?
- [x] `print_term_trajectories_tree`

### `get_ancestors(term_id, distance=None, include_self=False)`

In [9]:
G.get_ancestors('GO:0051322', distance=5, include_self=False)

['GO:0022403', 'GO:0008150', 'GO:0044848']

### `get_ancestors_with_distance(term_id, include_self=False)`

In [10]:
G.get_ancestors_with_distance('GO:0051322', include_self=False)

[('GO:0022403', 1), ('GO:0044848', 2), ('GO:0008150', 3)]

### `get_children(term_id, include_self=False)`

In [11]:
G.get_children('GO:0048308', include_self=False)

['GO:0000001',
 'GO:0000011',
 'GO:0009665',
 'GO:0045033',
 'GO:0048309',
 'GO:0048313']

### `get_common_ancestors(node_ids)`

In [12]:
G.get_common_ancestors(['GO:0000092', 'GO:0051325'])

{'GO:0008150', 'GO:0022403', 'GO:0044848'}

### `get_descendants(term_id, distance=None, include_self=False)`

In [13]:
G.get_descendants('GO:0051322', distance=5, include_self=False)

['GO:0000092', 'GO:0000090', 'GO:0000091']

### `get_descendants_with_distance(term_id, include_self=False)`

In [14]:
G.get_descendants_with_distance('GO:0051322', include_self=False)

[('GO:0000090', 1), ('GO:0000092', 2), ('GO:0000091', 2)]

### `get_distance_from_root(term_id)`

In [15]:
G.get_distance_from_root('GO:0000092')

5

### `get_distance_from_root(term_id)`

In [16]:
G.get_lowest_common_ancestors(['GO:0000092', 'GO:0051325'])

['GO:0022403']

### `get_parents(term_id, include_self=False)`

In [17]:
G.get_parents('GO:0048308', include_self=False)

['GO:0006996']

### `get_path_between(node_a, node_b)`

In [18]:
LUTS.term_to_description(G.get_path_between('GO:0044848', 'GO:0000092'))

['biological phase',
 'cell cycle phase',
 'anaphase',
 'mitotic anaphase',
 'mitotic anaphase B']

### `get_root()`

In [19]:
G.get_root()

['GO:0003674', 'GO:0005575', 'GO:0008150']

### `get_siblings(term_id)`

In [20]:
G.get_siblings('GO:0000017', include_self=False)

['GO:0015759', 'GO:1902418']

### `get_trajectories_from_root(term_id)`

In [21]:
G.get_trajectories_from_root('GO:0000017')

[[{'id': 'GO:0008150', 'name': 'biological_process', 'distance': 0},
  {'id': 'GO:0051179', 'name': 'localization', 'distance': 1},
  {'id': 'GO:0051234', 'name': 'establishment of localization', 'distance': 2},
  {'id': 'GO:0006810', 'name': 'transport', 'distance': 3},
  {'id': 'GO:1901264',
   'name': 'carbohydrate derivative transport',
   'distance': 4},
  {'id': 'GO:1901656', 'name': 'glycoside transport', 'distance': 5},
  {'id': 'GO:0042946', 'name': 'glucoside transport', 'distance': 6},
  {'id': 'GO:0000017', 'name': 'alpha-glucoside transport', 'distance': 7}]]

### `is_ancestor(ancestor_node, descendant_node)`

In [22]:
G.is_ancestor('GO:0015759', 'GO:0051325')

False

In [23]:
G.is_ancestor('GO:0015759', 'GO:0042946')

False

### `is_descendant(descendant_node, ancestor_node)`

In [24]:
G.is_descendant('GO:0015759', 'GO:0042946')

True

### `is_sibling(node_a, node_b)`

In [25]:
G.is_sibling('GO:0015759', 'GO:0000017')

True

### `print_term_trajectories_tree(trajectories)`

In [26]:
G.print_term_trajectories_tree(G.get_trajectories_from_root('GO:0000017'))

GO:0008150: biological_process (distance=0)
└── GO:0051179: localization (distance=1)
    └── GO:0051234: establishment of localization (distance=2)
        └── GO:0006810: transport (distance=3)
            └── GO:1901264: carbohydrate derivative transport (distance=4)
                └── GO:1901656: glycoside transport (distance=5)
                    └── GO:0042946: glucoside transport (distance=6)
                        └── GO:0000017: alpha-glucoside transport (distance=7)


---

# Human-readable query engine

In [27]:
from typing import Set, List
import re

class QueryEngine:
    def __init__(self, graph):
        """
        Initialize the QueryEngine with a Graph instance.

        Parameters
        ----------
        graph : Graph
            Your ontology graph instance.
        """
        self.graph = graph

    # --------------------------
    # Parsing
    # --------------------------
    @staticmethod
    def parse_query(query: str) -> List[str]:
        """
        Very simple parser: splits query into tokens (terms, AND, OR, NOT, parentheses).

        Examples:
            "'actomyosin' AND 'stress fiber'"
            "(term1 OR term2) AND NOT term3"
        """
        # Match quoted terms, operators, and parentheses
        token_pattern = r"'[^']+'|\(|\)|\bAND\b|\bOR\b|\bNOT\b"
        tokens = re.findall(token_pattern, query, flags=re.IGNORECASE)
        return [t.upper() if t.upper() in {"AND", "OR", "NOT"} else t.strip("'") for t in tokens]

    # --------------------------
    # Postfix conversion (shunting-yard)
    # --------------------------
    def _infix_to_postfix(self, tokens: List[str]) -> List[str]:
        precedence = {"NOT": 3, "AND": 2, "OR": 1}
        output = []
        stack = []

        for token in tokens:
            if token not in {"AND", "OR", "NOT", "(", ")"}:
                output.append(token)
            elif token == "(":
                stack.append(token)
            elif token == ")":
                while stack and stack[-1] != "(":
                    output.append(stack.pop())
                stack.pop()  # remove '('
            else:
                while stack and stack[-1] != "(" and precedence.get(stack[-1], 0) >= precedence[token]:
                    output.append(stack.pop())
                stack.append(token)

        while stack:
            output.append(stack.pop())

        return output

    # --------------------------
    # Query operations
    # --------------------------
    def _operation_AND(self, sets: List[Set[str]]) -> Set[str]:
        if not sets:
            return set()
        result_set = sets[0]
        for s in sets[1:]:
            result_set = result_set.intersection(s)
        return result_set

    def _operation_OR(self, sets: List[Set[str]]) -> Set[str]:
        result_set = set()
        for s in sets:
            result_set.update(s)
        return result_set

    def _operation_NOT(self, base_set: Set[str], term_set: Set[str]) -> Set[str]:
        return base_set - term_set

    # --------------------------
    # Postfix evaluation
    # --------------------------
    def _eval_postfix(self, postfix: List[str]) -> Set[str]:
        stack = []
        for token in postfix:
            if token in {"AND", "OR"}:
                right = stack.pop()
                left = stack.pop()
                if token == "AND":
                    stack.append(self._operation_AND([left, right]))
                else:
                    stack.append(self._operation_OR([left, right]))
            elif token == "NOT":
                operand = stack.pop()
                # Universe = all nodes in the graph
                universe = set(self.graph.nodes_container)
                stack.append(self._operation_NOT(universe, operand))
            else:
                # Convert term to descendants set
                term_id = self.graph.lookup_tables.description_to_term(token)
                descendants = set(self.graph.get_descendants(term_id, include_self=False))
                stack.append(descendants)

        if len(stack) != 1:
            raise ValueError("Malformed query. Check parentheses and operators.")
        return stack[0]

    # --------------------------
    # Public API
    # --------------------------
    def execute_query(self, query: str) -> Set[str]:
        tokens = self.parse_query(query)
        postfix = self._infix_to_postfix(tokens)
        result = self._eval_postfix(postfix)
        return result

    def format_results(self, result_set: Set[str]) -> List[str]:
        return sorted(result_set)


In [28]:
LUTS.term_to_description(G.get_children(LUTS.description_to_term('actomyosin')))

['stress fiber']

In [29]:
qe = QueryEngine(G)

query1 = "'striated muscle myosin thick filament assembly' AND 'cellular process'"
results1 = qe.execute_query(query1)
print(qe.format_results(results1))
print(LUTS.term_to_description(qe.format_results(results1)))

query2 = "'striated muscle myosin thick filament assembly' OR 'cellular process'"
results2 = qe.execute_query(query2)
print('\n',qe.format_results(results2))
print(LUTS.term_to_description(qe.format_results(results2)))

['GO:0030241', 'GO:0071690']
['skeletal muscle myosin thick filament assembly', 'cardiac muscle myosin thick filament assembly']

 ['GO:0000001', 'GO:0000011', 'GO:0000012', 'GO:0000022', 'GO:0000023', 'GO:0000024', 'GO:0000025', 'GO:0000027', 'GO:0000028', 'GO:0000032', 'GO:0000038', 'GO:0000045', 'GO:0000050', 'GO:0000052', 'GO:0000053', 'GO:0000054', 'GO:0000055', 'GO:0000056', 'GO:0000070', 'GO:0000073', 'GO:0000075', 'GO:0000076', 'GO:0000077', 'GO:0000082', 'GO:0000086', 'GO:0000096', 'GO:0000097', 'GO:0000098', 'GO:0000103', 'GO:0000105', 'GO:0000128', 'GO:0000132', 'GO:0000147', 'GO:0000154', 'GO:0000160', 'GO:0000162', 'GO:0000165', 'GO:0000183', 'GO:0000184', 'GO:0000196', 'GO:0000209', 'GO:0000212', 'GO:0000226', 'GO:0000244', 'GO:0000245', 'GO:0000255', 'GO:0000256', 'GO:0000266', 'GO:0000270', 'GO:0000271', 'GO:0000272', 'GO:0000278', 'GO:0000280', 'GO:0000281', 'GO:0000282', 'GO:0000288', 'GO:0000289', 'GO:0000290', 'GO:0000292', 'GO:0000294', 'GO:0000301', 'GO:0000316', 

In [30]:
LUTS.term_to_description(['GO:0030241'])

['skeletal muscle myosin thick filament assembly']

---

### Experimental

In [31]:
import h5py

In [32]:
with h5py.File('../data/out/nodes_edges_container.h5', 'w') as f:
    # Create a group for edges
    grp_edges = f.create_group('edges')
    for rel, data in G.edges_indexes.edges_indices.items():
        rel_grp = grp_edges.create_group(rel)
        rel_grp.create_dataset('rows', data=data['rows'], compression='gzip')
        rel_grp.create_dataset('cols', data=data['cols'], compression='gzip')

    # Create a group for nodes
    grp_nodes = f.create_group('nodes')
    grp_nodes.create_dataset('indices', data=G.nodes_indexes.nodes_indices, compression='gzip')