diff --git a/.gitignore b/.gitignore index 106da18..1e69add 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,4 @@ *.out # Editor specifics -.vscode/* +.vscode diff --git a/.travis.yml b/.travis.yml index 0828349..4b30d1e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,21 +1,15 @@ language: go go: -- 1.12.x -- 1.13.4 +- 1.13.x +- 1.14.1 install: -- go get github.com/mattn/goveralls -- go get honnef.co/go/tools/cmd/staticcheck -- go get github.com/client9/misspell/cmd/misspell +- go get -t ./... +- ./scripts/install-checks.sh before_script: -- PACKAGES=$(go list ./...) -- go get -d ./... +- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.14(\.[0-9]+)?$ ]]; then ./scripts/check.sh; fi - go build -- $(exit $(go fmt $PACKAGES | wc -l)) -- go vet $PACKAGES -- misspell -error -locale US . -- staticcheck $PACKAGES -- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.13\. ]] && [ "$TRAVIS_TAG" != "" ]; then ./scripts/cross_compile.sh $TRAVIS_TAG; fi +- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.14(\.[0-9]+)?$ ]] && [ "$TRAVIS_TAG" != "" ]; then ./scripts/cross_compile.sh $TRAVIS_TAG; fi script: -- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.13\. ]]; then ./scripts/cover.sh TRAVIS; else go test -v -race $PACKAGES; fi +- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.14(\.[0-9]+)?$ ]]; then ./scripts/cover.sh TRAVIS; else go test -v -race $PACKAGES; fi after_success: -- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.13\. ]] && [ "$TRAVIS_TAG" != "" ]; then ghr --owner resgateio --token $GITHUB_TOKEN --draft --replace $TRAVIS_TAG pkg/; fi +- if [[ "$TRAVIS_GO_VERSION" =~ ^1\.14(\.[0-9]+)?$ ]] && [ "$TRAVIS_TAG" != "" ]; then ghr --owner resgateio --token $GITHUB_TOKEN --draft --replace $TRAVIS_TAG pkg/; fi diff --git a/README.md b/README.md index 9c60b4d..bd287ff 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,6 @@ Whenever there is a change to the data, the responsible micro-service sends an e ## Quickstart -### Docker - If you install Docker, it is easy to run both *NATS server* and *Resgate* as containers: ```text @@ -44,37 +42,16 @@ docker run -d --name nats -p 4222:4222 --net res nats docker run --name resgate -p 8080:8080 --net res resgateio/resgate --nats nats://nats:4222 ``` -Both images are small, less than 10 MB each. - -### Download - -Another way to install *Resgate* and *NATS Server* is to download one of the pre-built binaries: - -* [Download](https://nats.io/download/nats-io/nats-server/) and run NATS Server -* [Download](https://resgate.io/download/) and run Resgate +Both images are small, about 10 MB each. -### Building - -If you wish to build your own binaries, first make sure you have: -* [installed Go](https://golang.org/doc/install) and [set your `$GOPATH`](https://golang.org/cmd/go/#hdr-GOPATH_environment_variable) -* added `$GOPATH/bin` (where your binaries are stored) to your `PATH` -* [installed node.js](https://nodejs.org/en/download/) (for the test app) - -Install and run [NATS server](https://nats-io.github.io/docs/nats_server/installation.html) and Resgate: -```bash -go get github.com/nats-io/nats-server -nats-server -``` -```bash -go get github.com/resgateio/resgate -resgate -``` +See [Resgate.io - Installation](https://resgate.io/docs/get-started/installation/) for other ways of installation. ## Examples -While Resgate may be used with any language, the examples in this repository are written in Javascript for Node.js. +While Resgate may be used with any language, the examples in this repository are written in Javascript for Node.js, without using any additional library. -For examples in other languages, visit [Resgate.io - Examples](https://resgate.io/docs/get-started/examples/). +* For Go (golang) examples, see [go-res package](https://github.com/jirenius/go-res) +* For C# (NETCore) examples, see [RES Service for .NET](https://github.com/jirenius/csharp-res) | Example | Description | --- | --- @@ -101,22 +78,43 @@ For more in depth information on the protocol: ``` resgate [options] ``` -| Option | Description | -|---|---| -| `-n, --nats ` | NATS Server URL | -| `-i, --addr ` | Bind to HOST address | -| `-p, --port ` | Use port for clients | -| `-w, --wspath ` | Path to WebSocket | -| `-a, --apipath ` | Path to web resources | -| `-r, --reqtimeout ` | Timeout duration for NATS requests | -| `-u, --headauth ` | Resource method for header authentication | -| ` --tls` | Enable TLS | -| ` --tlscert ` | Server certificate file | -| ` --tlskey ` | Private key for server certificate | -| ` --apiencoding ` | Encoding for web resources: json, jsonflat | -| `-c, --config ` | Configuration file | -| `-h, --help` | Show usage message | -| `-v, --version` | Show version | + +### Server options + +| Option | Description | Default value +| --- | --- | --- +| `-n`, `--nats ` | NATS Server URL | `nats://127.0.0.1:4222` +| `-i`, `--addr ` | Bind to HOST address | `0.0.0.0` +| `-p`, `--port ` | HTTP port for client connections | `8080` +| `-w`, `--wspath ` | WebSocket path for clients | `/` +| `-a`, `--apipath ` | Web resource path for clients | `/api/` +| `-r`, `--reqtimeout ` | Timeout duration for NATS requests | `3000` +| `-u`, `--headauth ` | Resource method for header authentication | +| ` --tls` | Enable TLS for HTTP | `false` +| ` --tlscert ` | HTTP server certificate file | +| ` --tlskey ` | Private key for HTTP server certificate | +| ` --apiencoding ` | Encoding for web resources: json, jsonflat | `json` +| ` --creds ` | NATS User Credentials file | +| ` --alloworigin ` | Allowed origin(s): *, or \://\\[:\\] | `*` +| ` --putmethod ` | Call method name mapped to HTTP PUT requests | +| ` --deletemethod ` | Call method name mapped to HTTP DELETE requests | +| ` --patchmethod ` | Call method name mapped to HTTP PATCH requests | +| `-c`, `--config ` | Configuration file in JSON format | + +### Logging options + +| Option | Description +| --- | --- +| `-D`, `--debug` | Enable debugging output +| `-V`, `--trace` | Enable trace logging +| `-DV` | Debug and trace + +### Common options + +| Option | Description +| --- | --- +| `-h`, `--help` | Show usage message +| `-v`, `--version` | Show version ## Configuration @@ -126,39 +124,61 @@ Configuration is a JSON encoded file. If no config file is found at the given pa ```javascript { - // URL to the NATS server - "natsUrl": "nats://127.0.0.1:4222", - // Timeout in milliseconds for NATS requests - "requestTimeout": 3000, - // Bind to HOST IPv4 or IPv6 address - // Empty string ("") means all IPv4 and IPv6 addresses. - // Invalid or missing IP address defaults to 0.0.0.0. - "addr": "0.0.0.0", - // Port for the http server to listen on. - // If the port value is missing or 0, standard http(s) port is used. - "port": 8080, - // Path for accessing the RES API WebSocket - "wsPath": "/", - // Path for accessing web resources - "apiPath": "/api", - // Encoding for web resources. - // Available encodings are: - // * json - JSON encoding with resource reference meta data - // * jsonflat - JSON encoding without resource reference meta data - "apiEncoding": "json", - // Header authentication resource method for web resources. - // Prior to accessing the resource, this resource method will be - // called, allowing an auth service to set a token using - // information such as the request headers. - // Missing value or null will disable header authentication. - // Eg. "authService.headerLogin" - "headerAuth": null, - // Flag telling if tls encryption is enabled - "tls": false, - // Certificate file path for tls encryption - "tlsCert": "", - // Key file path for tls encryption - "tlsKey": "" + // URL to the NATS server. + "natsUrl": "nats://127.0.0.1:4222", + // NATS User Credentials file path. + // Eg. "ngs.creds" + "natsCreds": null, + // Timeout in milliseconds for NATS requests + "requestTimeout": 3000, + // Bind to HOST IPv4 or IPv6 address. + // Empty string ("") means all IPv4 and IPv6 addresses. + // Invalid or missing IP address defaults to 0.0.0.0. + "addr": "0.0.0.0", + // Port for the http server to listen on. + // If the port value is missing or 0, standard http(s) port is used. + "port": 8080, + // Path for accessing the RES API WebSocket. + "wsPath": "/", + // Path prefix for accessing web resources. + "apiPath": "/api", + // Encoding for web resources. + // Available encodings are: + // * json - JSON encoding with resource reference meta data. + // * jsonflat - JSON encoding without resource reference meta data. + "apiEncoding": "json", + // Flag enabling WebSocket per message compression (RFC 7692). + "wsCompression": false, + // Call method name to map HTTP PUT method requests to. + // Eg. "put" + "putMethod": null, + // Call method name to map HTTP DELETE method requests to. + // Eg. "delete" + "deleteMethod": null, + // Call method name to map HTTP PATCH method requests to. + // Eg. "patch" + "patchMethod": null, + // Header authentication resource method for web resources. + // Prior to accessing the resource, this resource method will be + // called, allowing an auth service to set a token using + // information such as the request headers. + // Missing value or null will disable header authentication. + // Eg. "authService.headerLogin" + "headerAuth": null, + // Flag enabling tls encryption. + "tls": false, + // Certificate file path for tls encryption. + "tlsCert": "", + // Key file path for tls encryption. + "tlsKey": "", + // Allowed origin for CORS requests, or * to allow all origins. + // Multiple origins are separated by semicolon. + // Eg. "https://example.com;https://api.example.com" + "allowOrigin": "*", + // Flag enabling debug logging. + "debug": false, + // Flag enabling trace logging. + "trace": false } ``` diff --git a/docker/Dockerfile b/docker/Dockerfile index 5e18e19..387fc93 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.13.4-alpine3.10 as builder +FROM golang:1.14.1-alpine3.11 as builder LABEL maintainer="Samuel Jirenius " diff --git a/docker/Dockerfile.alpine b/docker/Dockerfile.alpine index e4501eb..50bab48 100644 --- a/docker/Dockerfile.alpine +++ b/docker/Dockerfile.alpine @@ -1,4 +1,4 @@ -FROM golang:1.13.4-alpine3.10 as builder +FROM golang:1.14.1-alpine3.11 as builder LABEL maintainer="Samuel Jirenius " @@ -15,7 +15,7 @@ COPY . . RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -v -ldflags "-s -w" -o /resgate -FROM alpine:3.10 +FROM alpine:3.11 COPY --from=builder /resgate /bin/resgate EXPOSE 8080 diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 4460572..f3680f6 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -10,6 +10,8 @@ All changes to the RES Protocol will be documented in this file. * #136 Added RES client *version* request. * #135 Resource response on call and auth requests. +See [v1.2 update page](res-protocol-v1.2-update.md) for more info. + ## v1.1.1 - [Resgate v1.3.0](compare/v1.2.2...v1.3.0) - 2019-10-02 * #110 Allow query on non-query requests. @@ -17,10 +19,11 @@ All changes to the RES Protocol will be documented in this file. * #113 Added [RES Protocol Semantic Versioning](blob/v1.3.0/docs/res-protocol-semver.md). ## v1.1.0 - [Resgate v1.2.1](compare/v1.2.0...v1.2.1) - 2019-08-05 -See [v1.1 update page](docs/res-protocol-v1.1-update.md) for more info. * #68 Props field on change event. +See [v1.1 update page](res-protocol-v1.1-update.md) for more info. + ## v1.0.0 - [Resgate v1.0.0](tree/v1.0.0) - 2018-09-22 * Initial release. \ No newline at end of file diff --git a/docs/res-protocol.md b/docs/res-protocol.md index 741cabc..f2cd9d7 100644 --- a/docs/res-protocol.md +++ b/docs/res-protocol.md @@ -135,7 +135,7 @@ If a gateway loses the connection to a client, it will make no attempt at recove A client is the application that accesses the API by connecting to a gateway's WebSocket. While it may be possible to access the API resources through HTTP requests, any reference in these documentations to *client* implies a client using the WebSocket. -A client uses the [RES-service protocol](res-service-protocol.md) for communication. +A client uses the [RES-client protocol](res-client-protocol.md) for communication. ## Connection IDs diff --git a/docs/res-service-protocol.md b/docs/res-service-protocol.md index cf17795..17dc038 100644 --- a/docs/res-service-protocol.md +++ b/docs/res-service-protocol.md @@ -402,6 +402,7 @@ A delete action is a JSON object used when a property has been deleted from a mo `event..add` Add events are sent when a value is added to a [collection](res-protocol.md#collections). +Any previous value at the same index or higher will implicitly be shifted one step to a higher index. MUST NOT be sent on [models](res-protocol.md#models). The event payload has the following parameters: @@ -410,7 +411,7 @@ The event payload has the following parameters: **idx** Zero-based index number of where the value is inserted. -MUST be a number that is zero or greater and less than the length of the collection. +MUST be a number that is zero or greater and less than or equal to the length of the collection. **Example payload** ```json @@ -426,6 +427,7 @@ MUST be a number that is zero or greater and less than the length of the collect `event..remove` Remove events are sent when a value is removed from a [collection](res-protocol.md#collections). +Any previous value at a higher index will implicitly be shifted one step to a lower index. MUST NOT be sent on [models](res-protocol.md#models). The event payload has the following parameter: diff --git a/examples/edit-text/package-lock.json b/examples/edit-text/package-lock.json index f8c3a7e..c4a15ea 100644 --- a/examples/edit-text/package-lock.json +++ b/examples/edit-text/package-lock.json @@ -553,9 +553,9 @@ } }, "minimist": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.0.tgz", - "integrity": "sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ=", + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.5.tgz", + "integrity": "sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw==", "dev": true }, "ms": { diff --git a/examples/hello-world/package-lock.json b/examples/hello-world/package-lock.json index bff5bee..a956e04 100644 --- a/examples/hello-world/package-lock.json +++ b/examples/hello-world/package-lock.json @@ -553,9 +553,9 @@ } }, "minimist": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.0.tgz", - "integrity": "sha1-o1AIsg9BOD7sH7kU9M1d95omQoQ=", + "version": "1.2.5", + "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.5.tgz", + "integrity": "sha512-FM9nNUYrRBAELZQT3xeZQ7fmMOBg6nWNmJKTcgsJeaLstP/UODVpGsr5OhXhhXg6f+qtJ8uiZ+PUxkDWcgIXLw==", "dev": true }, "ms": { diff --git a/go.mod b/go.mod index c890f93..1c55909 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,11 @@ module github.com/resgateio/resgate go 1.13 require ( - github.com/gorilla/websocket v1.4.1 + github.com/gorilla/websocket v1.4.2 github.com/jirenius/timerqueue v1.0.0 - github.com/nats-io/nats-server/v2 v2.1.0 // indirect - github.com/nats-io/nats.go v1.8.1 - github.com/posener/wstest v0.0.0-20180217133618-28272a7ea048 + github.com/nats-io/nats-server/v2 v2.1.4 // indirect + github.com/nats-io/nats.go v1.9.1 + github.com/posener/wstest v1.2.0 github.com/rs/xid v1.2.1 - github.com/stretchr/testify v1.4.0 // indirect - golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc // indirect + golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 // indirect ) diff --git a/go.sum b/go.sum index 6c7e834..733a3f4 100644 --- a/go.sum +++ b/go.sum @@ -2,38 +2,39 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/jirenius/timerqueue v1.0.0 h1:TgcUQlrxKBBHYmStXPzLdMPJFfmqkWZZ1s7BA5G1d9E= github.com/jirenius/timerqueue v1.0.0/go.mod h1:pUEjy16BUruJMjLIsjWvWQh9Bu9CSXCIfGADZf37WIk= github.com/nats-io/jwt v0.3.0 h1:xdnzwFETV++jNc4W1mw//qFyJGb2ABOombmZJQS4+Qo= github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= -github.com/nats-io/nats-server/v2 v2.1.0 h1:Yi0+ZhRPtPAGeIxFn5erIeJIV9wXA+JznfSxK621Fbk= -github.com/nats-io/nats-server/v2 v2.1.0/go.mod h1:r5y0WgCag0dTj/qiHkHrXAcKQ/f5GMOZaEGdoxxnJ4I= -github.com/nats-io/nats.go v1.8.1 h1:6lF/f1/NN6kzUDBz6pyvQDEXO39jqXcWRLu/tKjtOUQ= -github.com/nats-io/nats.go v1.8.1/go.mod h1:BrFz9vVn0fU3AcH9Vn4Kd7W0NpJ651tD5omQ3M8LwxM= -github.com/nats-io/nkeys v0.0.2 h1:+qM7QpgXnvDDixitZtQUBDY9w/s9mu1ghS+JIbsrx6M= -github.com/nats-io/nkeys v0.0.2/go.mod h1:dab7URMsZm6Z/jp9Z5UGa87Uutgc2mVpXLC4B7TDb/4= +github.com/nats-io/jwt v0.3.2 h1:+RB5hMpXUUA2dfxuhBTEkMOrYmM+gKIZYS1KjSostMI= +github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= +github.com/nats-io/nats-server/v2 v2.1.4 h1:BILRnsJ2Yb/fefiFbBWADpViGF69uh4sxe8poVDQ06g= +github.com/nats-io/nats-server/v2 v2.1.4/go.mod h1:Jw1Z28soD/QasIA2uWjXyM9El1jly3YwyFOuR8tH1rg= +github.com/nats-io/nats.go v1.9.1 h1:ik3HbLhZ0YABLto7iX80pZLPw/6dx3T+++MZJwLnMrQ= +github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= github.com/nats-io/nkeys v0.1.0 h1:qMd4+pRHgdr1nAClu+2h/2a5F2TmKcCzjCDazVgRoX4= github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= +github.com/nats-io/nkeys v0.1.3 h1:6JrEfig+HzTH85yxzhSVbjHRJv9cn0p6n3IngIcM5/k= +github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/posener/wstest v0.0.0-20180217133618-28272a7ea048 h1:XJ1bEwzKDbW33q703QCy580ZEqT2/hXTrU5sUYZf5LI= -github.com/posener/wstest v0.0.0-20180217133618-28272a7ea048/go.mod h1:cjC8eRbwXrr5m2069dsjp7l7b0gWqFwMTUBDLNvVqho= +github.com/posener/wstest v1.2.0 h1:PAY0cRybxOjh0yqSDCrlAGUwtx+GNKpuUfid/08pv48= +github.com/posener/wstest v1.2.0/go.mod h1:GkplCx9zskpudjrMp23LyZHrSonab0aZzh2x0ACGRbU= github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9 h1:mKdxBk7AujPs8kU4m80U72y/zjbZ3UcXC7dClwKbUI0= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc h1:KyTYo8xkh/2WdbFLUyQwBS0Jfn3qfZ9QmuPbok2oENE= -golang.org/x/crypto v0.0.0-20191001170739-f9e2070545dc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6 h1:TjszyFsQsyZNHwdVdZ5m7bjmreu0znc2kRYsEml9/Ww= +golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index 02fe9e1..2a0ee63 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "os" "os/signal" + "strings" "syscall" "time" @@ -43,6 +44,10 @@ Server Options: --tlskey Private key for HTTP server certificate --apiencoding Encoding for web resources: json, jsonflat (default: json) --creds NATS User Credentials file + --alloworigin Allowed origin(s): *, or ://[:] (default: *) + --putmethod Call method name mapped to HTTP PUT requests + --deletemethod Call method name mapped to HTTP DELETE requests + --patchmethod Call method name mapped to HTTP PATCH requests -c, --config Configuration file Logging Options: @@ -67,6 +72,22 @@ type Config struct { server.Config } +// StringSlice is a slice of strings implementing the flag.Value interface. +type StringSlice []string + +func (s *StringSlice) String() string { + if s == nil { + return "" + } + return strings.Join(*s, ";") +} + +// Set adds a value to the slice. +func (s *StringSlice) Set(v string) error { + *s = append(*s, v) + return nil +} + // SetDefault sets the default values func (c *Config) SetDefault() { if c.NatsURL == "" { @@ -82,14 +103,18 @@ func (c *Config) SetDefault() { // If no file exists, a new file with default settings is created func (c *Config) Init(fs *flag.FlagSet, args []string) { var ( - showHelp bool - showVersion bool - configFile string - port uint - headauth string - addr string - natsCreds string - debugTrace bool + showHelp bool + showVersion bool + configFile string + port uint + headauth string + addr string + natsCreds string + debugTrace bool + allowOrigin StringSlice + putMethod string + deleteMethod string + patchMethod string ) fs.BoolVar(&showHelp, "h", false, "Show this message.") @@ -115,6 +140,10 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) { fs.IntVar(&c.RequestTimeout, "r", 0, "Timeout in milliseconds for NATS requests.") fs.IntVar(&c.RequestTimeout, "reqtimeout", 0, "Timeout in milliseconds for NATS requests.") fs.StringVar(&natsCreds, "creds", "", "NATS User Credentials file.") + fs.Var(&allowOrigin, "alloworigin", "Allowed origin(s) for CORS.") + fs.StringVar(&putMethod, "putmethod", "", "Call method name mapped to HTTP PUT requests.") + fs.StringVar(&deleteMethod, "deletemethod", "", "Call method name mapped to HTTP DELETE requests.") + fs.StringVar(&patchMethod, "patchmethod", "", "Call method name mapped to HTTP PATCH requests.") fs.BoolVar(&c.Debug, "D", false, "Enable debugging output.") fs.BoolVar(&c.Debug, "debug", false, "Enable debugging output.") fs.BoolVar(&c.Trace, "V", false, "Enable trace logging.") @@ -139,6 +168,7 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) { version() } + writeConfig := false if configFile != "" { fin, err := ioutil.ReadFile(configFile) if err != nil { @@ -147,13 +177,7 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) { } c.SetDefault() - - fout, err := json.MarshalIndent(c, "", "\t") - if err != nil { - printAndDie(fmt.Sprintf("Error encoding config: %s", err), false) - } - - ioutil.WriteFile(configFile, fout, os.FileMode(0664)) + writeConfig = true } else { err = json.Unmarshal(fin, c) if err != nil { @@ -169,22 +193,31 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) { c.Port = uint16(port) } + // Helper function to set string pointers to nil if empty. + setString := func(v string, s **string) { + if v == "" { + *s = nil + } else { + *s = &v + } + } fs.Visit(func(f *flag.Flag) { switch f.Name { case "u": fallthrough case "headauth": - if headauth == "" { - c.HeaderAuth = nil - } else { - c.HeaderAuth = &headauth - } + setString(headauth, &c.HeaderAuth) case "creds": - if natsCreds == "" { - c.NatsCreds = nil - } else { - c.NatsCreds = &natsCreds - } + setString(natsCreds, &c.NatsCreds) + case "alloworigin": + str := allowOrigin.String() + c.AllowOrigin = &str + case "putmethod": + setString(putMethod, &c.PUTMethod) + case "deletemethod": + setString(deleteMethod, &c.DELETEMethod) + case "patchmethod": + setString(patchMethod, &c.PATCHMethod) case "i": fallthrough case "addr": @@ -197,6 +230,15 @@ func (c *Config) Init(fs *flag.FlagSet, args []string) { // Any value not set, set it now c.SetDefault() + + // Write config file + if writeConfig { + fout, err := json.MarshalIndent(c, "", "\t") + if err != nil { + printAndDie(fmt.Sprintf("Error encoding config: %s", err), false) + } + ioutil.WriteFile(configFile, fout, os.FileMode(0664)) + } } // usage will print out the flag options for the server. diff --git a/nats/nats.go b/nats/nats.go index 2501607..ee35ffc 100644 --- a/nats/nats.go +++ b/nats/nats.go @@ -26,7 +26,7 @@ type Client struct { mq *nats.Conn mqCh chan *nats.Msg - mqReqs map[*nats.Subscription]responseCont + mqReqs map[*nats.Subscription]*responseCont tq *timerqueue.Queue mu sync.Mutex closeHandler func(error) @@ -85,7 +85,7 @@ func (c *Client) Connect() error { c.mq = nc c.mqCh = make(chan *nats.Msg, natsChannelSize) - c.mqReqs = make(map[*nats.Subscription]responseCont) + c.mqReqs = make(map[*nats.Subscription]*responseCont) c.tq = timerqueue.New(c.onTimeout, c.RequestTimeout) c.stopped = make(chan struct{}) @@ -126,7 +126,7 @@ func (c *Client) Close() { c.mq = nil // Set mqReqs to empty map to avoid possible nil reference error in listener - c.mqReqs = make(map[*nats.Subscription]responseCont) + c.mqReqs = make(map[*nats.Subscription]*responseCont) c.tq.Clear() c.tq = nil @@ -175,7 +175,7 @@ func (c *Client) SendRequest(subj string, payload []byte, cb mq.Response) { } c.tq.Add(sub) - c.mqReqs[sub] = responseCont{isReq: true, f: cb} + c.mqReqs[sub] = &responseCont{isReq: true, f: cb} } // Subscribe to all events on a resource namespace. @@ -191,7 +191,7 @@ func (c *Client) Subscribe(namespace string, cb mq.Response) (mq.Unsubscriber, e c.Tracef("S=> %s", sub.Subject) - c.mqReqs[sub] = responseCont{f: cb} + c.mqReqs[sub] = &responseCont{f: cb} us := &Subscription{c: c, sub: sub} return us, nil @@ -244,7 +244,7 @@ func (c *Client) listener(ch chan *nats.Msg, stopped chan struct{}) { close(stopped) } -func (c *Client) parseMeta(msg *nats.Msg, rc responseCont) { +func (c *Client) parseMeta(msg *nats.Msg, rc *responseCont) { tag := reflect.StructTag(msg.Data) // timeout tag diff --git a/scripts/check.sh b/scripts/check.sh new file mode 100755 index 0000000..21a04ca --- /dev/null +++ b/scripts/check.sh @@ -0,0 +1,16 @@ +#!/bin/bash -e +# Run from directory above via ./scripts/check.sh + +echo "Checking formatting..." +if [ -n "$(gofmt -s -l .)" ]; then + echo "Code is not formatted. Run 'gofmt -s -w .'" + exit 1 +fi +echo "Checking with go vet..." +go vet ./... +echo "Checking with staticcheck..." +staticcheck ./... +echo "Checking with golint..." +golint -set_exit_status ./... +echo "Checking with misspell..." +misspell -error -locale US . diff --git a/scripts/install-checks.sh b/scripts/install-checks.sh new file mode 100755 index 0000000..52b6644 --- /dev/null +++ b/scripts/install-checks.sh @@ -0,0 +1,8 @@ +#!/bin/bash -e + +pushd /tmp > /dev/null +go get -u github.com/mattn/goveralls +go get -u honnef.co/go/tools/cmd/staticcheck +go get -u golang.org/x/lint/golint +go get -u github.com/client9/misspell/cmd/misspell +popd > /dev/null diff --git a/server/apiHandler.go b/server/apiHandler.go index 2df4481..7dab2a8 100644 --- a/server/apiHandler.go +++ b/server/apiHandler.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "mime" "net/http" "strings" @@ -21,10 +22,48 @@ func (s *Service) initAPIHandler() error { return fmt.Errorf("invalid apiEncoding setting (%s) - available encodings: %s", s.cfg.APIEncoding, strings.Join(keys, ", ")) } s.enc = f(s.cfg) + mimetype, _, err := mime.ParseMediaType(s.enc.ContentType()) + s.mimetype = mimetype + return err +} + +// setCommonHeaders sets common headers such as Access-Control-*. +// It returns error if the origin header does not match any allowed origin. +func (s *Service) setCommonHeaders(w http.ResponseWriter, r *http.Request) error { + if s.cfg.allowOrigin[0] == "*" { + w.Header().Set("Access-Control-Allow-Origin", "*") + return nil + } + + // CORS validation + origin := r.Header["Origin"] + // If no Origin header is set, or the value is null, we can allow access + // as it is not coming from a CORS enabled browser. + if len(origin) > 0 && origin[0] != "null" { + if matchesOrigins(s.cfg.allowOrigin, origin[0]) { + w.Header().Set("Access-Control-Allow-Origin", origin[0]) + w.Header().Set("Vary", "Origin") + } else { + // No matching origin + w.Header().Set("Access-Control-Allow-Origin", s.cfg.allowOrigin[0]) + w.Header().Set("Vary", "Origin") + return reserr.ErrForbiddenOrigin + } + } return nil } func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { + err := s.setCommonHeaders(w, r) + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", s.cfg.allowMethods) + return + } + if err != nil { + httpError(w, err, s.enc) + return + } + path := r.URL.RawPath if path == "" { path = r.URL.Path @@ -32,15 +71,18 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { apiPath := s.cfg.APIPath + // NotFound on oaths with trailing slash (unless it is only the APIPath) + if len(path) > len(apiPath) && path[len(path)-1] == '/' { + notFoundHandler(w, r, s.enc) + return + } + + var rid, action string switch r.Method { + case "HEAD": + fallthrough case "GET": - // Redirect paths with trailing slash (unless it is only the APIPath) - if len(path) > len(apiPath) && path[len(path)-1] == '/' { - notFoundHandler(w, r, s.enc) - return - } - - rid := PathToRID(path, r.URL.RawQuery, apiPath) + rid = PathToRID(path, r.URL.RawQuery, apiPath) if !codec.IsValidRID(rid, true) { notFoundHandler(w, r, s.enc) return @@ -55,59 +97,79 @@ func (s *Service) apiHandler(w http.ResponseWriter, r *http.Request) { cb(s.enc.EncodeGET(sub)) }) }) + return case "POST": - // Redirect paths with trailing slash (unless it is only the APIPath) - if len(path) > len(apiPath) && path[len(path)-1] == '/' { - notFoundHandler(w, r, s.enc) - return + rid, action = PathToRIDAction(path, r.URL.RawQuery, apiPath) + default: + var m *string + switch r.Method { + case "PUT": + if s.cfg.PUTMethod != nil { + m = s.cfg.PUTMethod + } + case "DELETE": + if s.cfg.DELETEMethod != nil { + m = s.cfg.DELETEMethod + } + case "PATCH": + if s.cfg.PATCHMethod != nil { + m = s.cfg.PATCHMethod + } } - - rid, action := PathToRIDAction(path, r.URL.RawQuery, apiPath) - if !codec.IsValidRID(rid, true) || !codec.IsValidRID(action, false) { - notFoundHandler(w, r, s.enc) + // Return error if we have no mapping for the method + if m == nil { + httpError(w, reserr.ErrMethodNotAllowed, s.enc) return } + rid = PathToRID(path, r.URL.RawQuery, apiPath) + action = *m + } + + s.handleCall(w, r, rid, action) +} - // Try to parse the body - b, err := ioutil.ReadAll(r.Body) +func notFoundHandler(w http.ResponseWriter, r *http.Request, enc APIEncoder) { + w.Header().Set("Content-Type", enc.ContentType()) + w.WriteHeader(http.StatusNotFound) + w.Write(enc.NotFoundError()) +} + +func (s *Service) handleCall(w http.ResponseWriter, r *http.Request, rid string, action string) { + if !codec.IsValidRID(rid, true) || !codec.IsValidRIDPart(action) { + notFoundHandler(w, r, s.enc) + return + } + + // Try to parse the body + b, err := ioutil.ReadAll(r.Body) + if err != nil { + httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc) + return + } + + var params json.RawMessage + if strings.TrimSpace(string(b)) != "" { + err = json.Unmarshal(b, ¶ms) if err != nil { - httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error reading request body: " + err.Error()}, s.enc) + httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc) return } + } - var params json.RawMessage - if strings.TrimSpace(string(b)) != "" { - err = json.Unmarshal(b, ¶ms) + s.temporaryConn(w, r, func(c *wsConn, cb func([]byte, error)) { + c.CallHTTPResource(rid, s.cfg.APIPath, action, params, func(r json.RawMessage, href string, err error) { if err != nil { - httpError(w, &reserr.Error{Code: reserr.CodeBadRequest, Message: "Error decoding request body: " + err.Error()}, s.enc) - return + cb(nil, err) + } else if href != "" { + w.Header().Set("Location", href) + w.WriteHeader(http.StatusOK) + cb(nil, nil) + } else { + cb(s.enc.EncodePOST(r)) } - } - - s.temporaryConn(w, r, func(c *wsConn, cb func([]byte, error)) { - c.CallHTTPResource(rid, s.cfg.APIPath, action, params, func(r json.RawMessage, href string, err error) { - if err != nil { - cb(nil, err) - } else if href != "" { - w.Header().Set("Location", href) - w.WriteHeader(http.StatusOK) - cb(nil, nil) - } else { - cb(s.enc.EncodePOST(r)) - } - }) }) - - default: - httpError(w, reserr.ErrMethodNotAllowed, s.enc) - } -} - -func notFoundHandler(w http.ResponseWriter, r *http.Request, enc APIEncoder) { - w.Header().Set("Content-Type", enc.ContentType()) - w.WriteHeader(http.StatusNotFound) - w.Write(enc.NotFoundError()) + }) } func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func(*wsConn, func([]byte, error))) { @@ -123,6 +185,13 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func( defer close(done) if err != nil { + // Convert system.methodNotFound to system.methodNotAllowed for PUT/DELETE/PATCH + if rerr, ok := err.(*reserr.Error); ok { + if rerr.Code == reserr.CodeMethodNotFound && (r.Method == "PUT" || r.Method == "DELETE" || r.Method == "PATCH") { + httpError(w, reserr.ErrMethodNotAllowed, s.enc) + return + } + } httpError(w, err, s.enc) return } @@ -149,7 +218,6 @@ func (s *Service) temporaryConn(w http.ResponseWriter, r *http.Request, cb func( func httpError(w http.ResponseWriter, err error, enc APIEncoder) { rerr := reserr.RESError(err) - out := enc.EncodeError(rerr) var code int switch rerr.Code { @@ -167,11 +235,13 @@ func httpError(w http.ResponseWriter, err error, enc APIEncoder) { code = http.StatusInternalServerError case reserr.CodeServiceUnavailable: code = http.StatusServiceUnavailable + case reserr.CodeForbidden: + code = http.StatusForbidden default: code = http.StatusBadRequest } w.Header().Set("Content-Type", enc.ContentType()) w.WriteHeader(code) - w.Write(out) + w.Write(enc.EncodeError(rerr)) } diff --git a/server/codec/codec.go b/server/codec/codec.go index 059db97..f44959d 100644 --- a/server/codec/codec.go +++ b/server/codec/codec.go @@ -629,3 +629,13 @@ func IsValidRID(rid string, allowQuery bool) bool { return !start } + +// IsValidRIDPart returns true if the RID part is valid, otherwise false. +func IsValidRIDPart(part string) bool { + for _, r := range part { + if r < 33 || r > 126 || r == '.' || r == '*' || r == '>' || r == '?' { + return false + } + } + return len(part) > 0 +} diff --git a/server/config.go b/server/config.go index 790be08..1454b2e 100644 --- a/server/config.go +++ b/server/config.go @@ -1,21 +1,29 @@ package server import ( + "errors" "fmt" "net" + "net/url" + "sort" "strings" + "unicode/utf8" "github.com/resgateio/resgate/server/codec" ) // Config holds server configuration type Config struct { - Addr *string `json:"addr"` - Port uint16 `json:"port"` - WSPath string `json:"wsPath"` - APIPath string `json:"apiPath"` - APIEncoding string `json:"apiEncoding"` - HeaderAuth *string `json:"headerAuth"` + Addr *string `json:"addr"` + Port uint16 `json:"port"` + WSPath string `json:"wsPath"` + APIPath string `json:"apiPath"` + APIEncoding string `json:"apiEncoding"` + HeaderAuth *string `json:"headerAuth"` + AllowOrigin *string `json:"allowOrigin"` + PUTMethod *string `json:"putMethod"` + DELETEMethod *string `json:"deleteMethod"` + PATCHMethod *string `json:"patchMethod"` TLS bool `json:"tls"` TLSCert string `json:"certFile"` @@ -29,6 +37,8 @@ type Config struct { netAddr string headerAuthRID string headerAuthAction string + allowOrigin []string + allowMethods string } // SetDefault sets the default values @@ -49,6 +59,10 @@ func (c *Config) SetDefault() { if c.APIEncoding == "" { c.APIEncoding = DefaultAPIEncoding } + if c.AllowOrigin == nil { + origin := "*" + c.AllowOrigin = &origin + } } // prepare sets the unexported values @@ -79,7 +93,7 @@ func (c *Config) prepare() error { c.netAddr = ip.String() } } else { - return fmt.Errorf("invalid addr setting (%s) - must be a valid IPv4 or IPv6 address", s) + return fmt.Errorf("invalid addr setting (%s)\n\tmust be a valid IPv4 or IPv6 address", s) } } } else { @@ -94,9 +108,40 @@ func (c *Config) prepare() error { c.headerAuthRID = s[:idx] c.headerAuthAction = s[idx+1:] } else { - return fmt.Errorf("invalid headerAuth setting (%s) - must be a valid resource method", s) + return fmt.Errorf("invalid headerAuth setting (%s)\n\tmust be a valid resource method", s) } } + + if c.AllowOrigin != nil { + c.allowOrigin = strings.Split(*c.AllowOrigin, ";") + if err := validateAllowOrigin(c.allowOrigin); err != nil { + return fmt.Errorf("invalid allowOrigin setting (%s)\n\t%s\n\tvalid options are *, or a list of semi-colon separated origins", *c.AllowOrigin, err) + } + sort.Strings(c.allowOrigin) + } else { + c.allowOrigin = []string{"*"} + } + + c.allowMethods = "GET, HEAD, OPTIONS, POST" + if c.PUTMethod != nil { + if !codec.IsValidRIDPart(*c.PUTMethod) { + return fmt.Errorf("invalid putMethod setting (%s)\n\tmust be a valid call method name", *c.PUTMethod) + } + c.allowMethods += ", PUT" + } + if c.DELETEMethod != nil { + if !codec.IsValidRIDPart(*c.DELETEMethod) { + return fmt.Errorf("invalid deleteMethod setting (%s)\n\tmust be a valid call method name", *c.DELETEMethod) + } + c.allowMethods += ", DELETE" + } + if c.PATCHMethod != nil { + if !codec.IsValidRIDPart(*c.PATCHMethod) { + return fmt.Errorf("invalid patchMethod setting (%s)\n\tmust be a valid call method name", *c.PATCHMethod) + } + c.allowMethods += ", PATCH" + } + if c.WSPath == "" { c.WSPath = "/" } @@ -106,3 +151,65 @@ func (c *Config) prepare() error { return nil } + +func validateAllowOrigin(s []string) error { + for i, o := range s { + o = toLowerASCII(o) + s[i] = o + if o == "*" { + if len(s) > 1 { + return fmt.Errorf("'%s' must not be used together with other origin settings", o) + } + } else { + if o == "" { + return errors.New("origin must not be empty") + } + u, err := url.Parse(o) + if err != nil || u.Scheme == "" || u.Host == "" || u.Opaque != "" || u.User != nil || u.Path != "" || len(u.Query()) > 0 || u.Fragment != "" { + return fmt.Errorf("'%s' doesn't match ://[:]", o) + } + } + } + return nil +} + +// toLowerASCII converts only A-Z to lower case in a string +func toLowerASCII(s string) string { + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + if 'A' <= c && c <= 'Z' { + c += 'a' - 'A' + } + b.WriteByte(c) + } + return b.String() +} + +func matchesOrigins(os []string, o string) bool { +origin: + for _, s := range os { + t := o + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + // Lowercase A-Z. Should already be done for origins. + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + continue origin + } + } + if s == t { + return true + } + } + return false +} diff --git a/server/config_test.go b/server/config_test.go index 96cd32a..f7fd9b2 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -5,6 +5,12 @@ import ( "testing" ) +func compareString(t *testing.T, name string, str, exp string, i int) { + if str != exp { + t.Fatalf("expected %s to be:\n%s\nbut got:\n%s\nin test #%d", name, exp, str, i+1) + } +} + func compareStringPtr(t *testing.T, name string, str, exp *string, i int) { if str == exp { return @@ -26,6 +32,16 @@ func TestConfigPrepare(t *testing.T) { ipv6Addr := "::1" invalidAddr := "127.0.0" invalidHeaderAuth := "test" + allowOriginAll := "*" + allowOriginSingle := "http://resgate.io" + allowOriginMultiple := "http://localhost;http://resgate.io" + allowOriginInvalidEmpty := "" + allowOriginInvalidEmptyOrigin := ";http://localhost" + allowOriginInvalidMultipleAll := "http://localhost;*" + allowOriginInvalidMultipleSame := "http://localhost;*" + allowOriginInvalidOrigin := "http://this.is/invalid" + method := "foo" + invalidMethod := "foo.bar" defaultCfg := Config{} defaultCfg.SetDefault() @@ -34,14 +50,32 @@ func TestConfigPrepare(t *testing.T) { Expected Config PrepareError bool }{ - {defaultCfg, Config{Addr: &defaultAddr, Port: 8080, WSPath: "/", APIPath: "/api/", APIEncoding: "json", scheme: "http", netAddr: "0.0.0.0:8080"}, false}, - {Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80"}, false}, - {Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80"}, false}, - {Config{Addr: &emptyAddr, WSPath: "/"}, Config{Addr: &emptyAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: ":80"}, false}, - {Config{Addr: &localAddr, WSPath: "/"}, Config{Addr: &localAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "127.0.0.1:80"}, false}, - {Config{Addr: &ipv6Addr, WSPath: "/"}, Config{Addr: &ipv6Addr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "[::1]:80"}, false}, + // Valid config + {defaultCfg, Config{Addr: &defaultAddr, Port: 8080, WSPath: "/", APIPath: "/api/", APIEncoding: "json", scheme: "http", netAddr: "0.0.0.0:8080", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{Addr: &emptyAddr, WSPath: "/"}, Config{Addr: &emptyAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: ":80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{Addr: &localAddr, WSPath: "/"}, Config{Addr: &localAddr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "127.0.0.1:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{Addr: &ipv6Addr, WSPath: "/"}, Config{Addr: &ipv6Addr, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "[::1]:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + // Allow origin + {Config{AllowOrigin: &allowOriginAll, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{AllowOrigin: &allowOriginSingle, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"http://resgate.io"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + {Config{AllowOrigin: &allowOriginMultiple, WSPath: "/"}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"http://localhost", "http://resgate.io"}, allowMethods: "GET, HEAD, OPTIONS, POST"}, false}, + // HTTP method mapping + {Config{WSPath: "/", PUTMethod: &method}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", PUTMethod: &method, scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST, PUT"}, false}, + {Config{WSPath: "/", DELETEMethod: &method}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", DELETEMethod: &method, scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST, DELETE"}, false}, + {Config{WSPath: "/", PATCHMethod: &method}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", PATCHMethod: &method, scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST, PATCH"}, false}, + {Config{WSPath: "/", PUTMethod: &method, DELETEMethod: &method, PATCHMethod: &method}, Config{Addr: nil, Port: 80, WSPath: "/", APIPath: "/", PUTMethod: &method, DELETEMethod: &method, PATCHMethod: &method, scheme: "http", netAddr: "0.0.0.0:80", allowOrigin: []string{"*"}, allowMethods: "GET, HEAD, OPTIONS, POST, PUT, DELETE, PATCH"}, false}, + // Invalid config {Config{Addr: &invalidAddr, WSPath: "/"}, Config{}, true}, {Config{HeaderAuth: &invalidHeaderAuth, WSPath: "/"}, Config{}, true}, + {Config{AllowOrigin: &allowOriginInvalidEmpty, WSPath: "/"}, Config{}, true}, + {Config{AllowOrigin: &allowOriginInvalidEmptyOrigin, WSPath: "/"}, Config{}, true}, + {Config{AllowOrigin: &allowOriginInvalidMultipleAll, WSPath: "/"}, Config{}, true}, + {Config{AllowOrigin: &allowOriginInvalidMultipleSame, WSPath: "/"}, Config{}, true}, + {Config{AllowOrigin: &allowOriginInvalidOrigin, WSPath: "/"}, Config{}, true}, + {Config{PUTMethod: &invalidMethod, WSPath: "/"}, Config{}, true}, + {Config{DELETEMethod: &invalidMethod, WSPath: "/"}, Config{}, true}, + {Config{PATCHMethod: &invalidMethod, WSPath: "/"}, Config{}, true}, } for i, r := range tbl { @@ -56,42 +90,34 @@ func TestConfigPrepare(t *testing.T) { t.Fatalf("expected an error, but got none, in test #%d", i+1) } - if cfg.WSPath != r.Expected.WSPath { - t.Fatalf("expected WSPath to be:\n%s\nbut got:\n%s\nin test #%d", r.Expected.WSPath, cfg.WSPath, i+1) - } - - if cfg.APIPath != r.Expected.APIPath { - t.Fatalf("expected APIPath to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.APIPath, cfg.APIPath, i+1) - } - - if cfg.APIEncoding != r.Expected.APIEncoding { - t.Fatalf("expected APIEncoding to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.APIEncoding, cfg.APIEncoding, i+1) - } - + compareString(t, "WSPath", cfg.WSPath, r.Expected.WSPath, i) + compareString(t, "APIPath", cfg.APIPath, r.Expected.APIPath, i) + compareString(t, "APIEncoding", cfg.APIEncoding, r.Expected.APIEncoding, i) compareStringPtr(t, "Addr", cfg.Addr, r.Expected.Addr, i) + compareStringPtr(t, "PUTMethod", cfg.PUTMethod, r.Expected.PUTMethod, i) + compareStringPtr(t, "DELETEMethod", cfg.DELETEMethod, r.Expected.DELETEMethod, i) + compareStringPtr(t, "PATCHMethod", cfg.PATCHMethod, r.Expected.PATCHMethod, i) if cfg.Port != r.Expected.Port { t.Fatalf("expected Port to be:\n%d\nbut got:\n%d\nin test %d", r.Expected.Port, cfg.Port, i+1) } - if cfg.scheme != r.Expected.scheme { - t.Fatalf("expected scheme to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.scheme, cfg.scheme, i+1) - } - - if cfg.netAddr != r.Expected.netAddr { - t.Fatalf("expected netAddr to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.netAddr, cfg.netAddr, i+1) - } + compareString(t, "scheme", cfg.scheme, r.Expected.scheme, i) + compareString(t, "netAddr", cfg.netAddr, r.Expected.netAddr, i) + compareString(t, "headerAuthAction", cfg.headerAuthAction, r.Expected.headerAuthAction, i) + compareString(t, "headerAuthRID", cfg.headerAuthRID, r.Expected.headerAuthRID, i) + compareString(t, "allowMethods", cfg.allowMethods, r.Expected.allowMethods, i) - if cfg.headerAuthAction != r.Expected.headerAuthAction { - t.Fatalf("expected headerAuthAction to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.headerAuthAction, cfg.headerAuthAction, i+1) + if len(cfg.allowOrigin) != len(r.Expected.allowOrigin) { + t.Fatalf("expected allowOrigin to be:\n%+v\nbut got:\n%+v\nin test %d", r.Expected.allowOrigin, cfg.allowOrigin, i+1) } - - if cfg.headerAuthRID != r.Expected.headerAuthRID { - t.Fatalf("expected headerAuthRID to be:\n%s\nbut got:\n%s\nin test %d", r.Expected.headerAuthRID, cfg.headerAuthRID, i+1) + for i, origin := range cfg.allowOrigin { + if origin != r.Expected.allowOrigin[i] { + t.Fatalf("expected allowOrigin to be:\n%+v\nbut got:\n%+v\nin test %d", r.Expected.allowOrigin, cfg.allowOrigin, i+1) + } } compareStringPtr(t, "HeaderAuth", cfg.HeaderAuth, r.Expected.HeaderAuth, i) - } } @@ -135,3 +161,30 @@ func TestVersionMatchesTag(t *testing.T) { t.Fatalf("Expected version %+v, got %+v", Version, tag[1:]) } } + +func TestMatchesOrigins(t *testing.T) { + tbl := []struct { + AllowedOrigins []string + Origin string + Expected bool + }{ + {[]string{"http://localhost"}, "http://localhost", true}, + {[]string{"https://resgate.io"}, "https://resgate.io", true}, + {[]string{"https://resgate.io"}, "https://Resgate.IO", true}, + {[]string{"http://localhost", "https://resgate.io"}, "http://localhost", true}, + {[]string{"http://localhost", "https://resgate.io"}, "https://resgate.io", true}, + {[]string{"http://localhost", "https://resgate.io"}, "https://Resgate.IO", true}, + {[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://Localhost", true}, + {[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "https://Resgate.io", true}, + {[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://resgate.IO", true}, + {[]string{"https://resgate.io"}, "http://resgate.io", false}, + {[]string{"http://localhost", "https://resgate.io"}, "http://resgate.io", false}, + {[]string{"http://localhost", "https://resgate.io", "http://resgate.io"}, "http://localhost/", false}, + } + + for i, r := range tbl { + if matchesOrigins(r.AllowedOrigins, r.Origin) != r.Expected { + t.Fatalf("expected matchesOrigins to return %#v\n\tmatchesOrigins(%#v, %#v)\n\tin test #%d", r.Expected, r.AllowedOrigins, r.Origin, i+1) + } + } +} diff --git a/server/const.go b/server/const.go index 163e770..1eadcb2 100644 --- a/server/const.go +++ b/server/const.go @@ -4,7 +4,7 @@ import "time" const ( // Version is the current version for the server. - Version = "1.4.1" + Version = "1.5.0" // ProtocolVersion is the implemented RES protocol version. ProtocolVersion = "1.2.0" diff --git a/server/reserr/reserr.go b/server/reserr/reserr.go index ae87577..7d440c7 100644 --- a/server/reserr/reserr.go +++ b/server/reserr/reserr.go @@ -50,6 +50,7 @@ const ( CodeBadRequest = "system.badRequest" CodeMethodNotAllowed = "system.methodNotAllowed" CodeServiceUnavailable = "system.serviceUnavailable" + CodeForbidden = "system.forbidden" ) // Pre-defined RES errors @@ -71,4 +72,5 @@ var ( ErrBadRequest = &Error{Code: CodeBadRequest, Message: "Bad request"} ErrMethodNotAllowed = &Error{Code: CodeMethodNotAllowed, Message: "Method not allowed"} ErrServiceUnavailable = &Error{Code: CodeServiceUnavailable, Message: "Service unavailable"} + ErrForbiddenOrigin = &Error{Code: CodeForbidden, Message: "Forbidden origin"} ) diff --git a/server/rpc/rpc.go b/server/rpc/rpc.go index 6953413..496d7c7 100644 --- a/server/rpc/rpc.go +++ b/server/rpc/rpc.go @@ -151,7 +151,7 @@ func HandleRequest(data []byte, req Requester) error { return nil } method = rid[idx+1:] - if !codec.IsValidRID(method, false) { + if !codec.IsValidRIDPart(method) { req.Reply(r.ErrorResponse(reserr.ErrInvalidRequest)) return nil } diff --git a/server/service.go b/server/service.go index ec8f783..a908923 100644 --- a/server/service.go +++ b/server/service.go @@ -25,8 +25,9 @@ type Service struct { cache *rescache.Cache // httpServer - h *http.Server - enc APIEncoder + h *http.Server + enc APIEncoder + mimetype string // wsListener/wsConn upgrader websocket.Upgrader diff --git a/server/wsHandler.go b/server/wsHandler.go index 4498a77..bc57b39 100644 --- a/server/wsHandler.go +++ b/server/wsHandler.go @@ -9,12 +9,26 @@ import ( ) func (s *Service) initWSHandler() { - s.upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { + var co func(r *http.Request) bool + switch s.cfg.allowOrigin[0] { + case "*": + co = func(r *http.Request) bool { return true - }, + } + default: + origins := s.cfg.allowOrigin + co = func(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 || origin[0] == "null" { + return true + } + return matchesOrigins(origins, origin[0]) + } + } + s.upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: co, EnableCompression: s.cfg.WSCompression, } s.conns = make(map[string]*wsConn) diff --git a/test/00connect_test.go b/test/00connect_test.go index bdc9235..73fae5b 100644 --- a/test/00connect_test.go +++ b/test/00connect_test.go @@ -1,7 +1,11 @@ package test import ( + "fmt" + "net/http" "testing" + + "github.com/resgateio/resgate/server" ) // Test that the server starts and stops the server without error @@ -16,12 +20,48 @@ func TestConnectClient(t *testing.T) { }) } -// // Test that a client gets error connecting to a server that is stopped -// func TestNotConnectedClientWhenStopped(t *testing.T) { -// var sess *Session -// runTest(t, func(s *Session) { -// sess = s -// }) -// sess.Connect() -// // c.AssertClosed(t) fails as the read from the websocket hijacked by wstest never returns -// } +func TestConnect_AllowOrigin_Connects(t *testing.T) { + tbl := []struct { + Origin string // Request's Origin header. Empty means no Origin header. + AllowOrigin string // AllowOrigin config + ExpectConnect bool // Expects a successful WebSocket connection/upgrade + }{ + // Valid Origin + {"http://localhost", "*", true}, + {"http://localhost", "http://localhost", true}, + {"http://localhost:8080", "http://localhost:8080", true}, + {"http://localhost", "http://localhost;https://resgate.io", true}, + {"https://resgate.io", "http://localhost;https://resgate.io", true}, + // Missing Origin + {"", "*", true}, + {"", "https://resgate.io", true}, + // Invalid Origin + {"http://resgate.io", "https://resgate.io", false}, + {"https://resgate.io", "https://api.resgate.io", false}, + {"https://resgate.io:8080", "https://resgate.io", false}, + {"https://resgate.io", "https://resgate.io:8080", false}, + } + + for i, l := range tbl { + l := l + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + var h http.Header + if l.Origin != "" { + h = http.Header{"Origin": {l.Origin}} + } + var c *Conn + if l.ExpectConnect { + c = s.ConnectWithHeader(h) + // Test sending a version request + creq := c.Request("version", versionRequest) + creq.GetResponse(s.t) + } else { + AssertPanic(t, func() { + c = s.ConnectWithHeader(h) + }) + } + }, func(cfg *server.Config) { + cfg.AllowOrigin = &l.AllowOrigin + }) + } +} diff --git a/test/01subscribe_test.go b/test/01subscribe_test.go index 5ae4e64..f155315 100644 --- a/test/01subscribe_test.go +++ b/test/01subscribe_test.go @@ -187,29 +187,29 @@ func TestSubscribe(t *testing.T) { responses := map[string][]string{ // Model responses - "test.model": []string{"test.model"}, - "test.model.parent": []string{"test.model.parent", "test.model"}, - "test.model.grandparent": []string{"test.model.grandparent", "test.model.parent", "test.model"}, - "test.model.secondparent": []string{"test.model.secondparent", "test.model"}, - "test.model.brokenchild": []string{"test.model.brokenchild", "test.err.notFound"}, + "test.model": {"test.model"}, + "test.model.parent": {"test.model.parent", "test.model"}, + "test.model.grandparent": {"test.model.grandparent", "test.model.parent", "test.model"}, + "test.model.secondparent": {"test.model.secondparent", "test.model"}, + "test.model.brokenchild": {"test.model.brokenchild", "test.err.notFound"}, // Cyclic model responses - "test.m.a": []string{"test.m.a"}, - "test.m.b": []string{"test.m.b", "test.m.c"}, - "test.m.d": []string{"test.m.d", "test.m.e", "test.m.f"}, - "test.m.g": []string{"test.m.d", "test.m.e", "test.m.f", "test.m.g"}, - "test.m.h": []string{"test.m.d", "test.m.e", "test.m.f", "test.m.h"}, + "test.m.a": {"test.m.a"}, + "test.m.b": {"test.m.b", "test.m.c"}, + "test.m.d": {"test.m.d", "test.m.e", "test.m.f"}, + "test.m.g": {"test.m.d", "test.m.e", "test.m.f", "test.m.g"}, + "test.m.h": {"test.m.d", "test.m.e", "test.m.f", "test.m.h"}, // Collection responses - "test.collection": []string{"test.collection"}, - "test.collection.parent": []string{"test.collection.parent", "test.collection"}, - "test.collection.grandparent": []string{"test.collection.grandparent", "test.collection.parent", "test.collection"}, - "test.collection.secondparent": []string{"test.collection.secondparent", "test.collection"}, - "test.collection.brokenchild": []string{"test.collection.brokenchild", "test.err.notFound"}, + "test.collection": {"test.collection"}, + "test.collection.parent": {"test.collection.parent", "test.collection"}, + "test.collection.grandparent": {"test.collection.grandparent", "test.collection.parent", "test.collection"}, + "test.collection.secondparent": {"test.collection.secondparent", "test.collection"}, + "test.collection.brokenchild": {"test.collection.brokenchild", "test.err.notFound"}, // Cyclic collection responses - "test.c.a": []string{"test.c.a"}, - "test.c.b": []string{"test.c.b", "test.c.c"}, - "test.c.d": []string{"test.c.d", "test.c.e", "test.c.f"}, - "test.c.g": []string{"test.c.d", "test.c.e", "test.c.f", "test.c.g"}, - "test.c.h": []string{"test.c.d", "test.c.e", "test.c.f", "test.c.h"}, + "test.c.a": {"test.c.a"}, + "test.c.b": {"test.c.b", "test.c.c"}, + "test.c.d": {"test.c.d", "test.c.e", "test.c.f"}, + "test.c.g": {"test.c.d", "test.c.e", "test.c.f", "test.c.g"}, + "test.c.h": {"test.c.d", "test.c.e", "test.c.f", "test.c.h"}, } for i, l := range sequenceTable { diff --git a/test/13query_event_test.go b/test/13query_event_test.go index 30d5389..929cd73 100644 --- a/test/13query_event_test.go +++ b/test/13query_event_test.go @@ -147,9 +147,8 @@ func TestModelQueryEventResponseUpdatesTheCache(t *testing.T) { c2 := s.Connect() // Subscribe a second time creq2 := c2.Request("subscribe.test.model?q=foo&f=bar", nil) - // Handle model get and access request - mreqs2 := s.GetParallelRequests(t, 1) - mreqs2.GetRequest(t, "access.test.model").RespondSuccess(json.RawMessage(`{"get":true}`)) + // Handle model access request + s.GetRequest(t).AssertSubject(t, "access.test.model").RespondSuccess(json.RawMessage(`{"get":true}`)) // Validate client response and validate creq2.GetResponse(t).AssertResult(t, json.RawMessage(`{"models":{"test.model?q=foo&f=bar":{"string":"bar","int":-12,"bool":true,"null":null}}}`)) }) @@ -721,3 +720,32 @@ func TestQueryEvent_DeleteEventOnCollection_DeletesFromCache(t *testing.T) { c.AssertNoNATSRequest(t, "test.collection") }) } + +func TestQueryEvent_MultipleClientsWithDifferentQueries_SendsMultipleQueryRequest(t *testing.T) { + runTest(t, func(s *Session) { + c1 := s.Connect() + c2 := s.Connect() + subscribeToTestQueryCollection(t, s, c1, "q=foo&f=bar", "q=foo&f=bar") + subscribeToTestQueryCollection(t, s, c2, "q=zoo&f=baz", "q=zoo&f=baz") + // Send query event + s.ResourceEvent("test.collection", "query", json.RawMessage(`{"subject":"_EVENT_01_"}`)) + // Respond to query request with an error + mreqs := s.GetParallelRequests(t, 2) + + // Ensure order + if mreqs[0].PathPayload(t, "query").(string) == "q=zoo&f=baz" { + temp := mreqs[0] + mreqs[0] = mreqs[1] + mreqs[1] = temp + } + // Validate query requests + mreqs[0].AssertSubject(t, "_EVENT_01_").AssertPathPayload(t, "query", "q=foo&f=bar") + mreqs[1].AssertSubject(t, "_EVENT_01_").AssertPathPayload(t, "query", "q=zoo&f=baz") + // Send query response + mreqs[0].RespondSuccess(json.RawMessage(`{"events":[]}`)) + mreqs[1].RespondSuccess(json.RawMessage(`{"events":[]}`)) + // Validate no events + c1.AssertNoEvent(t, "test.collection") + c2.AssertNoEvent(t, "test.collection") + }) +} diff --git a/test/14http_get_test.go b/test/14http_get_test.go index 37572a8..d9973b1 100644 --- a/test/14http_get_test.go +++ b/test/14http_get_test.go @@ -222,3 +222,60 @@ func TestHTTPGetOnPrimitiveQueryCollection(t *testing.T) { }) } } + +func TestHTTPGet_AllowOrigin_ExpectedResponse(t *testing.T) { + model := resourceData("test.model") + successResponse := json.RawMessage(model) + + tbl := []struct { + Origin string // Request's Origin header. Empty means no Origin header. + ContentType string // Request's Content-Type header. Empty means no Content-Type header. + AllowOrigin string // AllowOrigin config + ExpectedCode int // Expected response status code + ExpectedHeaders map[string]string // Expected response Headers + ExpectedMissingHeaders []string // Expected response headers not to be included + ExpectedBody interface{} // Expected response body + }{ + {"http://localhost", "", "*", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}, successResponse}, + {"http://localhost", "", "http://localhost", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil, successResponse}, + {"https://resgate.io", "", "http://localhost;https://resgate.io", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "https://resgate.io", "Vary": "Origin"}, nil, successResponse}, + // Invalid requests + {"http://example.com", "", "http://localhost;https://resgate.io", http.StatusForbidden, map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil, reserr.ErrForbiddenOrigin}, + // No Origin header in request + {"", "", "*", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}, successResponse}, + {"", "", "http://localhost", http.StatusOK, nil, []string{"Access-Control-Allow-Origin", "Vary"}, successResponse}, + } + + for i, l := range tbl { + l := l + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + hreq := s.HTTPRequest("GET", "/api/test/model", nil, func(req *http.Request) { + if l.Origin != "" { + req.Header.Set("Origin", l.Origin) + } + if l.ContentType != "" { + req.Header.Set("Content-Type", l.ContentType) + } + }) + + if l.ExpectedCode == http.StatusOK { + // Handle model get and access request + mreqs := s.GetParallelRequests(t, 2) + mreqs. + GetRequest(t, "access.test.model"). + RespondSuccess(json.RawMessage(`{"get":true}`)) + mreqs. + GetRequest(t, "get.test.model"). + RespondSuccess(json.RawMessage(`{"model":` + model + `}`)) + } + + // Validate http response + hreq.GetResponse(t). + Equals(t, l.ExpectedCode, l.ExpectedBody). + AssertHeaders(t, l.ExpectedHeaders). + AssertMissingHeaders(t, l.ExpectedMissingHeaders) + }, func(cfg *server.Config) { + cfg.AllowOrigin = &l.AllowOrigin + }) + } +} diff --git a/test/15http_post_test.go b/test/15http_post_test.go index 28da4db..fd77ad5 100644 --- a/test/15http_post_test.go +++ b/test/15http_post_test.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/resgateio/resgate/server" "github.com/resgateio/resgate/server/mq" "github.com/resgateio/resgate/server/reserr" ) @@ -254,3 +255,59 @@ func TestHTTPPostInvalidURLs(t *testing.T) { }) } } + +func TestHTTPPost_AllowOrigin_ExpectedResponse(t *testing.T) { + successResponse := json.RawMessage(`{"get":true,"call":"*"}`) + + tbl := []struct { + Origin string // Request's Origin header. Empty means no Origin header. + ContentType string // Request's Content-Type header. Empty means no Content-Type header. + AllowOrigin string // AllowOrigin config + ExpectedCode int // Expected response status code + ExpectedHeaders map[string]string // Expected response Headers + ExpectedMissingHeaders []string // Expected response headers not to be included + ExpectedBody interface{} // Expected response body + }{ + {"http://localhost", "", "*", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}, successResponse}, + {"http://localhost", "", "http://localhost", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil, successResponse}, + {"https://resgate.io", "", "http://localhost;https://resgate.io", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "https://resgate.io", "Vary": "Origin"}, nil, successResponse}, + // Invalid requests + {"http://example.com", "", "http://localhost;https://resgate.io", http.StatusForbidden, map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil, reserr.ErrForbiddenOrigin}, + // No Origin header in request + {"", "", "*", http.StatusOK, map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}, successResponse}, + {"", "", "http://localhost", http.StatusOK, nil, []string{"Access-Control-Allow-Origin", "Vary"}, successResponse}, + } + + for i, l := range tbl { + l := l + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + hreq := s.HTTPRequest("POST", "/api/test/model/method", nil, func(req *http.Request) { + if l.Origin != "" { + req.Header.Set("Origin", l.Origin) + } + if l.ContentType != "" { + req.Header.Set("Content-Type", l.ContentType) + } + }) + + if l.ExpectedCode == http.StatusOK { + // Get access request + req := s.GetRequest(t) + req.AssertSubject(t, "access.test.model") + req.RespondSuccess(json.RawMessage(`{"get":true,"call":"*"}`)) + // Get call request + req = s.GetRequest(t) + req.AssertSubject(t, "call.test.model.method") + req.RespondSuccess(successResponse) + } + + // Validate http response + hreq.GetResponse(t). + Equals(t, l.ExpectedCode, l.ExpectedBody). + AssertHeaders(t, l.ExpectedHeaders). + AssertMissingHeaders(t, l.ExpectedMissingHeaders) + }, func(cfg *server.Config) { + cfg.AllowOrigin = &l.AllowOrigin + }) + } +} diff --git a/test/21http_options_test.go b/test/21http_options_test.go new file mode 100644 index 0000000..f53d9a6 --- /dev/null +++ b/test/21http_options_test.go @@ -0,0 +1,44 @@ +package test + +import ( + "fmt" + "net/http" + "testing" + + "github.com/resgateio/resgate/server" +) + +func TestHTTPOptions_AllowOrigin_ExpectedResponseHeaders(t *testing.T) { + tbl := []struct { + Origin string // Request's Origin header. Empty means no Origin header. + AllowOrigin string // AllowOrigin config + ExpectedHeaders map[string]string // Expected response Headers + ExpectedMissingHeaders []string // Expected response headers not to be included + }{ + {"http://localhost", "*", map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}}, + {"http://localhost", "http://localhost", map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil}, + {"https://resgate.io", "http://localhost;https://resgate.io", map[string]string{"Access-Control-Allow-Origin": "https://resgate.io", "Vary": "Origin"}, nil}, + {"http://example.com", "http://localhost;https://resgate.io", map[string]string{"Access-Control-Allow-Origin": "http://localhost", "Vary": "Origin"}, nil}, + // No Origin header in request + {"", "*", map[string]string{"Access-Control-Allow-Origin": "*"}, []string{"Vary"}}, + {"", "http://localhost", nil, []string{"Access-Control-Allow-Origin", "Vary"}}, + } + + for i, l := range tbl { + l := l + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + hreq := s.HTTPRequest("OPTIONS", "/api/test/model", nil, func(req *http.Request) { + if l.Origin != "" { + req.Header.Set("Origin", l.Origin) + } + }) + // Validate http response + hreq.GetResponse(t). + Equals(t, http.StatusOK, nil). + AssertHeaders(t, l.ExpectedHeaders). + AssertMissingHeaders(t, l.ExpectedMissingHeaders) + }, func(cfg *server.Config) { + cfg.AllowOrigin = &l.AllowOrigin + }) + } +} diff --git a/test/22http_method_mapping_test.go b/test/22http_method_mapping_test.go new file mode 100644 index 0000000..6e1d31f --- /dev/null +++ b/test/22http_method_mapping_test.go @@ -0,0 +1,84 @@ +package test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/resgateio/resgate/server" + "github.com/resgateio/resgate/server/reserr" +) + +func TestHTTPMethod_MappedMethod_ExpectedResponse(t *testing.T) { + params := json.RawMessage(`{"foo":"bar"}`) + result := json.RawMessage(`"zoo"`) + method := "mappedMethod" + + tbl := []struct { + Config func(cfg *server.Config) // Server config + Method string // HTTP method to use + CallResponse interface{} // Response on call request. requestTimeout means timeout. noRequest means no call request is expected + ExpectedCode int // Expected response status code + Expected interface{} // Expected response body + }{ + // With mapped methods and success + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PUT", result, http.StatusOK, result}, + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "DELETE", result, http.StatusOK, result}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "PATCH", result, http.StatusOK, result}, + // With mapped methods and error + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PUT", reserr.ErrInvalidParams, http.StatusBadRequest, reserr.ErrInvalidParams}, + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "DELETE", reserr.ErrInvalidParams, http.StatusBadRequest, reserr.ErrInvalidParams}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "PATCH", reserr.ErrInvalidParams, http.StatusBadRequest, reserr.ErrInvalidParams}, + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PUT", requestTimeout, http.StatusNotFound, reserr.ErrTimeout}, + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "DELETE", requestTimeout, http.StatusNotFound, reserr.ErrTimeout}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "PATCH", requestTimeout, http.StatusNotFound, reserr.ErrTimeout}, + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PUT", reserr.ErrAccessDenied, http.StatusUnauthorized, reserr.ErrAccessDenied}, + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "DELETE", reserr.ErrAccessDenied, http.StatusUnauthorized, reserr.ErrAccessDenied}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "PATCH", reserr.ErrAccessDenied, http.StatusUnauthorized, reserr.ErrAccessDenied}, + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PUT", reserr.ErrMethodNotFound, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "DELETE", reserr.ErrMethodNotFound, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "PATCH", reserr.ErrMethodNotFound, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + // Without mapping + {func(cfg *server.Config) { cfg.DELETEMethod = &method }, "PUT", noRequest, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + {func(cfg *server.Config) { cfg.PATCHMethod = &method }, "DELETE", noRequest, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + {func(cfg *server.Config) { cfg.PUTMethod = &method }, "PATCH", noRequest, http.StatusMethodNotAllowed, reserr.ErrMethodNotAllowed}, + } + + for i, l := range tbl { + l := l + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + hreq := s.HTTPRequest(l.Method, "/api/test/model", params) + + if l.CallResponse != noRequest { + // Handle access request + s.GetRequest(t). + AssertSubject(t, "access.test.model"). + RespondSuccess(json.RawMessage(`{"get":true,"call":"*"}`)) + + // Handle call request + req := s.GetRequest(t). + AssertSubject(t, "call.test.model."+method). + AssertPathPayload(t, "params", json.RawMessage(params)) + if l.CallResponse == requestTimeout { + req.Timeout() + } else if err, ok := l.CallResponse.(*reserr.Error); ok { + req.RespondError(err) + } else { + req.RespondSuccess(l.CallResponse) + } + } + + // Validate HTTP response + hresp := hreq.GetResponse(t) + hresp.AssertStatusCode(t, l.ExpectedCode) + if err, ok := l.Expected.(*reserr.Error); ok { + hresp.AssertError(t, err) + } else if code, ok := l.Expected.(string); ok { + hresp.AssertErrorCode(t, code) + } else { + hresp.AssertBody(t, l.Expected) + } + }, l.Config) + } +} diff --git a/test/23http_head_test.go b/test/23http_head_test.go new file mode 100644 index 0000000..1dd461b --- /dev/null +++ b/test/23http_head_test.go @@ -0,0 +1,77 @@ +package test + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/resgateio/resgate/server/reserr" +) + +// Test invalid urls for HTTP get requests +func TestHTTPMethodHEAD_InvalidURLs_CorrectStatus(t *testing.T) { + tbl := []struct { + URL string // Url path + ExpectedCode int + }{ + {"/wrong_prefix/test/model", http.StatusNotFound}, + {"/api/", http.StatusNotFound}, + {"/api/test.model", http.StatusNotFound}, + {"/api/test/model/", http.StatusNotFound}, + {"/api/test//model", http.StatusNotFound}, + {"/api/test/mådel/action", http.StatusNotFound}, + } + + for i, l := range tbl { + runNamedTest(t, fmt.Sprintf("#%d", i+1), func(s *Session) { + s.HTTPRequest("HEAD", l.URL, nil). + GetResponse(t). + AssertStatusCode(t, l.ExpectedCode) + // We don't check the Body as the httptest.ResponseRecorder + // does not discard the written bytes to the body, unlike the + // actual http package. + }) + } +} + +func TestHTTPHead_OnSuccess_NoBody(t *testing.T) { + model := resourceData("test.model") + runTest(t, func(s *Session) { + hreq := s.HTTPRequest("HEAD", "/api/test/model", nil) + + /// Handle model get and access request + mreqs := s.GetParallelRequests(t, 2) + req := mreqs.GetRequest(t, "access.test.model") + req.RespondSuccess(json.RawMessage(`{"get":true}`)) + req = mreqs.GetRequest(t, "get.test.model") + req.RespondSuccess(json.RawMessage(`{"model":` + model + `}`)) + + // Validate http response + hreq.GetResponse(t). + AssertStatusCode(t, http.StatusOK) + // We don't check the Body as the httptest.ResponseRecorder + // does not discard the written bytes to the body, unlike the + // actual http package. + }) +} + +func TestHTTPHead_OnError_NoBody(t *testing.T) { + runTest(t, func(s *Session) { + hreq := s.HTTPRequest("HEAD", "/api/test/model", nil) + + /// Handle model get and access request + mreqs := s.GetParallelRequests(t, 2) + req := mreqs.GetRequest(t, "access.test.model") + req.RespondSuccess(json.RawMessage(`{"get":true}`)) + req = mreqs.GetRequest(t, "get.test.model") + req.RespondError(reserr.ErrNotFound) + + // Validate http response + hreq.GetResponse(t). + AssertStatusCode(t, http.StatusNotFound) + // We don't check the Body as the httptest.ResponseRecorder + // does not discard the written bytes to the body, unlike the + // actual http package. + }) +} diff --git a/test/http.go b/test/http.go index abb3ffa..9c653e8 100644 --- a/test/http.go +++ b/test/http.go @@ -152,7 +152,7 @@ func (hr *HTTPResponse) AssertIsError(t *testing.T) *HTTPResponse { func (hr *HTTPResponse) AssertHeaders(t *testing.T, h map[string]string) *HTTPResponse { for k, v := range h { hv := hr.Result().Header.Get(k) - if hr.Result().Header.Get(k) != v { + if hv != v { if hv == "" { t.Fatalf("expected response header %s to be %s, but header not found", k, v) } else { @@ -162,3 +162,14 @@ func (hr *HTTPResponse) AssertHeaders(t *testing.T, h map[string]string) *HTTPRe } return hr } + +// AssertMissingHeaders asserts that the response does not include the given headers +func (hr *HTTPResponse) AssertMissingHeaders(t *testing.T, h []string) *HTTPResponse { + for _, h := range h { + hv := hr.Result().Header.Get(h) + if hv != "" { + t.Fatalf("expected response header %s to be missing, but got %s", h, hv) + } + } + return hr +} diff --git a/test/natstest.go b/test/natstest.go index c0b8eb4..8ae5021 100644 --- a/test/natstest.go +++ b/test/natstest.go @@ -453,3 +453,15 @@ func (pr ParallelRequests) GetRequest(t *testing.T, subject string) *Request { t.Fatalf("expected parallel requests to contain subject %#v, but found none", subject) return nil } + +// AssertPanic expects the callback function to panic, otherwise +// logs an error with t.Errorf +func AssertPanic(t *testing.T, cb func()) { + defer func() { + v := recover() + if v == nil { + t.Errorf("expected callback to panic, but it didn't") + } + }() + cb() +} diff --git a/test/resources.go b/test/resources.go index 8088996..734405e 100644 --- a/test/resources.go +++ b/test/resources.go @@ -44,49 +44,49 @@ func resourceData(rid string) string { var resources = map[string]resource{ // Model resources - "test.model": resource{typeModel, `{"string":"foo","int":42,"bool":true,"null":null}`, nil}, - "test.model.parent": resource{typeModel, `{"name":"parent","child":{"rid":"test.model"}}`, nil}, - "test.model.secondparent": resource{typeModel, `{"name":"secondparent","child":{"rid":"test.model"}}`, nil}, - "test.model.grandparent": resource{typeModel, `{"name":"grandparent","child":{"rid":"test.model.parent"}}`, nil}, - "test.model.brokenchild": resource{typeModel, `{"name":"brokenchild","child":{"rid":"test.err.notFound"}}`, nil}, + "test.model": {typeModel, `{"string":"foo","int":42,"bool":true,"null":null}`, nil}, + "test.model.parent": {typeModel, `{"name":"parent","child":{"rid":"test.model"}}`, nil}, + "test.model.secondparent": {typeModel, `{"name":"secondparent","child":{"rid":"test.model"}}`, nil}, + "test.model.grandparent": {typeModel, `{"name":"grandparent","child":{"rid":"test.model.parent"}}`, nil}, + "test.model.brokenchild": {typeModel, `{"name":"brokenchild","child":{"rid":"test.err.notFound"}}`, nil}, // Cyclic model resources - "test.m.a": resource{typeModel, `{"a":{"rid":"test.m.a"}}`, nil}, + "test.m.a": {typeModel, `{"a":{"rid":"test.m.a"}}`, nil}, - "test.m.b": resource{typeModel, `{"c":{"rid":"test.m.c"}}`, nil}, - "test.m.c": resource{typeModel, `{"b":{"rid":"test.m.b"}}`, nil}, + "test.m.b": {typeModel, `{"c":{"rid":"test.m.c"}}`, nil}, + "test.m.c": {typeModel, `{"b":{"rid":"test.m.b"}}`, nil}, - "test.m.d": resource{typeModel, `{"e":{"rid":"test.m.e"},"f":{"rid":"test.m.f"}}`, nil}, - "test.m.e": resource{typeModel, `{"d":{"rid":"test.m.d"}}`, nil}, - "test.m.f": resource{typeModel, `{"d":{"rid":"test.m.d"}}`, nil}, + "test.m.d": {typeModel, `{"e":{"rid":"test.m.e"},"f":{"rid":"test.m.f"}}`, nil}, + "test.m.e": {typeModel, `{"d":{"rid":"test.m.d"}}`, nil}, + "test.m.f": {typeModel, `{"d":{"rid":"test.m.d"}}`, nil}, - "test.m.g": resource{typeModel, `{"e":{"rid":"test.m.e"},"f":{"rid":"test.m.f"}}`, nil}, - "test.m.h": resource{typeModel, `{"e":{"rid":"test.m.e"}}`, nil}, + "test.m.g": {typeModel, `{"e":{"rid":"test.m.e"},"f":{"rid":"test.m.f"}}`, nil}, + "test.m.h": {typeModel, `{"e":{"rid":"test.m.e"}}`, nil}, // Collection resources - "test.collection": resource{typeCollection, `["foo",42,true,null]`, nil}, - "test.collection.parent": resource{typeCollection, `["parent",{"rid":"test.collection"}]`, nil}, - "test.collection.secondparent": resource{typeCollection, `["secondparent",{"rid":"test.collection"}]`, nil}, - "test.collection.grandparent": resource{typeCollection, `["grandparent",{"rid":"test.collection.parent"}]`, nil}, - "test.collection.brokenchild": resource{typeCollection, `["brokenchild",{"rid":"test.err.notFound"}]`, nil}, + "test.collection": {typeCollection, `["foo",42,true,null]`, nil}, + "test.collection.parent": {typeCollection, `["parent",{"rid":"test.collection"}]`, nil}, + "test.collection.secondparent": {typeCollection, `["secondparent",{"rid":"test.collection"}]`, nil}, + "test.collection.grandparent": {typeCollection, `["grandparent",{"rid":"test.collection.parent"}]`, nil}, + "test.collection.brokenchild": {typeCollection, `["brokenchild",{"rid":"test.err.notFound"}]`, nil}, // Cyclic collection resources - "test.c.a": resource{typeCollection, `[{"rid":"test.c.a"}]`, nil}, + "test.c.a": {typeCollection, `[{"rid":"test.c.a"}]`, nil}, - "test.c.b": resource{typeCollection, `[{"rid":"test.c.c"}]`, nil}, - "test.c.c": resource{typeCollection, `[{"rid":"test.c.b"}]`, nil}, + "test.c.b": {typeCollection, `[{"rid":"test.c.c"}]`, nil}, + "test.c.c": {typeCollection, `[{"rid":"test.c.b"}]`, nil}, - "test.c.d": resource{typeCollection, `[{"rid":"test.c.e"},{"rid":"test.c.f"}]`, nil}, - "test.c.e": resource{typeCollection, `[{"rid":"test.c.d"}]`, nil}, - "test.c.f": resource{typeCollection, `[{"rid":"test.c.d"}]`, nil}, + "test.c.d": {typeCollection, `[{"rid":"test.c.e"},{"rid":"test.c.f"}]`, nil}, + "test.c.e": {typeCollection, `[{"rid":"test.c.d"}]`, nil}, + "test.c.f": {typeCollection, `[{"rid":"test.c.d"}]`, nil}, - "test.c.g": resource{typeCollection, `[{"rid":"test.c.e"},{"rid":"test.c.f"}]`, nil}, - "test.c.h": resource{typeCollection, `[{"rid":"test.c.e"}]`, nil}, + "test.c.g": {typeCollection, `[{"rid":"test.c.e"},{"rid":"test.c.f"}]`, nil}, + "test.c.h": {typeCollection, `[{"rid":"test.c.e"}]`, nil}, // Errors - "test.err.notFound": resource{typeError, "", reserr.ErrNotFound}, - "test.err.internalError": resource{typeError, "", reserr.ErrInternalError}, - "test.err.timeout": resource{typeError, "", reserr.ErrTimeout}, + "test.err.notFound": {typeError, "", reserr.ErrNotFound}, + "test.err.internalError": {typeError, "", reserr.ErrInternalError}, + "test.err.timeout": {typeError, "", reserr.ErrTimeout}, } // Call responses diff --git a/test/test.go b/test/test.go index 80c4f8b..ee3a760 100644 --- a/test/test.go +++ b/test/test.go @@ -33,7 +33,7 @@ func setup(t *testing.T, cfgs ...func(*server.Config)) *Session { l := NewCountLogger(true, true) c := NewNATSTestClient(l) - serv, err := server.NewService(c, TestConfig(cfgs...)) + serv, err := server.NewService(c, DefaultConfig(cfgs...)) if err != nil { t.Fatalf("error creating new service: %s", err) } @@ -57,8 +57,12 @@ func setup(t *testing.T, cfgs ...func(*server.Config)) *Session { // ConnectWithChannel makes a new mock client websocket connection // with a ClientEvent channel. func (s *Session) ConnectWithChannel(evs chan *ClientEvent) *Conn { + return s.connect(evs, nil) +} + +func (s *Session) connect(evs chan *ClientEvent, h http.Header) *Conn { d := wstest.NewDialer(s.s.GetWSHandlerFunc()) - c, _, err := d.Dial("ws://example.org/", nil) + c, _, err := d.Dial("ws://example.org/", h) if err != nil { panic(err) } @@ -71,7 +75,7 @@ func (s *Session) ConnectWithChannel(evs chan *ClientEvent) *Conn { // Connect makes a new mock client websocket connection // that handshakes with version v1.999.999. func (s *Session) Connect() *Conn { - c := s.ConnectWithChannel(make(chan *ClientEvent, 256)) + c := s.connect(make(chan *ClientEvent, 256), nil) // Send version connect creq := c.Request("version", versionRequest) @@ -86,8 +90,14 @@ func (s *Session) ConnectWithoutVersion() *Conn { return s.ConnectWithChannel(make(chan *ClientEvent, 256)) } +// ConnectWithHeader makes a new mock client websocket connection +// using provided headers. It does not send a version handshake. +func (s *Session) ConnectWithHeader(h http.Header) *Conn { + return s.connect(make(chan *ClientEvent, 256), h) +} + // HTTPRequest sends a request over HTTP -func (s *Session) HTTPRequest(method, url string, body []byte) *HTTPRequest { +func (s *Session) HTTPRequest(method, url string, body []byte, opts ...func(r *http.Request)) *HTTPRequest { r := bytes.NewReader(body) req, err := http.NewRequest(method, url, r) @@ -95,6 +105,10 @@ func (s *Session) HTTPRequest(method, url string, body []byte) *HTTPRequest { panic("test: failed to create new http request: " + err.Error()) } + for _, opt := range opts { + opt(req) + } + // Record the response into a httptest.ResponseRecorder rr := httptest.NewRecorder() @@ -107,7 +121,7 @@ func (s *Session) HTTPRequest(method, url string, body []byte) *HTTPRequest { go func() { s.Tracef("H-> %s %s: %s", method, url, body) s.s.ServeHTTP(rr, req) - s.Tracef("<-H %s %s: %s", method, url, rr.Body.String()) + s.Tracef("<-H %s %s: (%d) %s", method, url, rr.Code, rr.Body.String()) hr.ch <- &HTTPResponse{ResponseRecorder: rr} }() @@ -138,8 +152,8 @@ func teardown(s *Session) { } } -// TestConfig returns a default server configuration used for testing -func TestConfig(cfgs ...func(*server.Config)) server.Config { +// DefaultConfig returns a default server configuration used for testing +func DefaultConfig(cfgs ...func(*server.Config)) server.Config { var cfg server.Config cfg.SetDefault() cfg.NoHTTP = true