In [1]:
%pip install gravis


Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import ast 
import networkx as nx
import gravis as gv
from typing import Sequence
from pprint import pprint

class ComprehensiveVisitor(ast.NodeVisitor):
    def __init__(self):
        self.module = None
        self.relations = []
        self.variables = []
        self.interactives: list[ast.Interactive] = []
        self.expressions: list[ast.Expression] = []
        self.functiontypes: list[ast.FunctionType] = []
        self.functions: list[ast.FunctionDef] = []
        self.asyncfunctions = []
        self.classes: list[ast.ClassDef] = []
        self.returns = []
        self.deletes = []
        self.assigns = []
        self.typealiases = []
        self.augassigns = []
        self.annassigns = []
        self.forloops = []
        self.asyncforloops = []
        self.whileloops = []
        self.ifs = []
        self.withs = []
        self.asyncwiths = []
        self.matches = []
        self.raises = []
        self.trys = []
        self.trystars = []
        self.asserts = []
        self.imports: list[ast.Import] = []
        self.importfroms: list[ast.ImportFrom] = []
        self.globals = []
        self.nonlocals = []
        self.exprs: list[ast.Expr] = [] 
        self.passes = []
        self.breaks = []
        self.continues = []
        self.boolops = []
        self.namedexprs: list[ast.NamedExpr] = []
        self.binops = []
        self.uarynops = []
        self.lambdas = []
        self.ifexps = []
        self.dicts = []
        self.sets = []
        self.listcomps = []
        self.setcomps = []
        self.dictcomps = []
        self.generatorexps = []
        self.awaits = []
        self.yields = []
        self.yieldfroms = []
        self.comparisons = []
        self.calls: list[ast.Call] = []
        self.formattedvalues = []
        self.joinedstrs = []
        self.constats = []
        self.attribuetes = []
        self.subscripts = []
        self.starreds = []
        self.names: list[ast.Name] = []
        self.lists = []
        self.tuples = []
        self.slices = []
        self.loads = []
        self.stores = []
        self.dels = []
        self.ands = []
        self.ors = []
        self.adds = []
        self.subs = []
        self.mults = []
        self.matmults = []
        self.divs = []
        self.mods: list[ast.Mod] = [] 
        self.pows = []
        self.lshifts = []
        self.rshifts = []
        self.bitors = []
        self.bitxors = []
        self.bitands = []
        self.floordivs = []
        self.inverts = []
        self.nots = []
        self.uaddss = []
        self.usubss = []
        self.eqss = []
        self.not_eqss = []
        self.lts = []
        self.ltes = []
        self.gts = []
        self.gtes = []
        self.iss = []
        self.isnots = []
        self.ins = []
        self.notins = []
        self.excepthandlers = []
        self.matchvalues = []
        self.matchsingleton = []
        self.matchsequences = []
        self.matchmappings = []
        self.matchclasses = []
        self.matchstars = []
        self.matchases = []
        self.machors = []
        self.typeignores = []
        self.typevars = []
        self.paramspecs = []
        self.typevartuples = []

    #region Node Visitors
    def visit_Module(self, node):
        self.module = node
        self.generic_visit(node)

    def visit_Interactive(self, node):
        self.generic_visit(node)
        self.interactives.append(node)

    def visit_Expression(self, node):
        self.generic_visit(node)
        self.expressions.append(node)

    def visit_FunctionType(self, node):
        self.generic_visit(node)
        self.functiontypes.append(node)

    def visit_FunctionDef(self, node):
        self.functions.append(node)
        self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node):
        self.asyncfunctions.append(node.name)
        self.generic_visit(node)

    def visit_ClassDef(self, node):
        self.classes.append(node)
        for base in node.bases:
            if isinstance(base, ast.Name):
                self.relations.append((base.id, node.name))
        self.generic_visit(node)

    def visit_Return(self, node):
        self.generic_visit(node)
        self.returns.append(node)

    def visit_Delete(self, node):
        self.generic_visit(node)
        self.deletes.append(node)

    def visit_Assign(self, node):
        for target in node.targets:
            if isinstance(target, ast.Name):
                self.variables.append(target.id)
        self.generic_visit(node)

    def visit_AnnAssign(self, node):
        self.generic_visit(node)
        self.annassigns.append(node)

    def visit_AugAssign(self, node):
        self.generic_visit(node)
        self.augassigns.append(node)

    def visit_For(self, node):
        self.generic_visit(node)
        self.forloops.append(node)

    def visit_AsyncFor(self, node):
        self.generic_visit(node)
        self.asyncforloops.append(node)

    def visit_While(self, node):
        self.generic_visit(node)
        self.whileloops.append(node)

    def visit_If(self, node):
        self.generic_visit(node)
        self.ifs.append(node)

    def visit_With(self, node):
        self.generic_visit(node)
        self.withs.append(node)

    def visit_AsyncWith(self, node):
        self.generic_visit(node)
        self.asyncwiths.append(node)

    def visit_Match(self, node):
        self.generic_visit(node)
        self.matches.append(node)

    def visit_Raise(self, node):
        self.generic_visit(node)
        self.raises.append(node)

    def visit_Try(self, node):
        self.generic_visit(node)
        self.trys.append(node)

    def visit_TryStar(self, node):
        self.generic_visit(node)
        self.trystars.append(node)

    def visit_Assert(self, node):
        self.generic_visit(node)
        self.asserts.append(node)
 
    def visit_Import(self, node):
        #for alias in node.names:
        self.imports.append(node)
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        module = node.module
        # for alias in node.names:
        #     self.imports.append(f"{module}.{alias.name}")
        self.importfroms.append(node)
        self.generic_visit(node)

    def visit_Global(self, node):
        self.generic_visit(node)
        self.globals.append(node)

    def visit_Nonlocal(self, node):
        self.generic_visit(node)
        self.nonlocals.append(node)

    def visit_Expr(self, node):
        self.generic_visit(node)
        self.exprs.append(node)

    def visit_Pass(self, node):
        self.generic_visit(node)
        self.passes.append(node)

    def visit_Break(self, node):
        self.generic_visit(node)
        self.breaks.append(node)

    def visit_Continue(self, node):
        self.generic_visit(node)
        self.continues.append(node)

    def visit_BoolOp(self, node):
        self.generic_visit(node)
        self.boolops.append(node)

    def visit_NamedExpr(self, node):
        self.generic_visit(node)
        self.namedexprs.append(node)

    def visit_BinOp(self, node):
        self.generic_visit(node)
        self.binops.append(node)

    def visit_UnaryOp(self, node):
        self.generic_visit(node)
        self.uarynops.append(node)

    def visit_Lambda(self, node):
        self.generic_visit(node)
        self.lambdas.append(node)

    def visit_IfExp(self, node):
        self.generic_visit(node)
        self.ifexps.append(node)

    def visit_Dict(self, node):
        self.generic_visit(node)
        self.dicts.append(node)

    def visit_Set(self, node):
        self.generic_visit(node)
        self.sets.append(node)

    def visit_ListComp(self, node):
        self.generic_visit(node)
        self.listcomps.append(node)

    def visit_SetComp(self, node):
        self.generic_visit(node)
        self.setcomps.append(node)

    def visit_DictComp(self, node):
        self.generic_visit(node)
        self.dictcomps.append(node)

    def visit_GeneratorExp(self, node):
        self.generic_visit(node)
        self.generatorexps.append(node)

    def visit_Await(self, node):
        self.generic_visit(node)
        self.awaits.append(node)

    def visit_Yield(self, node):
        self.generic_visit(node)
        self.yields.append(node)

    def visit_YieldFrom(self, node):
        self.generic_visit(node)
        self.yieldfroms.append(node)

    def visit_Compare(self, node):
        self.generic_visit(node)
        self.comparisons.append(node)

    def visit_Call(self, node):
        self.generic_visit(node)
        self.calls.append(node)

    def visit_FormattedValue(self, node):
        self.generic_visit(node)
        self.formattedvalues.append(node)

    def visit_JoinedStr(self, node):
        self.generic_visit(node)
        self.joinedstrs.append(node)

    def visit_Constant(self, node):
        self.generic_visit(node)
        self.constats.append(node)

    def visit_Attribute(self, node):
        self.generic_visit(node)
        self.attribuetes.append(node)

    def visit_Subscript(self, node):
        self.generic_visit(node)
        self.subscripts.append(node)

    def visit_Starred(self, node):
        self.generic_visit(node)
        self.starreds.append(node)

    def visit_Name(self, node):
        self.generic_visit(node)
        self.names.append(node)

    def visit_List(self, node):
        self.generic_visit(node)
        self.lists.append(node)

    def visit_Tuple(self, node):
        self.generic_visit(node)
        self.tuples.append(node)

    def visit_Slice(self, node):
        self.generic_visit(node)
        self.slices.append(node)

    def visit_Load(self, node):
        self.generic_visit(node)
        self.loads.append(node)

    def visit_Store(self, node):
        self.generic_visit(node)
        self.stores.append(node)

    def visit_Del(self, node):
        self.generic_visit(node)
        self.dels.append(node)

    def visit_And(self, node):
        self.generic_visit(node)
        self.ands.append(node)

    def visit_Or(self, node):
        self.generic_visit(node)
        self.ors.append(node)

    def visit_Add(self, node):
        self.generic_visit(node)
        self.adds.append(node)

    def visit_Sub(self, node):
        self.generic_visit(node)
        self.subs.append(node)

    def visit_Mult(self, node):
        self.generic_visit(node)
        self.mults.append(node)

    def visit_MatMult(self, node):
        self.generic_visit(node)
        self.matmults.append(node)

    def visit_Div(self, node):
        self.generic_visit(node)
        self.divs.append(node)

    def visit_Mod(self, node):
        self.generic_visit(node)
        self.mods.append(node)

    def visit_Pow(self, node):
        self.generic_visit(node)
        self.pows.append(node)

    def visit_LShift(self, node):
        self.generic_visit(node)
        self.lshifts.append(node)

    def visit_RShift(self, node):
        self.generic_visit(node)
        self.rshifts.append(node)

    def visit_BitOr(self, node):
        self.generic_visit(node)
        self.bitors.append(node)

    def visit_BitXor(self, node):
        self.generic_visit(node)
        self.bitxors.append(node)

    def visit_BitAnd(self, node):
        self.generic_visit(node)
        self.bitands.append(node)

    def visit_FloorDiv(self, node):
        self.generic_visit(node)
        self.floordivs.append(node)

    def visit_Invert(self, node):
        self.generic_visit(node)
        self.inverts.append(node)

    def visit_Not(self, node):
        self.generic_visit(node)
        self.nots.append(node)

    def visit_UAdd(self, node):
        self.generic_visit(node)
        self.uaddss.append(node)

    def visit_USub(self, node):
        self.generic_visit(node)
        self.usubss.append(node)

    def visit_Eq(self, node):
        self.generic_visit(node)
        self.eqss.append(node)

    def visit_NotEq(self, node):
        self.generic_visit(node)
        self.not_eqss.append(node)

    def visit_Lt(self, node):
        self.generic_visit(node)
        self.lts.append(node)

    def visit_LtE(self, node):
        self.generic_visit(node)
        self.ltes.append(node)

    def visit_Gt(self, node):
        self.generic_visit(node)
        self.gts.append(node)

    def visit_GtE(self, node):
        self.generic_visit(node)
        self.gtes.append(node)

    def visit_Is(self, node):
        self.generic_visit(node)
        self.iss.append(node)

    def visit_IsNot(self, node):
        self.generic_visit(node)
        self.isnots.append(node)

    def visit_In(self, node):
        self.generic_visit(node)
        self.ins.append(node)

    def visit_NotIn(self, node):
        self.generic_visit(node)
        self.notins.append(node)

    def visit_ExceptHandler(self, node):
        self.generic_visit(node)
        self.excepthandlers.append(node)

    def visit_MatchValue(self, node):
        self.generic_visit(node)
        self.matchvalues.append(node)

    def visit_MatchSingleton(self, node):
        self.generic_visit(node)
        self.matchsingleton.append(node)

    def visit_MatchSequence(self, node):
        self.generic_visit(node)
        self.matchsequences.append(node)

    def visit_MatchMapping(self, node):
        self.generic_visit(node)
        self.matchmappings.append(node)

    def visit_MatchClass(self, node):
        self.generic_visit(node)
        self.matchclasses.append(node)

    def visit_MatchStar(self, node):
        self.generic_visit(node)
        self.matchstars.append(node)

    def visit_MatchAs(self, node):
        self.generic_visit(node)
        self.matchases.append(node)

    def visit_MatchOr(self, node):
        self.generic_visit(node)
        self.machors.append(node)

    def visit_TypeIgnore(self, node):
        self.generic_visit(node)
        self.typeignores.append(node)

    def visit_TypeVar(self, node):
        self.generic_visit(node)
        self.typevars.append(node)

    def visit_ParamSpec(self, node):
        self.generic_visit(node)
        self.paramspecs.append(node)

    def visit_TypeVarTuple(self, node):
        self.generic_visit(node)
        self.typevartuples.append(node)
    #endregion

