In [361]:
import dotenv

dotenv.load_dotenv("../.env")

import openai
import libcst

MODEL = "gpt-3.5-turbo-0613"

In [362]:
def recursively_remove_field_titles(schema: dict) -> dict:
    keys_to_remove = []
    for key, value in schema.items():
        if key == "title":
            keys_to_remove.append(key)
        elif isinstance(value, dict):
            recursively_remove_field_titles(value)
        elif isinstance(value, list):
            for item in value:
                if isinstance(item, dict):
                    recursively_remove_field_titles(item)

    for key in keys_to_remove:
        schema.pop(key)

    return schema

In [363]:
from typing import List
from pydantic import BaseModel, Field
import json

class FunctionParameterData(BaseModel):
  name:         str = Field(description="The name of the parameter")
  description:  str = Field(description="The description of the parameter")
  assumed_type: str = Field(description="The assumed type(s) of the parameter; example: 'str, int, float'")


class FunctionReturnData(BaseModel):
  name:         str = Field(description="The name of the return value")
  description:  str = Field(description="The description of the return value")
  assumed_type: str = Field(description="The assumed type(s) of the return value; example: 'str, int, float'")


class FunctionExceptionData(BaseModel):
  name:        str = Field(description="The name of the exception")
  description: str = Field(description="The description of the exception")


class FunctionDocstringData(BaseModel):
  description:   str                         = Field(description="The overall description of the function in 1 sentence, you may consider starting it with a verb.")
  parameters:    List[FunctionParameterData] = Field(description="An array of descriptions for each of the function's parameters, if it has any.")
  return_values: List[FunctionReturnData]    = Field(description="An array of descriptions for each of the function's return values, if it has any.")
  exceptions:    List[FunctionExceptionData] = Field(description="An array of descriptions for each of the exceptions the function may raise, if it has any.")


schema = FunctionDocstringData.schema(
  by_alias=False,
  ref_template="#/components/schemas/{model}",
)

schema = recursively_remove_field_titles(schema)

schema_string = json.dumps(schema, indent=2)

print(schema_string)

