In [44]:
import dotenv
dotenv.load_dotenv("../.env")

from termcolor import colored as col
import diskcache
import libcst
import openai


## Config
OPENAI_MODEL_TO_USE = "gpt-3.5-turbo-0613"


## Types
StringQuoteLiteralStyles = ['"', "'", '"""', "'''"]


# Cache all OpenAI API calls to avoid paying for same call more than once
cache = diskcache.Cache("cache")


def preprocess_func_or_class_def(definition: str) -> str:
  assert isinstance(definition, str), "Expected definition to be a string"

  # TODO: try to minimize tokens passed to OpenAI API
  definition = definition.strip()

  return definition


def capitalize_first_letter(string: str) -> str:
  assert isinstance(string, str), "Expected string to be a string"

  string = string.strip()

  assert len(string) > 0, "Expected string to be non-empty"

  return string[0].upper() + string[1:]


def postprocess_docstring(docstring: str) -> str:
  assert isinstance(docstring, str), "Expected docstring to be a string"
  
  PREFIXES_TO_REMOVE = [
    "This function",
    "This method",
    "This class",
  ]

  # Remove leading prefixes from docstring
  for prefix in PREFIXES_TO_REMOVE:
    if docstring.startswith(prefix):
      docstring = docstring.removeprefix(prefix)
      break

  # Remove leading and trailing whitespace
  docstring = docstring.strip()

  # Capitalize first letter
  docstring = capitalize_first_letter(docstring)

  return docstring


@cache.memoize()
def generate_docstring_with_openai(definition: str, definition_type: str) -> str:
  assert isinstance(definition, str), "Expected definition to be a string"
  assert isinstance(definition_type, str), "Expected definition_type to be a string"
  assert definition_type in ["function", "class"], "Expected definition_type to be either 'function' or 'class'"

  # Generate docstring with OpenAI API
  response = openai.ChatCompletion.create(
    model=OPENAI_MODEL_TO_USE,
    messages=[
      {
        "role": "system",
        "content": f"""Describe what this Python {definition_type} does.
Limit your description to 1 sentence; pretend that it is going to be used as a comment within a codebase.
Your description should begin with a verb, and should not include the {definition_type} name itself in the description.

Example output:
Creates a matrix of given shape and propagates it with random samples from a uniform distribution.""",
      },
      {
        "role": "user",
        "content": definition,
      },
    ],
  )

  docstring = response["choices"][0]["message"]["content"]

  print(col(f"Generated docstring for {definition_type}:", "green"), docstring, "\n")

  return docstring


def check_if_node_has_docstring(
  node: libcst.FunctionDef | libcst.ClassDef,
) -> bool:
  assert isinstance(node, (libcst.FunctionDef | libcst.ClassDef)), "Expected node to be a function or a class definition"

  first_stmt_line = node.body.body and isinstance(node.body.body[0], libcst.SimpleStatementLine)
  has_expr = first_stmt_line and node.body.body[0].body and isinstance(node.body.body[0].body[0], libcst.Expr)
  has_docstr = has_expr and isinstance(node.body.body[0].body[0].value, libcst.SimpleString)

  return has_docstr


