diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index cac352c0..ee60c517 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -420,6 +420,12 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: current_chunks.append(f" {name}: {type_name}") typeddict_encoder.append(",") typeddict_encoder.append("}") + # exclude_none + typeddict_encoder = ( + ["{k: v for (k, v) in ("] + + typeddict_encoder + + [").items() if v is not None}"] + ) else: typeddict_encoder.append("{}") current_chunks.append(" pass") diff --git a/scripts/parity/check_parity.py b/scripts/parity/check_parity.py index 4ecc3626..8e75532c 100644 --- a/scripts/parity/check_parity.py +++ b/scripts/parity/check_parity.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Literal, TypedDict, TypeVar, Union, cast +from typing import Any, Callable, Literal, TypedDict, TypeVar, Union import pyd import tyd @@ -16,15 +16,45 @@ A = TypeVar("A") +PrimitiveType = ( + bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"] +) + + +def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]: + if a == b: + return True + elif isinstance(a, dict) and isinstance(b, dict): + a_keys: PrimitiveType = list(a.keys()) + b_keys: PrimitiveType = list(b.keys()) + assert deep_equal(a_keys, b_keys) + + # We do this dance again because Python variance is hard. Feel free to fix it. + keys = set(a.keys()) + keys.update(b.keys()) + for k in keys: + aa: PrimitiveType = a[k] + bb: PrimitiveType = b[k] + assert deep_equal(aa, bb) + return True + elif isinstance(a, list) and isinstance(b, list): + assert len(a) == len(b) + for i in range(len(a)): + assert deep_equal(a[i], b[i]) + return True + else: + assert a == b, f"{a} != {b}" + return True + def baseTestPattern( x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] ) -> None: a = encode(x) m = adapter.validate_python(a) - z = adapter.dump_python(m) + z = adapter.dump_python(m, by_alias=True, exclude_none=True) - assert a == z + assert deep_equal(a, z) def testAiexecExecInit() -> None: @@ -93,7 +123,39 @@ def testAgenttoollanguageserverGetcodesymbolInput() -> None: "line": gen_float(), "character": gen_float(), }, - "kind": cast(kind_type, gen_opt(gen_choice(list(range(1, 27))))()), + "kind": gen_choice( + list[kind_type]( + [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + None, + ] + ) + )(), } baseTestPattern( @@ -116,17 +178,17 @@ def testShellexecSpawnInput() -> None: "env": gen_opt(gen_dict(gen_str))(), "cwd": gen_opt(gen_str)(), "size": gen_opt( - lambda: cast( - size_type, + lambda: size_type( { "rows": gen_int(), "cols": gen_int(), - }, - ) + } + ), )(), "useReplitRunEnv": gen_opt(gen_bool)(), "useCgroupMagic": gen_opt(gen_bool)(), "interactive": gen_opt(gen_bool)(), + "onlySpawnIfNoProcesses": gen_opt(gen_bool)(), } baseTestPattern( @@ -146,12 +208,82 @@ def testConmanfilesystemPersistInput() -> None: ) +closeFile = tyd.ReplspaceapiInitInputOneOf_closeFile +githubToken = tyd.ReplspaceapiInitInputOneOf_githubToken +sshToken0 = tyd.ReplspaceapiInitInputOneOf_sshToken0 +sshToken1 = tyd.ReplspaceapiInitInputOneOf_sshToken1 +allowDefaultBucketAccess = tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccess + +allowDefaultBucketAccessResultOk = ( + tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_ok +) +allowDefaultBucketAccessResultError = ( + tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_error +) + + +def testReplspaceapiInitInput() -> None: + x: tyd.ReplspaceapiInitInput = gen_choice( + list[tyd.ReplspaceapiInitInput]( + [ + closeFile( + {"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()} + ), + githubToken( + {"kind": "githubToken", "token": gen_str(), "nonce": gen_str()} + ), + sshToken0( + { + "kind": "sshToken", + "nonce": gen_str(), + "SSHHostname": gen_str(), + "token": gen_str(), + } + ), + sshToken1({"kind": "sshToken", "nonce": gen_str(), "error": gen_str()}), + allowDefaultBucketAccess( + { + "kind": "allowDefaultBucketAccess", + "nonce": gen_str(), + "result": gen_choice( + list[ + tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResult + ]( + [ + allowDefaultBucketAccessResultOk( + { + "bucketId": gen_str(), + "sourceReplId": gen_str(), + "status": "ok", + "targetReplId": gen_str(), + } + ), + allowDefaultBucketAccessResultError( + {"message": gen_str(), "status": "error"} + ), + ] + ) + )(), + } + ), + ] + ) + )() + + baseTestPattern( + x, + tyd.encode_ReplspaceapiInitInput, + TypeAdapter(pyd.ReplspaceapiInitInput), + ) + + def main() -> None: testAiexecExecInit() testAgenttoollanguageserverOpendocumentInput() testAgenttoollanguageserverGetcodesymbolInput() testShellexecSpawnInput() testConmanfilesystemPersistInput() + testReplspaceapiInitInput() if __name__ == "__main__": diff --git a/scripts/parity/gen.py b/scripts/parity/gen.py index ae777183..92bae50b 100644 --- a/scripts/parity/gen.py +++ b/scripts/parity/gen.py @@ -1,15 +1,15 @@ import random +import string from typing import Callable, Optional, TypeVar A = TypeVar("A") +printable_chars = string.ascii_letters + string.digits + + def gen_char() -> str: - pos = random.randint(0, 26 * 2) - if pos < 26: - return chr(ord("A") + pos) - else: - return chr(ord("a") + pos - 26) + return random.choice(printable_chars) def gen_str() -> str: