@@ -605,6 +605,7 @@ class Navigation:
605605 maybe_edge_types : Optional [List [EdgeType ]] = None
606606 direction : str = Direction .outbound
607607 maybe_two_directional_outbound_edge_type : Optional [List [EdgeType ]] = None
608+ edge_filter : Optional [Term ] = None
608609
609610 @property
610611 def edge_types (self ) -> List [EdgeType ]:
@@ -617,14 +618,18 @@ def __str__(self) -> str:
617618 mo = self .maybe_two_directional_outbound_edge_type
618619 depth = ("" if start == 1 else f"[{ start } ]" ) if start == until and not mo else f"[{ start } :{ until_str } ]"
619620 out_nav = "," .join (mo ) if mo else ""
620- nav = f'{ "," .join (self .edge_types )} { depth } { out_nav } '
621+ fltr = f"{{{ self .edge_filter } }}" if self .edge_filter else ""
622+ nav = f'{ "," .join (self .edge_types )} { depth } { fltr } { out_nav } '
621623 if self .direction == Direction .outbound :
622624 return f"-{ nav } ->"
623625 elif self .direction == Direction .inbound :
624626 return f"<-{ nav } -"
625627 else :
626628 return f"<-{ nav } ->"
627629
630+ def change_variable (self , fn : Callable [[str ], str ]) -> Navigation :
631+ return evolve (self , edge_filter = self .edge_filter .change_variable (fn )) if self .edge_filter else self
632+
628633
629634NavigateUntilRoot = Navigation (
630635 start = 1 , until = Navigation .Max , maybe_edge_types = [EdgeTypes .default ], direction = Direction .inbound
@@ -740,6 +745,7 @@ def change_variable(self, fn: Callable[[str], str]) -> Part:
740745 term = self .term .change_variable (fn ),
741746 with_clause = self .with_clause .change_variable (fn ) if self .with_clause else None ,
742747 sort = [sort .change_variable (fn ) for sort in self .sort ],
748+ navigation = self .navigation .change_variable (fn ) if self .navigation else None ,
743749 )
744750
745751 # ancestor.some_type.reported.prop -> MergeQuery
@@ -1012,17 +1018,40 @@ def filter_with(self, clause: WithClause) -> Query:
10121018 first_part = evolve (self .parts [0 ], with_clause = clause )
10131019 return evolve (self , parts = [first_part , * self .parts [1 :]])
10141020
1015- def traverse_out (self , start : int = 1 , until : int = 1 , edge_type : EdgeType = EdgeTypes .default ) -> Query :
1016- return self .traverse (start , until , edge_type , Direction .outbound )
1017-
1018- def traverse_in (self , start : int = 1 , until : int = 1 , edge_type : EdgeType = EdgeTypes .default ) -> Query :
1019- return self .traverse (start , until , edge_type , Direction .inbound )
1020-
1021- def traverse_inout (self , start : int = 1 , until : int = 1 , edge_type : EdgeType = EdgeTypes .default ) -> Query :
1022- return self .traverse (start , until , edge_type , Direction .any )
1021+ def traverse_out (
1022+ self ,
1023+ start : int = 1 ,
1024+ until : int = 1 ,
1025+ edge_type : EdgeType = EdgeTypes .default ,
1026+ edge_filter : Optional [Term ] = None ,
1027+ ) -> Query :
1028+ return self .traverse (start , until , edge_type , Direction .outbound , edge_filter )
1029+
1030+ def traverse_in (
1031+ self ,
1032+ start : int = 1 ,
1033+ until : int = 1 ,
1034+ edge_type : EdgeType = EdgeTypes .default ,
1035+ edge_filter : Optional [Term ] = None ,
1036+ ) -> Query :
1037+ return self .traverse (start , until , edge_type , Direction .inbound , edge_filter )
1038+
1039+ def traverse_inout (
1040+ self ,
1041+ start : int = 1 ,
1042+ until : int = 1 ,
1043+ edge_type : EdgeType = EdgeTypes .default ,
1044+ edge_filter : Optional [Term ] = None ,
1045+ ) -> Query :
1046+ return self .traverse (start , until , edge_type , Direction .any , edge_filter )
10231047
10241048 def traverse (
1025- self , start : int , until : int , edge_type : EdgeType = EdgeTypes .default , direction : str = Direction .outbound
1049+ self ,
1050+ start : int ,
1051+ until : int ,
1052+ edge_type : EdgeType = EdgeTypes .default ,
1053+ direction : str = Direction .outbound ,
1054+ edge_filter : Optional [Term ] = None ,
10261055 ) -> Query :
10271056 parts = self .parts .copy ()
10281057 p0 = parts [0 ]
@@ -1034,9 +1063,15 @@ def traverse(
10341063 parts [0 ] = evolve (p0 , navigation = evolve (p0 .navigation , start = start_m , until = until_m ))
10351064 # this is another traversal: so we need to start a new part
10361065 else :
1037- parts .insert (0 , Part (AllTerm (), navigation = Navigation (start , until , [edge_type ], direction )))
1066+ parts .insert (
1067+ 0 ,
1068+ Part (
1069+ AllTerm (),
1070+ navigation = Navigation (start , until , [edge_type ], direction , edge_filter = edge_filter ),
1071+ ),
1072+ )
10381073 else :
1039- parts [0 ] = evolve (p0 , navigation = Navigation (start , until , [edge_type ], direction ))
1074+ parts [0 ] = evolve (p0 , navigation = Navigation (start , until , [edge_type ], direction , edge_filter = edge_filter ))
10401075 return evolve (self , parts = parts )
10411076
10421077 def group_by (self , variables : List [AggregateVariable ], funcs : List [AggregateFunction ]) -> Query :
0 commit comments