Skip to content

Commit

Permalink
🔧 Enforce type checking in needuml.py (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Feb 19, 2024
1 parent c2ad574 commit aef3a0f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 36 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
'sphinx_needs.api.need',
'sphinx_needs.directives.needuml',
]
ignore_errors = true

Expand Down
4 changes: 2 additions & 2 deletions sphinx_needs/diagrams_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sphinx.util.docutils import SphinxDirective

from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import NeedsFilteredBaseType, NeedsPartsInfoType
from sphinx_needs.data import NeedsFilteredBaseType, NeedsInfoType, NeedsPartsInfoType
from sphinx_needs.errors import NoUri
from sphinx_needs.logging import get_logger
from sphinx_needs.utils import get_scale, split_link_types
Expand Down Expand Up @@ -169,7 +169,7 @@ def get_debug_container(puml_node: nodes.Element) -> nodes.container:


def calculate_link(
app: Sphinx, need_info: NeedsPartsInfoType, _fromdocname: str
app: Sphinx, need_info: NeedsInfoType | NeedsPartsInfoType, _fromdocname: None | str
) -> str:
"""
Link calculation
Expand Down
98 changes: 65 additions & 33 deletions sphinx_needs/directives/needuml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import html
import os
from typing import Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, TypedDict

from docutils import nodes
from docutils.parsers.rst import directives
Expand All @@ -11,13 +11,25 @@
from sphinx.util.docutils import SphinxDirective

from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import SphinxNeedsData
from sphinx_needs.data import NeedsInfoType, SphinxNeedsData
from sphinx_needs.debug import measure_time
from sphinx_needs.diagrams_common import calculate_link
from sphinx_needs.directives.needflow import make_entity_name
from sphinx_needs.filter_common import filter_needs
from sphinx_needs.utils import add_doc

if TYPE_CHECKING:
from sphinxcontrib.plantuml import plantuml


class ProcessedDataType(TypedDict):
art: str
key: None | str
arguments: dict[str, Any]


ProcessedNeedsType = Dict[str, List[ProcessedDataType]]


class Needuml(nodes.General, nodes.Element):
pass
Expand Down Expand Up @@ -130,8 +142,13 @@ def run(self) -> Sequence[nodes.Node]:


def transform_uml_to_plantuml_node(
app, uml_content: str, parent_need_id: str, key: str, kwargs: dict, config: str
):
app: Sphinx,
uml_content: str,
parent_need_id: None | str,
key: None | str,
kwargs: dict[str, Any],
config: str,
) -> plantuml:
try:
if "sphinxcontrib.plantuml" not in app.config.extensions:
raise ImportError
Expand All @@ -158,7 +175,7 @@ def transform_uml_to_plantuml_node(
puml_node["uml"] += "\n\n"

# jinja2uml to translate jinja statements to uml text
(uml_content_return, processed_need_ids_return) = jinja2uml(
(uml_content_return, _) = jinja2uml(
app=app,
fromdocname=None,
uml_content=uml_content,
Expand All @@ -167,16 +184,15 @@ def transform_uml_to_plantuml_node(
processed_need_ids={},
kwargs=kwargs,
)
# silently discard processed_need_ids_return

puml_node["uml"] += f"\n{uml_content_return}"
puml_node["uml"] += "\n@enduml\n"
return puml_node


def get_debug_node_from_puml_node(puml_node):
def get_debug_node_from_puml_node(puml_node: plantuml) -> nodes.container:
if isinstance(puml_node, nodes.figure):
data = puml_node.children[0]["uml"]
data = puml_node.children[0]["uml"] # type: ignore[index]
data = puml_node.get("uml", "")
data = "\n".join([html.escape(line) for line in data.split("\n")])
debug_para = nodes.raw("", f"<pre>{data}</pre>", format="html")
Expand All @@ -186,20 +202,20 @@ def get_debug_node_from_puml_node(puml_node):


def jinja2uml(
app,
fromdocname,
app: Sphinx,
fromdocname: None | str,
uml_content: str,
parent_need_id: str,
key: str,
processed_need_ids: {},
kwargs: dict,
) -> (str, {}):
parent_need_id: None | str,
key: None | str,
processed_need_ids: ProcessedNeedsType,
kwargs: dict[str, Any],
) -> tuple[str, ProcessedNeedsType]:
# Let's render jinja templates with uml content template to 'plantuml syntax' uml
# 1. Remove @startuml and @enduml
uml_content = uml_content.replace("@startuml", "").replace("@enduml", "")

# 2. Prepare jinja template
mem_template = Environment(loader=BaseLoader).from_string(uml_content)
mem_template = Environment(loader=BaseLoader()).from_string(uml_content)

# 3. Get a new instance of Jinja Helper Functions
jinja_utils = JinjaFunctions(app, fromdocname, parent_need_id, processed_need_ids)
Expand All @@ -211,7 +227,7 @@ def jinja2uml(
)

# 5. Get data for the jinja processing
data = {}
data: dict[str, Any] = {}
# 5.1 Set default config to data
data.update(**NeedsSphinxConfig(app.config).render_context)
# 5.2 Set uml() kwargs to data and maybe overwrite default settings
Expand Down Expand Up @@ -246,8 +262,12 @@ class JinjaFunctions:
"""

