Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,16 +457,35 @@ def {_field_name}(
case _:
encoder_parts.append((None, "x"))

# Build the ternary chain from encoder_parts
# Build the ternary chain from encoder_parts.
#
# Every entry that has a `type_check` (isinstance / `x is None`) gets
# its own guard, including the last one. Falling off the end means the
# input did not match any declared anyOf variant, which should not
# happen for a well-formed value; we emit a `cast(Any, x)` so mypy
# doesn't try to narrow the value through the chain.
#
# Previously the last entry was emitted unconditionally as the `else`
# branch. That works for simple unions (object | str | list), but
# breaks down when the last variant's encoder requires iteration
# (e.g. `[encode_X(y) for y in x]` when the variant is an array) and
# mypy fails to fully narrow `x` through the prior `isinstance`
# checks. The unguarded final branch then triggers `union-attr`
# errors like "Item 'float' has no attribute '__iter__'".
typeddict_encoder = list[str]()
for i, (type_check, encoder_expr) in enumerate(encoder_parts):
is_last = i == len(encoder_parts) - 1
if is_last or type_check is None:
# Last item or no type check - just the expression
has_unguarded_terminal = False
for type_check, encoder_expr in encoder_parts:
if type_check is None:
# No type check available — emit the bare expression and stop;
# nothing after it could be reached anyway.
typeddict_encoder.append(encoder_expr)
else:
# Add expression with type check
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
has_unguarded_terminal = True
break
typeddict_encoder.append(f"{encoder_expr} if {type_check} else")
if not has_unguarded_terminal and encoder_parts:
# Unreachable in practice (every declared variant was guarded
# above), but mypy needs a concrete final expression.
typeddict_encoder.append("cast(Any, x)")
if permit_unknown_members:
union = _make_open_union_type_expr(any_of)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Code generated by river.codegen. DO NOT EDIT.
from pydantic import BaseModel
from typing import Literal

import replit_river as river


from .test_service import Test_ServiceService


class AnyOfArrayInUnionClient:
def __init__(self, client: river.Client[Literal[None]]):
self.test_service = Test_ServiceService(client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any
import datetime

from pydantic import TypeAdapter

from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
import replit_river as river


from .exec_sql_method import (
Exec_Sql_MethodInput,
Exec_Sql_MethodOutput,
Exec_Sql_MethodOutputTypeAdapter,
encode_Exec_Sql_MethodInput,
encode_Exec_Sql_MethodInputParams,
)


class Test_ServiceService:
def __init__(self, client: river.Client[Any]):
self.client = client

async def exec_sql_method(
self,
input: Exec_Sql_MethodInput,
timeout: datetime.timedelta,
) -> Exec_Sql_MethodOutput:
return await self.client.send_rpc(
"test_service",
"exec_sql_method",
input,
encode_Exec_Sql_MethodInput,
lambda x: Exec_Sql_MethodOutputTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: RiverErrorTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
timeout,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
from typing import (
Any,
Literal,
Mapping,
NotRequired,
TypedDict,
cast,
)
from typing_extensions import Annotated

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river


Exec_Sql_MethodInputParamsAnyOf_4 = str | float | bool | None


def encode_Exec_Sql_MethodInputParamsAnyOf_4(
x: "Exec_Sql_MethodInputParamsAnyOf_4",
) -> Any:
return x


Exec_Sql_MethodInputParams = (
str | float | bool | list[Exec_Sql_MethodInputParamsAnyOf_4] | None
)


def encode_Exec_Sql_MethodInputParams(x: "Exec_Sql_MethodInputParams") -> Any:
return (
x
if isinstance(x, str)
else x
if isinstance(x, (int, float))
else x
if isinstance(x, bool)
else None
if x is None
else [encode_Exec_Sql_MethodInputParamsAnyOf_4(y) for y in x]
if isinstance(x, list)
else cast(Any, x)
)


def encode_Exec_Sql_MethodInput(
x: "Exec_Sql_MethodInput",
) -> Any:
return {
k: v
for (k, v) in (
{
"params": [encode_Exec_Sql_MethodInputParams(y) for y in x["params"]]
if "params" in x and x["params"] is not None
else None,
}
).items()
if v is not None
}


class Exec_Sql_MethodInput(TypedDict):
params: NotRequired[list[Exec_Sql_MethodInputParams] | None]


class Exec_Sql_MethodOutput(BaseModel):
ok: bool


Exec_Sql_MethodOutputTypeAdapter: TypeAdapter[Exec_Sql_MethodOutput] = TypeAdapter(
Exec_Sql_MethodOutput
)
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def encode_Anyof_Mixed_MethodInputRun_Command(
else x
if isinstance(x, str)
else list(x)
if isinstance(x, list)
else cast(Any, x)
)


Expand Down
31 changes: 31 additions & 0 deletions tests/v1/codegen/snapshot/test_anyof_array_in_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pytest_snapshot.plugin import Snapshot

from tests.fixtures.codegen_snapshot_fixtures import validate_codegen


async def test_anyof_array_in_union(snapshot: Snapshot) -> None:
"""Test codegen for an array field whose item type is a non-discriminated
anyOf union that itself contains an `array` variant.

Concretely this mirrors the PostgreSQL `executeSqlCommand.params` schema:
`array<scalar | array<scalar>>`. The inner union encoder ends in an
iteration over `x` (for the array variant), and historically that branch
was emitted as the unguarded `else` of a ternary chain. When mypy failed
to fully narrow `x` to `list[...]` through the preceding `isinstance`
checks, it complained that scalar items of the union have no
`__iter__` attribute (`union-attr`).

The fix emits an explicit `isinstance(x, list)` guard for the array
branch and a `cast(Any, x)` fallback, so mypy never has to negative-
narrow into the iterating branch.
"""
validate_codegen(
snapshot=snapshot,
snapshot_dir="tests/v1/codegen/snapshot/snapshots",
read_schema=lambda: open(
"tests/v1/codegen/types/anyof_array_in_union_schema.json"
),
target_path="test_anyof_array_in_union",
client_name="AnyOfArrayInUnionClient",
protocol_version="v1.1",
)
49 changes: 49 additions & 0 deletions tests/v1/codegen/types/anyof_array_in_union_schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"services": {
"test_service": {
"procedures": {
"exec_sql_method": {
"input": {
"type": "object",
"properties": {
"params": {
"description": "Parameterized query values. Each entry is either a scalar or an array of scalars (for ANY($1::text[]) etc.).",
"type": "array",
"items": {
"anyOf": [
{ "type": "string" },
{ "type": "number" },
{ "type": "boolean" },
{ "type": "null" },
{
"type": "array",
"items": {
"anyOf": [
{ "type": "string" },
{ "type": "number" },
{ "type": "boolean" },
{ "type": "null" }
]
}
}
]
}
}
}
},
"output": {
"type": "object",
"properties": {
"ok": { "type": "boolean" }
},
"required": ["ok"]
},
"errors": {
"not": {}
},
"type": "rpc"
}
}
}
}
}
Loading