class ProjectParser:
    def __init__(self):
        self.graph = nx.DiGraph()
        
    def parse_file(self, filepath):
        try:
            with open(filepath, 'r', encoding='utf-8') as file:
                # Read the file content
                file_content = file.read()
            # Parse the content using AST
            node = ast.parse(file_content)
            visitor = ComprehensiveVisitor()
            visitor.visit(node)
            self.add_to_graph(visitor)
        except SyntaxError as se:
            # Log or print error message if needed
            # print(f"SyntaxError in file {filepath}: {se}")
            pass  # Simply skip the file
        except UnicodeDecodeError as ude:
            # Log or print error message if needed
            # print(f"UnicodeDecodeError in file {filepath}: {ude}")
            pass  # Simply skip the file
        except Exception as e:
            # Log or print error message if needed
            # print(f"Failed to parse {filepath}: {e}")
            pass  # Simply skip the file

    def add_to_graph(self, visitor: ComprehensiveVisitor):

        module = visitor.module

        # Handle module
        # self.graph.add_node(module, label=module, type='module')

        # Define types and corresponding attributes
        node_types: dict[str, Sequence[ast.AST]] = {
            'interactive': visitor.interactives,
            'expression': visitor.expressions,
            'function': visitor.functions,
            'asyncfunction': visitor.asyncfunctions,
            'class': visitor.classes,
            'return': visitor.returns,
            'delete': visitor.deletes,
            'assign': visitor.assigns,
            'typealias': visitor.typealiases,
            'augassign': visitor.augassigns,
            'annassign': visitor.annassigns,
            'forloop': visitor.forloops,
            'asyncforloop': visitor.asyncforloops,
            'whileloop': visitor.whileloops,
            'if': visitor.ifs,
            'with': visitor.withs,
            'asyncwith': visitor.asyncwiths,
            'match': visitor.matches,
            'raise': visitor.raises,
            'try': visitor.trys,
            'trystar': visitor.trystars,
            'assert': visitor.asserts,
            'import': visitor.imports,
            'importfrom': visitor.importfroms,
            'global': visitor.globals,
            'nonlocal': visitor.nonlocals,
            'expr': visitor.exprs,
            'pass': visitor.passes,
            'break': visitor.breaks,
            'continue': visitor.continues,
            'boolop': visitor.boolops,
            'namedexpr': visitor.namedexprs,
            'binop': visitor.binops,
            'unaryop': visitor.uarynops,
            'lambda': visitor.lambdas,
            'ifexp': visitor.ifexps,
            'dict': visitor.dicts,
            'set': visitor.sets,
            'listcomp': visitor.listcomps,
            'setcomp': visitor.setcomps,
            'dictcomp': visitor.dictcomps,
            'generatorexp': visitor.generatorexps,
            'await': visitor.awaits,
            'yield': visitor.yields,
            'yieldfrom': visitor.yieldfroms,
            'comparison': visitor.comparisons,
            'call': visitor.calls,
            'formattedvalue': visitor.formattedvalues,
            'joinedstr': visitor.joinedstrs,
            'constant': visitor.constats,
            'attribute': visitor.attribuetes,
            'subscript': visitor.subscripts,
            'starred': visitor.starreds,
            'name': visitor.names,
            'list': visitor.lists,
            'tuple': visitor.tuples,
            'slice': visitor.slices,
            'load': visitor.loads,
            'store': visitor.stores,
            'del': visitor.dels,
            'and': visitor.ands,
            'or': visitor.ors,
            'add': visitor.adds,
            'sub': visitor.subs,
            'mult': visitor.mults,
            'matmult': visitor.matmults,
            'div': visitor.divs,
            'mod': visitor.mods,
            'pow': visitor.pows,
            'lshift': visitor.lshifts,
            'rshift': visitor.rshifts,
            'bitor': visitor.bitors,
            'bitxor': visitor.bitxors,
            'bitand': visitor.bitands,
            'floordiv': visitor.floordivs,
            'invert': visitor.inverts,
            'not': visitor.nots,
            'uadd': visitor.uaddss,
            'usub': visitor.usubss,
            'eq': visitor.eqss,
            'not_eq': visitor.not_eqss,
            'lt': visitor.lts,
            'lte': visitor.ltes,
            'gt': visitor.gts,
            'gte': visitor.gtes,
            'is': visitor.iss,
            'isnot': visitor.isnots,
            'in': visitor.ins,
            'notin': visitor.notins,
            'excepthandler': visitor.excepthandlers,
            'matchvalue': visitor.matchvalues,
            'matchsingleton': visitor.matchsingleton,
            'matchsequence': visitor.matchsequences,
            'matchmapping': visitor.matchmappings,
            'matchclass': visitor.matchclasses,
            'matchstar': visitor.matchstars,
            'matchas': visitor.matchases,
            'machor': visitor.machors,
            'typeignore': visitor.typeignores,
            'typevar': visitor.typevars,
            'paramspec': visitor.paramspecs,
            'typevartuple': visitor.typevartuples,
        }
        
        for node_type, asts in node_types.items():
            astNodes = [astNode for astNode in asts]
            for astNode in astNodes:
                label = str(astNode)
                
                if (isinstance(astNode, ast.Module) 
                 or isinstance(astNode, ast.Interactive) 
                 or isinstance(astNode, ast.Expression)):
                    label = astNode.__class__.__name__

                if (isinstance(astNode, ast.Call)):
                    label = astNode.func.__class__.__name__

                if (isinstance(astNode, ast.Expr)): 
                    label = astNode.value.__class__.__name__

                if (isinstance(astNode, ast.ClassDef) 
                 or isinstance(astNode, ast.FunctionDef) 
                 or isinstance(astNode, ast.AsyncFunctionDef)):
                    label = astNode.name

                if (isinstance(astNode, ast.Import) 
                 or isinstance(astNode, ast.ImportFrom)):
                    label = astNode.names[0].name

                if (isinstance(astNode, ast.Name)):
                    label = astNode.id
                    
                if (isinstance(astNode, ast.Constant)):
                    label = astNode.value

                if (isinstance(astNode, ast.Attribute)): 
                    label = astNode.attr
                    
                if (isinstance(astNode, ast.List)):
                    label = 'List'

                if (isinstance(astNode, ast.Dict)):
                    label = 'Dict'

                self.graph.add_node(astNode, label=label, type=node_type)
                self.graph.add_edge(module, astNode, type='import')

    def get_files_from_directory(self, root_dir):
        for dirpath, dirnames, filenames in os.walk(root_dir):
            for filename in [f for f in filenames if f.endswith(".py")]:
                self.parse_file(os.path.join(dirpath, filename))

    def apply_styles(self):
        node_styles = {
            'class': {'color': '#ff9999', 'shape': 'ellipse'},
            'function': {'color': '#99ff99', 'shape': 'rectangle'},
            'variable': {'color': '#009999', 'shape': 'rectangle'},
            'import': {'color': '#9999ff', 'shape': 'diamond'},
            'module': {'color': '#ffff99', 'shape': 'rectangle'},
            'default': {'color': '#0c0c44', 'shape': 'circle'}
        }

        edge_styles = {
            'inheritance': {'color': '#990066', 'line_style': 'solid'},
            'import': {'color': '#005500', 'line_style': 'dashed'},
            'variable': {'color': '#000055', 'line_style': 'dotted'},
            'function': {'color': '#550000', 'line_style': 'dotted'},
            'default': {'color': '#999999', 'line_style': 'dotted'}
        }

        for node, data in self.graph.nodes(data=True):
            node_type = data.get('type', 'default')
            style = node_styles.get(node_type, node_styles['default'])
            data.update(style)
            data['opacity'] = 0.5

        for _, _, data in self.graph.edges(data=True):
            edge_type = data.get('type', 'default')
            style = edge_styles.get(edge_type, edge_styles['default'])
            data.update(style)


In [3]:
# # Use the framework
path = 'D:/Programming/Repos/QM/codecartographer/src/codecarto'
print(path)
parser = ProjectParser()
parser.get_files_from_directory(path)
parser.apply_styles()
 

D:/Programming/Repos/QM/codecartographer/src/codecarto


In [8]:
gv.d3(parser.graph,
      node_size_factor=2, node_label_size_factor=0.2,
      edge_size_factor=0.5, edge_label_size_factor=0.2,
      node_label_data_source='label',
      show_menu=True,
      show_node_label=True,
      graph_height=1000)

In [5]:
# gv.three(parser.graph,
#       node_size_factor=2, node_label_size_factor=0.2,
#       edge_size_factor=0.5, edge_label_size_factor=0.2,
#       show_menu=True,
#       graph_height=1000)

In [6]:
# gv.vis(parser.graph)

In [7]:
# gv.convert.any_to_gjgf(parser.graph, 'graph.gjgf')