From b90f4f6a18dcfaec5d2e2b65dfbc1b471fc37d5a Mon Sep 17 00:00:00 2001 From: nitely Date: Tue, 25 Nov 2025 04:40:59 -0300 Subject: [PATCH 1/2] Fix decode params forwarding --- serialization.nim | 4 ++-- serialization/macros.nim | 26 ++++++++++++-------------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/serialization.nim b/serialization.nim index d673007..23fad94 100644 --- a/serialization.nim +++ b/serialization.nim @@ -74,7 +74,7 @@ template decodeImpl[InputType]( # from the fact that the dynamic dispatch mechanisms used in # faststreams may be reading from a file or a network device. {.noSideEffect.}: - var reader = unpackForwarded(init, [ReaderType, stream, params]) + var reader = unpackForwarded(init, [ReaderType, stream], params) reader.readValue(result) except IOError: raiseAssert "memory input doesn't raise IOError" @@ -90,7 +90,7 @@ template decodeImpl[InputType]( # Something's terribly wrong if we're reaching this point raiseAssert "negative memory input length" - unpackForwarded(decodeProc, [inputParam, params]) + unpackArgs(decodeProc, [inputParam, params]) template decode*( Format: type SerializationFormat, diff --git a/serialization/macros.nim b/serialization/macros.nim index a6ba4eb..d3c0dc4 100644 --- a/serialization/macros.nim +++ b/serialization/macros.nim @@ -54,24 +54,22 @@ macro forward*(args, prc: untyped): untyped = # this exact ident instance .. prc.params.add nnkIdentDefs.newTree(ident $arg[0], nnkCall.newTree(ident "typeof", arg[1]), newEmptyNode()) else: - prc.params.add nnkIdentDefs.newTree(ident "fwd" & $i, nnkCall.newTree(ident "typeof", arg[0]), newEmptyNode()) - i += 1 + prc.params.add nnkIdentDefs.newTree(ident "fwd" & $i, nnkCall.newTree(ident "typeof", arg), newEmptyNode()) + i += 1 prc -macro unpackForwarded*(callee: untyped, args: untyped): untyped = +macro unpackForwarded*(callee: untyped, args: untyped, params: varargs[untyped]): untyped = # pass on `args` to callee - args should be an array of parameters to pass - # on to callee where one of them should be the `varargs[untyped]` passed to - # the forward macro. Messy. + # on to callee; `varargs[untyped]` should be the params passed to the forward macro. result = newCall(callee) - var i = 0 for arg in usefulArgs(args): - if arg.kind == nnkArgList: - for subarg in usefulArgs(arg): - if subarg.kind == nnkExprEqExpr: - result.add nnkExprEqExpr.newTree(ident $subarg[0], ident $subarg[0]) - else: - result.add ident "fwd" & $i - i += 1 + result.add arg + + var i = 0 + for arg in usefulArgs(params): + if arg.kind == nnkExprEqExpr: + result.add nnkExprEqExpr.newTree(ident $arg[0], ident $arg[0]) else: - result.add arg + result.add ident "fwd" & $i + i += 1 From 8b8f969e97b7943d077e070893343f60af18abb8 Mon Sep 17 00:00:00 2001 From: nitely Date: Tue, 25 Nov 2025 22:14:10 -0300 Subject: [PATCH 2/2] tests --- tests/test_serializer.nim | 98 ++++++++++++++++++++++++++++++++++++++ tests/utils/serializer.nim | 8 +++- 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/tests/test_serializer.nim b/tests/test_serializer.nim index 4d72b84..36bde4d 100644 --- a/tests/test_serializer.nim +++ b/tests/test_serializer.nim @@ -81,3 +81,101 @@ suite "Rountrips": var val: ref string let ser = Ser.encode(val) check Ser.decode(ser, typeof(val)).isNil() + +type StringLimErr = object of SerializationError +type StringLim = distinct string + +proc `==`(a, b: StringLim): bool {.borrow.} +proc add(a: var StringLim, b: char) {.borrow.} + +proc readValue(r: var SerReader, val: var StringLim) {.raises: [IOError, SerializationError].} = + consumeKind r, SerKind.String + let L = r.readUint64() + if L > r.conf.limit.uint64: + raise newException(StringLimErr, "limit err") + for _ in 0 ..< L: + val.add r.stream.read().char + +suite "Config": + test "pass let conf": + let val = "1234567890" + let ser = Ser.encode(val) + + let conf10 = SerConf(limit: 10) + check Ser.decode(ser, StringLim, conf = conf10) == val.StringLim + check Ser.decode(ser, StringLim, conf10) == val.StringLim + + let conf5 = SerConf(limit: 5) + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf = conf5) + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf5) + + test "pass const conf": + let val = "1234567890" + let ser = Ser.encode(val) + + const conf10 = SerConf(limit: 10) + check Ser.decode(ser, StringLim, conf = conf10) == val.StringLim + check Ser.decode(ser, StringLim, conf10) == val.StringLim + + const conf5 = SerConf(limit: 5) + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf = conf5) + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf5) + + test "pass inlined conf": + let val = "1234567890" + let ser = Ser.encode(val) + + check Ser.decode(ser, StringLim, conf = SerConf(limit: 10)) == val.StringLim + check Ser.decode(ser, StringLim, SerConf(limit: 10)) == val.StringLim + + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf = SerConf(limit: 5)) + expect StringLimErr: + discard Ser.decode(ser, StringLim, SerConf(limit: 5)) + + test "pass expression conf": + let val = "1234567890" + let ser = Ser.encode(val) + + template conf10: untyped = + var conf = SerConf(limit: 10) + conf + + check Ser.decode(ser, StringLim, conf = conf10) == val.StringLim + check Ser.decode(ser, StringLim, conf10) == val.StringLim + + template conf5: untyped = + var conf = SerConf(limit: 5) + conf + + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf = conf5) + expect StringLimErr: + discard Ser.decode(ser, StringLim, conf5) + + test "multi params": + func init( + R: type SerReader, + stream: InputStream, + limit1: int, + limit2: int, + conf = default(SerConf) + ): R = + R(stream: stream, conf: SerConf(limit: conf.limit + limit1 + limit2)) + + let val = "1234567890" + let ser = Ser.encode(val) + + let lim1 = 10 + let lim2 = 0 + check: + Ser.decode(ser, StringLim, 10, 0) == val.StringLim + Ser.decode(ser, StringLim, 0, 10) == val.StringLim + Ser.decode(ser, StringLim, 0, 0, SerConf(limit: 10)) == val.StringLim + Ser.decode(ser, StringLim, limit1 = 10, limit2 = 0) == val.StringLim + Ser.decode(ser, StringLim, limit1 = 0, limit2 = 10) == val.StringLim + Ser.decode(ser, StringLim, lim1, lim2) == val.StringLim diff --git a/tests/utils/serializer.nim b/tests/utils/serializer.nim index 9a6f422..518ec0f 100644 --- a/tests/utils/serializer.nim +++ b/tests/utils/serializer.nim @@ -27,13 +27,17 @@ Ser.setWriter SerWriter, PreferredOutput = seq[byte] func init*(W: type SerWriter, stream: OutputStream): W = W(stream: stream) +type SerConf* = object + limit*: int + type SerReader* = object stream*: InputStream + conf*: SerConf Ser.setReader SerReader -func init*(R: type SerReader, stream: InputStream): R = - R(stream: stream) +func init*(R: type SerReader, stream: InputStream, conf = default(SerConf)): R = + R(stream: stream, conf: conf) type SerKind* {.pure.} = enum Int = 0