In [96]:
import libcst as cst
from typing import cast

parsed = cst.parse_module(read_file("test_data/ex_code_1.py"))
parsed.body[2]

FunctionDef(
    name=Name(
        value='t_add',
        lpar=[],
        rpar=[],
    ),
    params=Parameters(
        params=[
            Param(
                name=Name(
                    value='x',
                    lpar=[],
                    rpar=[],
                ),
                annotation=None,
                equal=MaybeSentinel.DEFAULT,
                default=None,
                comma=Comma(
                    whitespace_before=SimpleWhitespace(
                        value='',
                    ),
                    whitespace_after=SimpleWhitespace(
                        value=' ',
                    ),
                ),
                star='',
                whitespace_after_star=SimpleWhitespace(
                    value='',
                ),
                whitespace_after_param=SimpleWhitespace(
                    value='',
                ),
            ),
            Param(
                name=Name(
                    value='y',
    

In [97]:
from typing import List, Tuple, Dict, Set, Optional, Union

class SpecialNames:
    Return = "<return>"

class AnnotCollector(cst.CSTVisitor):
    def __init__(self):
        # stack for storing the canonical name of the current function
        self.stack: List[str] = []
        # store the type annotations
        self.annotations: Dict[
            Tuple[str, ...],  # key: tuple of canonical variable name
            Optional[cst.Annotation],  # value: (params, returns)
        ] = {}

    def on_visit(self, node):
        if isinstance(node, cst.FunctionDef) or isinstance(node, cst.ClassDef) or isinstance(node, cst.Param):
            self.stack.append(node.name.value)
        elif isinstance(node, cst.AnnAssign):
            self.stack.append(node.target.value)
        return super().on_visit(node)

    def on_leave(self, node):
        r = super().on_leave(node)
        if isinstance(node, cst.FunctionDef) or isinstance(node, cst.ClassDef) or isinstance(node, cst.Param) or isinstance(node, cst.AnnAssign):
            self.stack.pop()
        return r

    def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
        self.stack.append(SpecialNames.Return)
        self.annotations[tuple(self.stack)] = node.returns
        self.stack.pop()

    def visit_Param(self, node: cst.Param) -> Optional[bool]:
        self.annotations[tuple(self.stack)] = node.annotation
    
    def visit_AnnAssign(self, ndoe: cst.AnnAssign) -> Optional[bool]:
        self.annotations[tuple(self.stack)] = ndoe.annotation

class AnnotApplier(cst.CSTTransformer):
    def __init__(self, annots: Dict[Tuple[str, ...], cst.Annotation]):
        self.annots = annots
        # stack for storing the canonical name of the current function
        self.stack: List[str] = []
        # store the target prefixes
        self.prefixes: Set[Tuple[str, ...]] = set()
        for path in annots.keys():
            self.prefixes.update(path[0:i] for i in range(len(path)+1))

    def on_visit(self, node):
        if isinstance(node, cst.FunctionDef) or isinstance(node, cst.ClassDef) or isinstance(node, cst.Param):
            self.stack.append(node.name.value)
        elif isinstance(node, cst.AnnAssign):
            self.stack.append(node.target.value)
        if tuple(self.stack) not in self.prefixes:
            return False
        return super().on_visit(node)

    def on_leave(self, node, updated):
        r = super().on_leave(node, updated)
        if isinstance(node, cst.FunctionDef) or isinstance(node, cst.ClassDef) or isinstance(node, cst.Param) or isinstance(node, cst.AnnAssign):
            self.stack.pop()
        return r

    def leave_FunctionDef(self, node: cst.FunctionDef, updated: cst.FunctionDef) -> cst.FunctionDef:
        self.stack.append(SpecialNames.Return)
        patch = self.annots.get(tuple(self.stack))
        self.stack.pop()
        return updated if patch is None else updated.with_changes(returns = patch)

    def leave_Param(self, node: cst.Param, updated: cst.Param) -> cst.Param:
        patch = self.annots.get(tuple(self.stack))
        return updated if patch is None else updated.with_changes(annotation = patch)
        

In [98]:
collector = AnnotCollector()
parsed.visit(collector)
collector.annotations

{('fib',
  '<return>'): Annotation(
     annotation=Subscript(
         value=Name(
             value='List',
             lpar=[],
             rpar=[],
         ),
         slice=[
             SubscriptElement(
                 slice=Index(
                     value=Name(
                         value='int',
                         lpar=[],
                         rpar=[],
                     ),
                 ),
                 comma=MaybeSentinel.DEFAULT,
             ),
         ],
         lbracket=LeftSquareBracket(
             whitespace_after=SimpleWhitespace(
                 value='',
             ),
         ),
         rbracket=RightSquareBracket(
             whitespace_before=SimpleWhitespace(
                 value='',
             ),
         ),
         lpar=[],
         rpar=[],
         whitespace_after_value=SimpleWhitespace(
             value='',
         ),
     ),
     whitespace_before_indicator=SimpleWhitespace(
         value=' ',
     ),
     whi

In [99]:
# TODO: make unit tests
# print(parsed.visit(AnnotApplier(("fib", SpecialNames.Return), cst.Annotation(cst.Name("int")))).code)
applier = AnnotApplier({
    ("fib", "n"): cst.Annotation(cst.Name("int")),
    ("fib", SpecialNames.Return): cst.Annotation(cst.Name("int")),
})
print(parsed.visit(applier).code)
# applier.prefixes

from typing import List

# A recursive fibonacci function
def fib(n: int) -> int:
    if n == 0:
        return 0
    elif n == 1:
        return 1
    else:
        return fib(n-1) + fib(n-2)

def t_add(x, y) -> "List[int]":
    r = x + y
    return r

x: int = fib(3)