def __init__(
self, app: Sphinx, fromdocname, parent_need_id: str, processed_need_ids: dict
):
self,
app: Sphinx,
fromdocname: None | str,
parent_need_id: None | str,
processed_need_ids: ProcessedNeedsType,
) -> None:
self.needs = SphinxNeedsData(app.env).get_or_create_needs()
self.app = app
self.fromdocname = fromdocname
Expand All @@ -258,24 +278,28 @@ def __init__(
)
self.processed_need_ids = processed_need_ids

def need_to_processed_data(self, art: str, key: str, kwargs: dict) -> {}:
d = {
def need_to_processed_data(
self, art: str, key: None | str, kwargs: dict[str, Any]
) -> ProcessedDataType:
d: ProcessedDataType = {
"art": art,
"key": key,
"arguments": kwargs,
}
return d

def append_need_to_processed_needs(
self, need_id: str, art: str, key: str, kwargs: dict
self, need_id: str, art: str, key: None | str, kwargs: dict[str, Any]
) -> None:
data = self.need_to_processed_data(art=art, key=key, kwargs=kwargs)
if need_id not in self.processed_need_ids:
self.processed_need_ids[need_id] = []
if data not in self.processed_need_ids[need_id]:
self.processed_need_ids[need_id].append(data)

def append_needs_to_processed_needs(self, processed_needs_data: dict) -> None:
def append_needs_to_processed_needs(
self, processed_needs_data: ProcessedNeedsType
) -> None:
for k, v in processed_needs_data.items():
if k not in self.processed_need_ids:
self.processed_need_ids[k] = []
Expand All @@ -284,17 +308,17 @@ def append_needs_to_processed_needs(self, processed_needs_data: dict) -> None:
self.processed_need_ids[k].append(d)

def data_in_processed_data(
self, need_id: str, art: str, key: str, kwargs: dict
self, need_id: str, art: str, key: str, kwargs: dict[str, Any]
) -> bool:
data = self.need_to_processed_data(art=art, key=key, kwargs=kwargs)
return (need_id in self.processed_need_ids) and (
data in self.processed_need_ids[need_id]
)

def get_processed_need_ids(self) -> {}:
def get_processed_need_ids(self) -> ProcessedNeedsType:
return self.processed_need_ids

def uml_from_need(self, need_id: str, key: str = "diagram", **kwargs) -> str:
def uml_from_need(self, need_id: str, key: str = "diagram", **kwargs: Any) -> str:
if need_id not in self.needs:
raise NeedumlException(
f"Jinja function uml() is called with undefined need_id: '{need_id}'."
Expand Down Expand Up @@ -337,7 +361,7 @@ def uml_from_need(self, need_id: str, key: str = "diagram", **kwargs) -> str:

return uml

def flow(self, need_id) -> str:
def flow(self, need_id: str) -> str:
if need_id not in self.needs:
raise NeedumlException(
f"Jinja function flow is called with undefined need_id: '{need_id}'."
Expand Down Expand Up @@ -368,7 +392,9 @@ def flow(self, need_id) -> str:

return need_uml

def ref(self, need_id: str, option: str = None, text: str = None) -> str:
def ref(
self, need_id: str, option: None | str = None, text: None | str = None
) -> str:
if need_id not in self.needs:
raise NeedumlException(
f"Jinja function ref is called with undefined need_id: '{need_id}'."
Expand All @@ -388,7 +414,7 @@ def ref(self, need_id: str, option: str = None, text: str = None) -> str:

return need_uml

def filter(self, filter_string):
def filter(self, filter_string: str) -> list[NeedsInfoType]:
"""
Return a list of found needs that pass the given filter string.
"""
Expand All @@ -398,7 +424,7 @@ def filter(self, filter_string):
list(self.needs.values()), needs_config, filter_string=filter_string
)

def imports(self, *args):
def imports(self, *args: str) -> str:
if not self.parent_need_id:
raise NeedumlException(
"Jinja function 'import()' is not supported in needuml directive."
Expand All @@ -408,8 +434,14 @@ def imports(self, *args):
uml_ids = []
for option_name in args:
# check if link option_name exists in current need object
if option_name in need_info and need_info[option_name]:
for id in need_info[option_name]:
if option_value := need_info.get(option_name):
try:
iterable_value = list(option_value) # type: ignore
except TypeError:
raise NeedumlException(
f"Option value for {option_name!r} is not iterable."
)
for id in iterable_value:
uml_ids.append(id)
umls = ""
if uml_ids:
Expand All @@ -420,7 +452,7 @@ def imports(self, *args):
umls += local_uml_from_need
return umls

def need(self):
def need(self) -> NeedsInfoType:
if not self.parent_need_id:
raise NeedumlException(
"Jinja function 'need()' is not supported in needuml directive."
Expand Down

0 comments on commit aef3a0f

Please sign in to comment.