Skip to content

Commit

Permalink
Add flag to pass comma delimited arrays by REST client. (#8)
Browse files Browse the repository at this point in the history
Simplify server array decoding.
Add tests for comma delimited arrays.
  • Loading branch information
cheatfate committed May 17, 2021
1 parent 37b25ea commit 5163805
Show file tree
Hide file tree
Showing 3 changed files with 339 additions and 293 deletions.
57 changes: 41 additions & 16 deletions presto/client.nim
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type
session: HttpSessionRef
address*: HttpAddress
agent: string
flags: RestClientFlags

RestClientRef* = ref RestClient

Expand All @@ -36,6 +37,11 @@ type

RestStatus* = distinct int

RestClientFlag* {.pure.} = enum
CommaSeparatedArray

RestClientFlags* = set[RestClientFlag]

RestRequestFlag* {.pure.} = enum
ConsumeBody

Expand All @@ -60,15 +66,16 @@ proc `$`*(x: RestStatus): string {.borrow.}

proc new*(t: typedesc[RestClientRef],
url: string,
flags: HttpClientFlags = {},
flags: RestClientFlags = {},
httpFlags: HttpClientFlags = {},
maxConnections: int = -1,
maxRedirections: int = HttpMaxRedirections,
connectTimeout = HttpConnectTimeout,
headersTimeout = HttpHeadersTimeout,
bufferSize: int = 4096,
userAgent = PrestoIdent
): RestResult[RestClientRef] =
let session = HttpSessionRef.new(flags, maxRedirections, connectTimeout,
let session = HttpSessionRef.new(httpFlags, maxRedirections, connectTimeout,
headersTimeout, bufferSize, maxConnections)
var uri = parseUri(url)
uri.path = ""
Expand All @@ -81,23 +88,26 @@ proc new*(t: typedesc[RestClientRef],
if res.isErr():
return err("Unable to resolve remote hostname")
res.get()
ok(RestClientRef(session: session, address: address, agent: userAgent))
ok(RestClientRef(session: session, address: address, agent: userAgent,
flags: flags))

proc new*(t: typedesc[RestClientRef],
ta: TransportAddress,
scheme: HttpClientScheme = HttpClientScheme.NonSecure,
flags: HttpClientFlags = {},
flags: RestClientFlags = {},
httpFlags: HttpClientFlags = {},
maxConnections: int = -1,
maxRedirections: int = HttpMaxRedirections,
connectTimeout = HttpConnectTimeout,
headersTimeout = HttpHeadersTimeout,
bufferSize: int = 4096,
userAgent = PrestoIdent
): RestClientRef =
let session = HttpSessionRef.new(flags, maxRedirections, connectTimeout,
let session = HttpSessionRef.new(httpFlags, maxRedirections, connectTimeout,
headersTimeout, bufferSize, maxConnections)
let address = ta.getAddress(scheme, "")
RestClientRef(session: session, address: address, agent: userAgent)
RestClientRef(session: session, address: address, agent: userAgent,
flags: flags)

proc closeWait*(client: RestClientRef) {.async.} =
await client.session.closeWait()
Expand Down Expand Up @@ -674,16 +684,31 @@ proc restSingleProc(prc: NimNode): NimNode {.compileTime.} =
statements.add quote do:
let `encodedName` =
block:
var res: seq[string]
for item in `paramName`.items():
let eres = encodeString(item)
if eres.isErr():
raiseRestEncodingStringError(`paramLiteral`)
var sres = `paramLiteral`
sres.add('=')
sres.add(encodeUrl(eres.get(), true))
res.add(sres)
res.join("&")
if RestClientFlag.CommaSeparatedArray in `clientIdent`.flags:
var res: seq[string]
for item in `paramName`.items():
let eres = encodeString(item)
if eres.isErr():
raiseRestEncodingStringError(`paramLiteral`)
res.add(encodeUrl(eres.get(), true))
if len(res) > 0:
var sres = `paramLiteral`
sres.add('=')
sres.add(res.join(","))
sres
else:
""
else:
var res: seq[string]
for item in `paramName`.items():
let eres = encodeString(item)
if eres.isErr():
raiseRestEncodingStringError(`paramLiteral`)
var sres = `paramLiteral`
sres.add('=')
sres.add(encodeUrl(eres.get(), true))
res.add(sres)
res.join("&")
else:
statements.add quote do:
let `encodedName` =
Expand Down
2 changes: 1 addition & 1 deletion presto/route.nim
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ macro api*(router: RestRouter, meth: static[HttpMethod],
block:
var sres: seq[`seqType`]
var errorMsg: cstring = nil
for index, item in `queryParams`.getList(`strName`).pairs():
for item in `queryParams`.getList(`strName`).items():
let res = decodeString(`seqType`, item)
if res.isErr():
errorMsg = res.error()
Expand Down
Loading

0 comments on commit 5163805

Please sign in to comment.