{
  "type": "object",
  "properties": {
    "description": {
      "description": "The overall description of the function in 1 sentence, you may consider starting it with a verb.",
      "type": "string"
    },
    "parameters": {
      "description": "An array of descriptions for each of the function's parameters, if it has any.",
      "type": "array",
      "items": {
        "$ref": "#/components/schemas/FunctionParameterData"
      }
    },
    "return_values": {
      "description": "An array of descriptions for each of the function's return values, if it has any.",
      "type": "array",
      "items": {
        "$ref": "#/components/schemas/FunctionReturnData"
      }
    },
    "exceptions": {
      "description": "An array of descriptions for each of the exceptions the function may raise, if it has any.",
      "type": "array",
      "items": {
        "$ref": "#/components/schemas/FunctionExceptionData"
      }
    }
  },
  "required": [
    "description",
    "parameters",

In [364]:
functions = [
    {
        "name": "respond_with_structured_docstring_data_for_function",
        "parameters": {
            "type": "object",
            "properties": {
                "description": {
                    "type": "string",
                    "description": "The overall description of the function in 1 sentence.",
                },
                "parameters": {
                    "type": "array",
                    "description": "An array of descriptions for each of the function's parameters, if it has any.",
                    "items": {
                        "type": "object",
                        "properties": {
                            "name": {
                                "type": "string",
                            },
                            "description": {
                                "type": "string",
                            },
                            "assumed_type": {
                                "type": "string",
                                "description": "The assumed type(s) of the parameter; example: 'str, int, float'.",
                            },
                        },
                    },
                },
                "return_values": {
                    "type": "array",
                    "description": "An array of descriptions for each of the function's return values, if it has any.",
                    "items": {
                        "type": "object",
                        "properties": {
                            "name": {
                                "type": "string",
                            },
                            "description": {
                                "type": "string",
                            },
                            "assumed_type": {
                                "type": "string",
                                "description": "The assumed type(s) of the return value; example: 'str, int, float'.",
                            },
                        },
                    },
                },
                "exceptions": {
                    "type": "array",
                    "description": "An array of descriptions for each of the exceptions the function may raise, if it has any.",
                    "items": {
                        "type": "object",
                        "properties": {
                            "name": {
                                "type": "string",
                            },
                            "description": {
                                "type": "string",
                            },
                        },
                    },
                },
            },
        },
    },
]

In [365]:
functions = [
    {
        "name": "respond_with_structured_docstring_data_for_function",
        "parameters": schema,
    },
]

In [366]:
import diskcache

cache = diskcache.Cache("cache")

# @cache.memoize()
def generate_function_docstring(
    node_source_code: str,
    temperature: float = 0.25,
):
    # Generate docstring
    response = openai.ChatCompletion.create(
        model=MODEL,
        temperature=temperature,
        functions=functions,
        function_call={ "name": "respond_with_structured_docstring_data_for_function", },
        messages=[
            {
                "role": "system",
                "content": f"""
You are a robot who is an expert at writing docstrings for Python functions, mainly because you are extremely good at being concise.
Help the human write a docstring for the following function:
""",
            },
            { "role": "user", "content": node_source_code, },
        ],
    )
    
    # Extract response
    message = response.choices[0]["message"]
    if "function_call" not in message:
        raise Exception(f"Unexpected response: {response}")

    function_call = message["function_call"]
    function_call_name = function_call["name"]
    
    if function_call_name == "respond_with_structured_docstring_data_for_function":
        return function_call["arguments"]

    if function_call_name == "skip_function":
        print("Skipping function...")
        return None

    raise Exception(f"Unexpected function call: {function_call_name}")

In [367]:
import textwrap

def build_google_docstring(data: FunctionDocstringData) -> str:
    docstring = data.description + '\n\n'
    
    if data.parameters:
        docstring += 'Args:\n'
        for param in data.parameters:
            docstring += f'    {param.name} ({param.assumed_type}): {param.description}\n'

    if data.return_values:
        docstring += '\nReturns:\n'
        for ret in data.return_values:
            docstring += f'    {ret.name} ({ret.assumed_type}): {ret.description}\n'

    if data.exceptions:
        docstring += '\nRaises:\n'
        for exc in data.exceptions:
            docstring += f'    {exc.name}: {exc.description}\n'

    return docstring

In [368]:
import textwrap

def build_numpy_docstring(data: FunctionDocstringData) -> str:
    docstring = data.description + "\n\n"

    # Parameters section
    if data.parameters:
        docstring += "Parameters\n----------\n"
        for param in data.parameters:
            docstring += f"{param.name} : {param.assumed_type}\n    {param.description}\n"

    # Return values section
    if data.return_values:
        docstring += "\nReturns\n-------\n"
        for ret_val in data.return_values:
            docstring += f"{ret_val.name} : {ret_val.assumed_type}\n    {ret_val.description}\n"

    # Exceptions section
    if data.exceptions:
        docstring += "\nRaises\n------\n"
        for exc in data.exceptions:
            docstring += f"{exc.name}\n    {exc.description}\n"

    return docstring

In [369]:
# Basic function visitor
from termcolor import colored


class FunctionVisitor(libcst.CSTVisitor):
    def __init__(self, module) -> None:
        self.module = module

    def visit_FunctionDef(self, node: libcst.FunctionDef) -> None:
        print(f"Found function:\n{node.name.value}")

        # Extract node source code
        node_source_code = self.module.code_for_node(node)

        # TEMP, testing something
        node_source_code = "abcd"

        # Generate docstring
        docstring = generate_function_docstring(node_source_code)

        # Parse to docstring struct
        docstring_struct = FunctionDocstringData.parse_raw(docstring)

        # Generate NumPy docstring
        numpy_docstring = build_numpy_docstring(docstring_struct)
        print(f"Generated NumPy docstring:\n{colored(numpy_docstring, 'green')}")
        print("-" * 80)

        # Generate Google docstring
        google_docstring = build_google_docstring(docstring_struct)
        print(f"Generated Google docstring:\n{colored(google_docstring, 'yellow')}")
        print("-" * 80)

In [370]:
# Read example source code from disk
with open("example_source_one_function.py") as file:
    example_source = file.read()

# Parse source code into a CST
module = libcst.parse_module(example_source)

# Test
_ = module.visit(FunctionVisitor(module))

Found function:
download_github_repo
Generated NumPy docstring:
[32mReturn the result of concatenating the strings 'a', 'b', 'c', and 'd'.


Returns
-------
result : str
    The concatenated string.
[0m
--------------------------------------------------------------------------------
Generated Google docstring:
[33mReturn the result of concatenating the strings 'a', 'b', 'c', and 'd'.


Returns:
    result (str): The concatenated string.
[0m
--------------------------------------------------------------------------------
