Skip to content

Commit

Permalink
Handle default values correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
unexge committed Jan 3, 2023
1 parent 8621ca1 commit 5e56f45
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 14 deletions.
Expand Up @@ -288,7 +288,7 @@ class PythonApplicationGenerator(
/// :param workers ${PythonType.Optional(PythonType.Int).render()}:
/// :param tls ${PythonType.Optional(tlsConfig).render()}:
/// :rtype ${PythonType.None.render()}:
##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers, tls)")]
##[pyo3(text_signature = "(${'$'}self, address=None, port=None, backlog=None, workers=None, tls=None)")]
pub fn run(
&mut self,
py: #{pyo3}::Python,
Expand All @@ -315,7 +315,7 @@ class PythonApplicationGenerator(
}
/// Build the service and start a single worker.
##[pyo3(text_signature = "(${'$'}self, socket, worker_number, tls)")]
##[pyo3(text_signature = "(${'$'}self, socket, worker_number, tls=None)")]
pub fn start_worker(
&mut self,
py: pyo3::Python,
Expand Down
29 changes: 22 additions & 7 deletions rust-runtime/aws-smithy-http-server-python/examples/stubgen.py
Expand Up @@ -150,8 +150,8 @@ def is_fn_like(obj: Any) -> bool:
)


def walk_field(writer: Writer, name: str, field: Any) -> str:
return f"{name}: {writer.fix_and_include(DocstringParser.parse_type(field))} = ..."
def make_field(writer: Writer, name: str, field: Any) -> str:
return f"{name}: {writer.fix_and_include(DocstringParser.parse_type(field))}"


def make_function(writer: Writer, name: str, obj: Any, indent_level: int = 0) -> str:
Expand All @@ -161,6 +161,13 @@ def make_function(writer: Writer, name: str, obj: Any, indent_level: int = 0) ->
return f"{name}: {writer.include('typing.Any')}"

params, rtype = res
# We're using signature for getting default values only, currently type hints are not supported
# in signatures. We can leverage signatures more if it supports type hints in future.
sig: Optional[inspect.Signature] = None
try:
sig = inspect.signature(obj)
except:
pass

receivers: List[str] = []
attrs: List[str] = []
Expand All @@ -169,9 +176,17 @@ def make_function(writer: Writer, name: str, obj: Any, indent_level: int = 0) ->
else:
attrs.append("@staticmethod")

params = ", ".join(
receivers + [f"{n}: {writer.fix_and_include(t)} = ..." for n, t in params]
)
def format_param(name: str, ty: str) -> str:
param = f"{name}: {writer.fix_and_include(ty)}"

if sig is not None:
sig_param = sig.parameters.get(name)
if sig_param and sig_param.default is not sig_param.empty:
param += f" = {sig_param.default}"

return param

params = ", ".join(receivers + [format_param(n, t) for n, t in params])

fn_def = ""
if len(attrs) > 0:
Expand All @@ -194,15 +209,15 @@ def make_class(
if inspect.isdatadescriptor(member):
is_empty = False
definition += (
indent(walk_field(writer, name, member), indent_level + 4) + "\n"
indent(make_field(writer, name, member), indent_level + 4) + "\n"
)
elif is_fn_like(member):
is_empty = False
definition += make_function(writer, name, member, indent_level + 4) + "\n"
# Enum variant
elif isinstance(member, klass):
is_empty = False
definition += indent(f"{name}: {class_name} = ...\n", indent_level + 4)
definition += indent(f"{name}: {class_name}\n", indent_level + 4)
else:
print(f"Unknown member type={member}")

Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-http-server-python/src/error.rs
Expand Up @@ -44,7 +44,7 @@ impl From<PyError> for PyErr {
/// :param status_code typing.Optional[int]:
/// :rtype None:
#[pyclass(name = "MiddlewareException", extends = BasePyException)]
#[pyo3(text_signature = "(message, status_code)")]
#[pyo3(text_signature = "($self, message, status_code=None)")]
#[derive(Debug, Clone)]
pub struct PyMiddlewareException {
#[pyo3(get, set)]
Expand Down
1 change: 1 addition & 0 deletions rust-runtime/aws-smithy-http-server-python/src/logging.rs
Expand Up @@ -91,6 +91,7 @@ fn setup_tracing_subscriber(
/// :param logfile typing.Optional[pathlib.Path]:
/// :rtype None:
#[pyclass(name = "TracingHandler")]
#[pyo3(text_signature = "($self, level=None, logfile=None)")]
#[derive(Debug)]
pub struct PyTracingHandler {
_guard: Option<WorkerGuard>,
Expand Down
Expand Up @@ -17,7 +17,6 @@ use super::{PyHeaderMap, PyMiddlewareError};

/// Python-compatible [Request] object.
#[pyclass(name = "Request")]
#[pyo3(text_signature = "(request)")]
#[derive(Debug)]
pub struct PyRequest {
parts: Option<Parts>,
Expand Down
Expand Up @@ -23,7 +23,7 @@ use super::{PyHeaderMap, PyMiddlewareError};
/// :param body typing.Optional[bytes]:
/// :rtype None:
#[pyclass(name = "Response")]
#[pyo3(text_signature = "(status, headers, body)")]
#[pyo3(text_signature = "($self, status, headers=None, body=None)")]
pub struct PyResponse {
parts: Option<Parts>,
headers: PyHeaderMap,
Expand Down
1 change: 0 additions & 1 deletion rust-runtime/aws-smithy-http-server-python/src/socket.rs
Expand Up @@ -49,7 +49,6 @@ impl PySocket {

/// Clone the inner socket allowing it to be shared between multiple
/// Python processes.
#[pyo3(text_signature = "($self, socket, worker_number)")]
pub fn try_clone(&self) -> PyResult<PySocket> {
let copied = self.inner.try_clone()?;
Ok(PySocket { inner: copied })
Expand Down
2 changes: 1 addition & 1 deletion rust-runtime/aws-smithy-http-server-python/src/tls.rs
Expand Up @@ -27,7 +27,7 @@ pub mod listener;
/// :rtype None:
#[pyclass(
name = "TlsConfig",
text_signature = "(*, key_path, cert_path, reload)"
text_signature = "($self, *, key_path, cert_path, reload_secs=86400)"
)]
#[derive(Clone)]
pub struct PyTlsConfig {
Expand Down

0 comments on commit 5e56f45

Please sign in to comment.