class DocstringTransformer(libcst.CSTTransformer):
  def __init__(
    self,
    module: libcst.Module,
    docstring_quote_style: str = '"""',
    generate_docstring_func: callable = generate_docstring_with_openai,
  ) -> None:
    assert isinstance(module, libcst.Module), "Expected module to be a `libcst.Module`"
    assert docstring_quote_style in StringQuoteLiteralStyles, "Expected docstring_quote_style to be a `StringQuoteLiteralStyles`"
    assert callable(generate_docstring_func), "Expected generate_docstring_func to be a callable"

    self._module = module
    self._docstring_quote_style = docstring_quote_style
    self._generate_docstring_func = generate_docstring_func
  
  def leave_ClassDef(
    self,
    original_node: libcst.ClassDef,
    updated_node: libcst.ClassDef,
  ) -> libcst.ClassDef:
    # TEMP: Disable docstring generation for classes
    return updated_node

    original_node # Silence unused variable warning

    # Skip if class already has docstring
    if check_if_node_has_docstring(updated_node):
      return updated_node
    
    # Generate docstring value
    docstring_value = self.generate_docstring(updated_node)

    # Insert docstring as first statement in class node
    updated_node = self.insert_docstring(updated_node, docstring_value)
    
    # Return updated node
    return updated_node

  def leave_FunctionDef(
    self,
    original_node: libcst.FunctionDef,
    updated_node: libcst.FunctionDef,
  ) -> libcst.FunctionDef:
    original_node # Silence unused variable warning

    # Skip if function already has docstring
    if check_if_node_has_docstring(updated_node):
      return updated_node
    
    # Generate docstring value
    docstring_value = self.generate_docstring(updated_node)

    # Insert docstring as first statement in function node
    updated_node = self.insert_docstring(updated_node, docstring_value)
    
    # Return updated node
    return updated_node

  def extract_node_source_code(
    self,
    node: libcst.FunctionDef | libcst.ClassDef,
  ) -> str:
      assert isinstance(node, (libcst.FunctionDef | libcst.ClassDef)), "Expected node to be a function or a class definition"
      
      node_source_code = self._module.code_for_node(node)
      
      return node_source_code


  def generate_docstring(
    self,
    node: libcst.FunctionDef | libcst.ClassDef,
  ) -> str:
      assert isinstance(node, (libcst.FunctionDef | libcst.ClassDef)), "Expected node to be a function or a class definition"

      print(f"Generating docstring value for: `{col(node.name.value, 'blue')}`")
      
      # Extract node source code
      node_source_code = self.extract_node_source_code(node)

      # Preprocess function or class definition to minimize tokens passed to OpenAI API
      node_source_code = preprocess_func_or_class_def(node_source_code)

      # Determine node type
      if isinstance(node, libcst.FunctionDef):
        node_type = "function"
      elif isinstance(node, libcst.ClassDef):
        node_type = "class"
      else:
        raise Exception(f"Unexpected node type: {node_type}")
      
      # Generate docstring (i.e. via OpenAI API wrapper function)
      docstring_value = self._generate_docstring_func(node_source_code, node_type)

      # Postprocess generated docstring
      docstring_value = postprocess_docstring(docstring_value)

      # Print docstring
      print(f"Generated docstring:", col(docstring_value, "green"))

      return docstring_value


  def insert_docstring(
    self,
    node: libcst.FunctionDef | libcst.ClassDef,
    docstring_value: str,
  ) -> libcst.FunctionDef:
    assert isinstance(node, (libcst.FunctionDef | libcst.ClassDef)), "Expected node to be a function or a class definition"
    assert isinstance(docstring_value, str), "Expected docstring_value to be a string"

    # Create docstring node
    docstring_node = libcst.SimpleStatementLine(body=[
      libcst.Expr(value=libcst.SimpleString(value=f'{self._docstring_quote_style}{docstring_value}{self._docstring_quote_style}'))
    ])

    # Insert docstring node as first statement in parent node
    new_body = libcst.IndentedBlock(body=[docstring_node] + list(node.body.body))
    
    # Return updated node
    return node.with_changes(body=new_body)


##
## Demo
##


# Load source code to transform
source_code_path = "example_source.py"
with open(source_code_path) as file:
  source_code = file.read()

# Parse source code
module = libcst.parse_module(source_code)

# Initialize docstring transformer
transformer = DocstringTransformer(
  module=module,
  docstring_quote_style='"""',
  generate_docstring_func=generate_docstring_with_openai,
)

# Apply docstring transformer to source code
modified_module = module.visit(transformer)

# Print transformed code
with open(f"modified_{source_code_path}", "w") as file:
  file.write(modified_module.code)

Generating docstring value for: `[34mexample_func_0[0m)`
Generated docstring: [32mPrints "Hello".[0m
Generating docstring value for: `[34mexample_func_1[0m)`
Generated docstring: [32mPrints the values of `arg0` and `arg1` to the console.[0m
Generating docstring value for: `[34m__init__[0m)`
Generated docstring: [32mAssigns the given value to the "value" attribute of the object.[0m
Generating docstring value for: `[34madd[0m)`
[32mGenerated docstring for function:[0m Increments the value attribute of an object by a given value. 

Generated docstring: [32mIncrements the value attribute of an object by a given value.[0m
Generating docstring value for: `[34msubtract[0m)`
[32mGenerated docstring for function:[0m Subtracts the given value from the attribute value of the object. 

Generated docstring: [32mSubtracts the given value from the attribute value of the object.[0m
