In [1]:
# Generating lexer and parser from grammar files, with visitor
! cd _work/06-parsing && antlr4 -Dlanguage=Python3 -visitor JavaLexer.g4 JavaParser.g4

In [2]:
# Import the generated lexer, parser, and visitor
import sys
sys.path.append("_work/06-parsing")

import antlr4
from antlr4 import FileStream, CommonTokenStream
from JavaLexer import JavaLexer
from JavaParser import JavaParser
from JavaParserVisitor import JavaParserVisitor

In [26]:
# Other imports
from pathlib import Path
import seutil as su
import random
import collections
import dataclasses
from tqdm import tqdm
from typing import Dict, List, Optional, Set

logger = su.log.get_logger("log")

In [4]:
# Parse the subject file
subject_file = Path.cwd() / "07-static-analysis-bytecode" / "src" / "main" / "java" / "ca" / "uwaterloo" / "cs846" / "Subject.java"

input_stream = FileStream(subject_file)
lexer = JavaLexer(input_stream)
stream = CommonTokenStream(lexer)
parser = JavaParser(stream)
tree = parser.compilationUnit()

In [5]:
def pretty_print_tree(node, rule_names=None, parser=parser, indent_level=0) -> str:
    if rule_names is None:
        rule_names = parser.ruleNames

    s = indent_level * " "

    if isinstance(node, antlr4.RuleNode):
        if (
            node.getAltNumber() != 0
        ):  # should use ATN.INVALID_ALT_NUMBER but won't compile
            s += (
                "<"
                + rule_names[node.getRuleIndex()]
                + ":"
                + str(node.getAltNumber())
                + ">"
            )
        s += "<" + rule_names[node.getRuleIndex()] + ">"
    elif isinstance(node, antlr4.ErrorNode):
        s += "<" + str(node) + ">"
    elif isinstance(node, antlr4.TerminalNode):
        if node.symbol is not None:
            s += node.symbol.text

    if node.getChildCount() == 1:
        s += " : " + pretty_print_tree(
            node.getChild(0), rule_names, parser, indent_level=0
        )
    else:
        for i in range(node.getChildCount()):
            c = node.getChild(i)
            s += "\n" + pretty_print_tree(c, rule_names, parser, indent_level + 2)
    return s

def get_text(node) -> str:
    if isinstance(node, antlr4.TerminalNode):
        return node.getText()
    else:
        return input_stream.getText(node.start.start, node.stop.stop)



In [None]:
print(su.io.load(subject_file, fmt=su.io.fmts.txt))
print(pretty_print_tree(tree))

In [43]:
# Implement the visitor
@dataclasses.dataclass
class MethodContext:
    name: str = ""
    signature: str = ""
    defs: Set[str] = dataclasses.field(default_factory=set)
    uses: Set[str] = dataclasses.field(default_factory=set)


class FindUnusedLocalVarsVisitor(JavaParserVisitor):
    def __init__(self):
        super().__init__()
        self.cur_method_ctx: Optional[MethodContext] = None

    def visitMethodDeclaration(self, ctx: JavaParser.MethodDeclarationContext):
        # save the current method context (in case of nested methods)
        prev_method_ctx = self.cur_method_ctx

        # create a new method context
        name = ctx.identifier().getText()
        ret_type = ctx.typeTypeOrVoid().getText()
        param_types = []
        for parameter in ctx.formalParameters().formalParameterList().formalParameter():
            param_types.append(parameter.typeType().getText())
        signature = f"{ret_type} {name}({', '.join(param_types)})"

        self.cur_method_ctx = MethodContext(name=name, signature=signature)
        logger.info(f"beginning a method, {self.cur_method_ctx}")

        # delegate to super visitor
        node = super().visitMethodDeclaration(ctx)

        # check for any unused variables
        logger.info(f"ending a method, {self.cur_method_ctx}")
        unused_vars = self.cur_method_ctx.defs - self.cur_method_ctx.uses
        if len(unused_vars) > 0:
            logger.warning(f"unused variables in `{self.cur_method_ctx.signature}`: {unused_vars}")

        # restore the previous method context
        self.cur_method_ctx = prev_method_ctx
        return node

    # record definitions of variables (in VariableDeclaratorId)
    def visitVariableDeclaratorId(self, ctx: JavaParser.VariableDeclaratorIdContext):
        if self.cur_method_ctx is not None:
            self.cur_method_ctx.defs.add(ctx.identifier().getText())
        return super().visitVariableDeclaratorId(ctx)
    
    # record uses of variables (usually in Primary (see Expression rule))
    def visitPrimary(self, ctx: JavaParser.PrimaryContext):
        if self.cur_method_ctx is not None:
            if ctx.identifier() is not None:
                var = ctx.identifier().getText()
                # Primary may be a mix of local/global variable uses and types, so we only record the ones defined in the current method
                if var in self.cur_method_ctx.defs:
                    self.cur_method_ctx.uses.add(var)
        return super().visitPrimary(ctx)


In [None]:
# logger.setLevel(su.log.INFO)  # with detailed logging
logger.setLevel(su.log.WARNING)  # with only warnings (final results)
visitor = FindUnusedLocalVarsVisitor()
tree.accept(visitor)