diff --git a/CHANGELOG.md b/CHANGELOG.md index efa0aa7cf..8d7cea7a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## v1.0.0-beta4 (unreleased) + +- **[Breaking]** Introduce a config loader, this will allow to override config loading + and use custom dirs to load from. In order to load configs calls to `config.Load()` + should be replaced with `config.NewLoader().Load()`. +- Added `metrics.NopScope` for tests on service.NopHost with tagging capabilities + turned on by default +- Added a command line provider `config.NewCommandLineProvider()`, which can be used + to pass configuration parameters through command line. +- **[Breaking]** `uhttp module` now accepts `http.Handler` as part of module setup. + As part of refactor, RouteHandler is removed from the module registration. +- `Loader.Path() string` is now `Loader.Paths() []string`, to better reflect that + configuration is loaded from multiple directories. +- **[Breaking]** Removed `CreateAuthInfo` interface from auth package. package auth + RegisterFunc now accepts `config.Provider` and `tally.Scope` for initialization. +- **[Breaking]** Removed `auth.Client` access from `service.Host`. `auth.Client` can + now be accessed via `auth.Load()` call. + ## v1.0.0-beta3 (28 Mar 2017) - **[Breaking]** Environment config provider was removed. If you were using diff --git a/auth/README.md b/auth/README.md index fd2aba647..d6b5c0669 100644 --- a/auth/README.md +++ b/auth/README.md @@ -36,7 +36,7 @@ type userAuthClient struct { // embed backend security service client here } -func userAuthClient(info CreateAuthInfo) auth.Client { +func userAuthClient(config config.Provider, scope tally.Scope) auth.Client { return &userAuthClient{} } diff --git a/auth/auth_failure_client.go b/auth/auth_failure_client.go index 355e6b7ff..f817fd6a7 100644 --- a/auth/auth_failure_client.go +++ b/auth/auth_failure_client.go @@ -23,13 +23,19 @@ package auth import ( "context" "errors" + + "go.uber.org/fx/config" + + "github.com/uber-go/tally" ) -type failureClient struct { -} +// FailureClient is used for auth failure testing +var FailureClient = FakeFailureClient(nil, tally.NoopScope) + +type failureClient struct{} // FakeFailureClient fails all auth request and must only be used for testing -func FakeFailureClient(info CreateAuthInfo) Client { +func FakeFailureClient(config config.Provider, scope tally.Scope) Client { return &failureClient{} } diff --git a/auth/auth_stub.go b/auth/auth_stub.go index 826a7041d..ac5596f31 100644 --- a/auth/auth_stub.go +++ b/auth/auth_stub.go @@ -20,11 +20,17 @@ package auth -import "context" +import ( + "context" + + "go.uber.org/fx/config" + + "github.com/uber-go/tally" +) var ( // NopClient is used for testing and no-op integration - NopClient = nopClient(nil) + NopClient = nopClient(nil, tally.NoopScope) _ Client = &nop{} ) @@ -32,7 +38,7 @@ var ( type nop struct { } -func nopClient(info CreateAuthInfo) Client { +func nopClient(config config.Provider, scope tally.Scope) Client { return &nop{} } diff --git a/auth/doc.go b/auth/doc.go index cef756d18..c7a00a731 100644 --- a/auth/doc.go +++ b/auth/doc.go @@ -66,7 +66,7 @@ // // embed backend security service client here // } // -// func userAuthClient(info CreateAuthInfo) auth.Client { +// func userAuthClient(config config.Provider, scope tally.Scope) auth.Client { // return &userAuthClient{} // } // diff --git a/auth/localauth.go b/auth/localauth.go index dad36f814..eac9411e5 100644 --- a/auth/localauth.go +++ b/auth/localauth.go @@ -20,7 +20,13 @@ package auth -import "context" +import ( + "context" + + "go.uber.org/fx/config" + + "github.com/uber-go/tally" +) var _ Client = &defaultClient{} @@ -30,7 +36,7 @@ type defaultClient struct { // defaultAuth is a placeholder auth client when no auth client is registered // TODO(anup): add configurable authentication, whether a service needs one or not -func defaultAuth(info CreateAuthInfo) Client { +func defaultAuth(config config.Provider, scope tally.Scope) Client { return &defaultClient{ authClient: NopClient, } diff --git a/auth/uauth.go b/auth/uauth.go index 1fa185245..718475102 100644 --- a/auth/uauth.go +++ b/auth/uauth.go @@ -24,8 +24,9 @@ import ( "context" "sync" - "github.com/uber-go/tally" "go.uber.org/fx/config" + + "github.com/uber-go/tally" ) var ( @@ -44,14 +45,8 @@ var ( ErrAuthorization = "Error authorizing the service" ) -// CreateAuthInfo interface provides necessary data -type CreateAuthInfo interface { - Config() config.Provider - Metrics() tally.Scope -} - // RegisterFunc is used during service init time to register the Auth client -type RegisterFunc func(info CreateAuthInfo) Client +type RegisterFunc func(config config.Provider, scope tally.Scope) Client // RegisterClient sets up the registerFunc for Auth client initialization func RegisterClient(registerFunc RegisterFunc) { @@ -71,11 +66,11 @@ func UnregisterClient() { } // Load returns a Client instance based on registered auth client implementation -func Load(info CreateAuthInfo) Client { +func Load(config config.Provider, scope tally.Scope) Client { _setupMu.Lock() defer _setupMu.Unlock() if _registerFunc != nil { - return _registerFunc(info) + return _registerFunc(config, scope) } return NopClient } diff --git a/auth/uauth_test.go b/auth/uauth_test.go index fbc7c73aa..f1a3b5297 100644 --- a/auth/uauth_test.go +++ b/auth/uauth_test.go @@ -24,15 +24,12 @@ import ( "context" "testing" - "go.uber.org/fx/config" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/uber-go/tally" - "go.uber.org/zap" ) -func withAuthClientSetup(t *testing.T, registerFunc RegisterFunc, info CreateAuthInfo, fn func()) { +func withAuthClientSetup(t *testing.T, registerFunc RegisterFunc, fn func()) { UnregisterClient() RegisterClient(registerFunc) fn() @@ -40,7 +37,7 @@ func withAuthClientSetup(t *testing.T, registerFunc RegisterFunc, info CreateAut func TestUauth_Stub(t *testing.T) { RegisterClient(defaultAuth) - authClient := Load(fakeAuthInfo{}) + authClient := Load(nil, tally.NoopScope) assert.Equal(t, "auth", authClient.Name()) assert.NotNil(t, authClient) assert.Nil(t, authClient.Authorize(context.Background())) @@ -52,8 +49,8 @@ func TestUauth_Stub(t *testing.T) { } func TestUauth_Register(t *testing.T) { - withAuthClientSetup(t, FakeFailureClient, fakeAuthInfo{}, func() { - authClient := Load(fakeAuthInfo{}) + withAuthClientSetup(t, FakeFailureClient, func() { + authClient := Load(nil, tally.NoopScope) assert.Equal(t, "failure", authClient.Name()) assert.NotNil(t, authClient) err := authClient.Authorize(context.Background()) @@ -67,7 +64,7 @@ func TestUauth_Register(t *testing.T) { } func TestUauth_RegisterPanic(t *testing.T) { - withAuthClientSetup(t, FakeFailureClient, nil, func() { + withAuthClientSetup(t, FakeFailureClient, func() { assert.Panics(t, func() { RegisterClient(FakeFailureClient) }) @@ -75,21 +72,7 @@ func TestUauth_RegisterPanic(t *testing.T) { } func TestUauth_Default(t *testing.T) { - withAuthClientSetup(t, nil, fakeAuthInfo{}, func() { - assert.Equal(t, "nop", Load(fakeAuthInfo{}).Name()) + withAuthClientSetup(t, nil, func() { + assert.Equal(t, "nop", Load(nil, tally.NoopScope).Name()) }) } - -type fakeAuthInfo struct{} - -func (fakeAuthInfo) Config() config.Provider { - return nil -} - -func (fakeAuthInfo) Logger() *zap.Logger { - return zap.NewNop() -} - -func (fakeAuthInfo) Metrics() tally.Scope { - return tally.NoopScope -} diff --git a/config/README.md b/config/README.md index 8895611cf..77ef9d136 100644 --- a/config/README.md +++ b/config/README.md @@ -15,7 +15,7 @@ The configuration system wraps a set of _providers_ that each know how to get values from an underlying source: * Static YAML configuration -* Environment variables +* Command-line flags So by stacking these providers, we can have a priority system for defining configuration that can be overridden by higher priority providers. For example, @@ -65,14 +65,15 @@ fmt.Println("Port is", target.Port) // "Port is 8081" ``` This model respects priority of providers to allow overriding of individual -values. +values. Read [Loading Configuration](#Loading-Configuration) section for more details +about the loading process. ## Provider `Provider` is the interface for anything that can provide values. We provide a few reference implementations (environment and YAML), but you are -free to register your own providers via `config.RegisterProviders()` and -`config.RegisterDynamicProviders`. +free to register your own providers via `RegisterProviders()` and +`RegisterDynamicProviders()`. ### Static configuration providers @@ -80,7 +81,7 @@ Static configuration providers conform to the `Provider` interface and are bootstrapped first. Use these for simple providers such as file-backed or environment-based configuration providers. -### Dynamic Configuration Providers +### Dynamic configuration providers Dynamic configuration providers frequently need some bootstrap configuration to be useful, so UberFx treats them specially. Dynamic configuration providers @@ -117,7 +118,7 @@ fmt.Println(foo) // Output: hello ``` -To get an access to the root element use `config.Root`: +To get an access to the root element use `Root`: ```go root := provider.Get(config.Root).AsString() @@ -138,8 +139,8 @@ If the underlying value cannot be converted to the requested type, `As*` will ## Populate `Populate` is akin to `json.Unmarshal()` in that it takes a pointer to a -custom struct and fills in the fields. It returns a `true` if the requested -fields were found and populated properly, and `false` otherwise. +custom struct or any other type and fills in the fields. It returns an error, +if the requested fields were not populated properly. For example, say we have the following YAML file: @@ -166,34 +167,252 @@ fmt.Println(m.World) Note that any fields you wish to deserialize into must be exported, just like `json.Unmarshal` and friends. +## Environment variables + +The YAML provider supports accepting values from the environment in which the process +runs. For example, consider the following YAML file: + +```yaml +modules: + http: + port: ${HTTP_PORT:3001} +``` + +When it loads the file, the YAML provider looks up the `HTTP_PORT` environment +variable and checks for a value to use. If the YAML provider doesn't find a value, +it uses the provided 3001 default. + +## Command-line arguments + +The command-line provider is a static provider that reads flags passed to a +program and wraps them in the `Provider` interface. Dots in flag names act +as separators for nested values (read about dotted notation in the +[Dynamic configuration providers](Dynamic-configuration-providers) section above). +Commas indicate to the provider that the flag value is an array of values. +For example, command `./service --roles=actor,writer` will set roles to a slice +with two values `[]string{"actor","writer"}`. + +Use the `pflag.CommandLine` global variable to define your own flags: + +```go +type Wonka struct { + Source string + Array []string +} + +type Willy struct { + Name Wonka +} + +func main() { + pflag.CommandLine.String("Name.Source", "default value", "String example") + pflag.CommandLine.Var( + &config.StringSlice{}, + "Name.Array", + "Example of a nested array") + + var v Willy + config.DefaultLoader.Load().Get(config.Root).Populate(&v) + log.Println(v) +} +``` + +If you run this program with arguments +`./main --Name.Source=chocolateFactory --Name.Array=cookie,candy`, it will print +`{{chocolateFactory [cookie candy]}}` + +## Testing + +The `Provider` interface makes unit testing easy. You can use the config +that came loaded with your service or mock it with a static provider. For example, +let's create a calculator type that does operations with two arguments: + +```go +// Operation is a simple binary function. +type Operation func(left, right int) int + +// Calculator evaluates operation Op on its Left and Right fields. +type Calculator struct { + Left int + Right int + Op Operation +} + +func (c Calculator) Eval() int { + return c.Op(c.Left, c.Right) +} +``` + +The calculator constructor needs only `Provider` and it loads configuration from +the root: + +```go +func NewCalculator(cfg Provider) (*Calculator, error){ + calc := &Calculator{} + return calc, cfg.Get(Root).Populate(calc) +} +``` + +`Operation` has a function type, but we can make it configurable. In order for +a provider to know how to deserialize it, `Operation` type needs to implement the +`text.Unmarshaller` interface: + +```go +func (o *Operation) UnmarshalText(text []byte) error { + switch s := string(text); s { + case "+": + *o = func(left, right int) int { return left + right } + case "-": + *o = func(left, right int) int { return left - right } + default: + return fmt.Errorf("unknown operation %q", s) + } + + return nil +} +``` + +To test with a static provider will be easy, define all arguments with the +expected results: + +```go +func TestCalculator_Eval(t *testing.T) { + t.Parallel() + + table := map[string]Provider{ + "1+2": NewStaticProvider(map[string]string{ + "Op": "+", "Left": "1", "Right": "2", "Expected": "3"}), + "1-2": NewStaticProvider(map[string]string{ + "Op": "-", "Left": "2", "Right": "1", "Expected": "1"}), + } + + for name, cfg := range table { + t.Run(name, func(t *testing.T) { + calc, err := NewCalculator(cfg) + require.NoError(t, err) + assert.Equal(t, cfg.Get("Expected").AsInt(), calc.Eval()) + }) + } +} +``` + +Don't forget to test the error path: + +```go +func TestCalculator_Errors(t *testing.T) { + t.Parallel() + + _, err := newCalculator(NewStaticProvider(map[string]string{ + "Op": "*", "Left": "3", "Right": "5" + })) + + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown operation") +} +``` + +For integration/E2E testing you can customize `Loader` to load the +configuration files from either custom folders (`Loader.SetDirs()`) +or custom files (`Loader.SetFiles()`), or you can register providers +on top of the existing providers (`Loader.RegisterProviders()`) that will +override values of the default configs. + +## Utilities + +The `config` package comes with several helpers for writing tests, creating +new providers, and amending existing providers. + +* `NewCachedProvider(p Provider)` returns a new provider that wraps `p` + and caches values in underlying map. It also registers callbacks to track + changes in all cached values, so you can call `cached.Get("something")` + without worrying about latency. It is safe for concurrent use by + multiple goroutines. + +* The `MockDynamicProvider` is a mock provider that can be used to test dynamic + features. It implements `Provider` interface and lets you set values + to trigger change callbacks. + +* Sometimes dynamic providers only let you register one callback per key. + If you want to have multiple keys per callback, use the + `NewMultiCallbackProvider(p Provider)` wrapper. It stores a list of + all callbacks for each value and calls them when a value changes. + **Caution**: provider is locked during callbacks execution, you should try to + make the callbacks as fast as possible. + +* `NopProvider` is useful for testing because it can be embedded in any type + if you are not interested in implementing all Provider methods. + +* `NewProviderGroup(name string, providers ...Provider)` groups providers into + one. Lookups for values are determined by the order providers passed: + + ```go + group := NewProviderGroup("global", provider1, provider2) + value := group.Get("X") + ``` + + The `group` provider checks `provider1` for "X" first. If there is no value, + it returns the result of `provider2.Get()`. + +* `NewStaticProvider(data interface{})` is a very useful wrapper for testing. + You can pass custom maps and use them as configs instead of loading them + from files. + +## Loading Configuration + +The load process is controlled by `Loader`. If a service doesn't +specify a config provider, `service.Manager` is going to use a provider +returned by `DefaultLoader.Load()`. + +The default loader creates static providers first: + +* YAML provider will look for `base.yaml` and `${environment}.yaml` files in + the current directory and then in the `./config` directory. You can override + directories to look for these files with `Loader.SetDirs()`. + To override file names, use `Loader.SetFiles()`. + +* The command-line provider looks for `--roles` argument to specify service + roles. Use `pflags.CommandLine` variable to introduce or override config + values before building a service. + +You can add more static providers on top of those mentioned above with +`RegisterProviders()` function: + +```go +config.DefaultLoader.RegisterProviders( + func() Provider, error { + return config.NewStaticProvider(map[string]int{"1+2": 3}) + } +) +``` + +After static providers are loaded, they are used to create dynamic providers. +You can add new dynamic providers in the loader with the `RegisterDynamicProviders()` +call as well. + +In the end all providers are grouped together using +`NewProviderGroup("global", staticProviders, dynamicProviders)` and returned to +your service. + +If you only want a config, you don't need to build a service. You can use +`DefaultLoader.Load()` and get exactly the same config as `service.Config()`. + +The loader type is customizable, letting you write parallel tests easily. If you +don't want to use the `os.LookupEnv()` function to look for environment variables, +override it with your custom function: `DefaultLoader.SetLookupFn()`. + ### Benchmarks Current performance benchmark data: ``` -BenchmarkYAMLCreateSingleFile-8 119 allocs/op -BenchmarkYAMLCreateMultiFile-8 203 allocs/op +BenchmarkYAMLCreateSingleFile-8 117 allocs/op +BenchmarkYAMLCreateMultiFile-8 204 allocs/op BenchmarkYAMLSimpleGetLevel1-8 0 allocs/op BenchmarkYAMLSimpleGetLevel3-8 0 allocs/op BenchmarkYAMLSimpleGetLevel7-8 0 allocs/op -BenchmarkYAMLPopulate-8 16 allocs/op +BenchmarkYAMLPopulate-8 18 allocs/op BenchmarkYAMLPopulateNested-8 42 allocs/op BenchmarkYAMLPopulateNestedMultipleFiles-8 52 allocs/op -BenchmarkYAMLPopulateNestedTextUnmarshaler-8 211 allocs/op -BenchmarkZapConfigLoad-8 188 allocs/op -``` - -## Environment Variables - -YAML provider supports accepting values from the environment. -For example, consider the following YAML file: - -```yaml -modules: - http: - port: ${HTTP_PORT:3001} -``` - -Upon loading file, YAML provider will look up the HTTP_PORT environment variable -and if available use it's value. If it's not found, the provided `3001` default -will be used. +BenchmarkYAMLPopulateNestedTextUnmarshaler-8 233 allocs/op +BenchmarkZapConfigLoad-8 136 allocs/op +``` \ No newline at end of file diff --git a/config/command_line_provider.go b/config/command_line_provider.go new file mode 100644 index 000000000..616cdef39 --- /dev/null +++ b/config/command_line_provider.go @@ -0,0 +1,109 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "fmt" + "strings" + + flag "github.com/ogier/pflag" +) + +// StringSlice is an alias to string slice, that is used to read comma separated flag values. +type StringSlice []string + +var _ flag.Value = (*StringSlice)(nil) + +// String returns slice elements separated by comma. +func (s *StringSlice) String() string { + return strings.Join(*s, ",") +} + +// Set splits val using comma as separators. +func (s *StringSlice) Set(val string) error { + *s = StringSlice(strings.Split(val, ",")) + return nil +} + +type commandLineProvider struct { + Provider +} + +// NewCommandLineProvider returns a Provider that is using command line parameters as config values. +// In order to address nested elements one can use dots in flag names which are considered separators. +// One can use StringSlice type to work with a list of comma separated strings. +func NewCommandLineProvider(flags *flag.FlagSet, args []string) Provider { + if err := flags.Parse(args); err != nil { + panic(err) + } + + m := make(map[string]interface{}) + flags.VisitAll(func(f *flag.Flag) { + prev, last := traversePath(m, f) + assignValues(prev, last, f.Value) + }) + + return commandLineProvider{Provider: NewStaticProvider(m)} +} + +// Assign values to a map element based on value type. +// If value is a StringSlice - create a new map and with keys - indices and values - StringSlice elements. +// Otherwise just assign it's string value. +func assignValues(m map[string]interface{}, key string, value flag.Value) { + if ss, ok := value.(*StringSlice); ok { + slice := []string(*ss) + tmp := make(map[string]interface{}, len(slice)) + m[key] = tmp + for i, str := range slice { + tmp[fmt.Sprint(i)] = str + } + + return + } + + m[key] = value.String() +} + +// Traverse map with the flag name used as path. +func traversePath(m map[string]interface{}, f *flag.Flag) (map[string]interface{}, string) { + curr, prev := m, m + path := strings.Split(f.Name, _separator) + for _, item := range path { + if _, ok := curr[item]; !ok { + curr[item] = map[string]interface{}{} + } + + prev = curr + if tmp, ok := curr[item].(map[string]interface{}); ok { + curr = tmp + } else { + // This should never happen, because pflag/flag sort flags before calling a visitor, + // but it is better to be safe then sorry. + curr = map[string]interface{}{} + } + } + + return prev, path[len(path)-1] +} + +func (commandLineProvider) Name() string { + return "cmd" +} diff --git a/config/command_line_provider_test.go b/config/command_line_provider_test.go new file mode 100644 index 000000000..364e7f1c8 --- /dev/null +++ b/config/command_line_provider_test.go @@ -0,0 +1,139 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "testing" + + flag "github.com/ogier/pflag" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCommandLineProvider_Roles(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + var s StringSlice + f.Var(&s, "roles", "") + + c := NewCommandLineProvider(f, []string{`--roles=a,b,c"d"`}) + v := c.Get("roles") + require.True(t, v.HasValue()) + var roles []string + require.NoError(t, v.Populate(&roles)) + assert.Equal(t, []string{"a", "b", `c"d"`}, roles) +} + +func TestCommandLineProvider_Default(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + f.String("killerFeature", "minesweeper", "Start->Games->Minesweeper") + + c := NewCommandLineProvider(f, nil) + v := c.Get("killerFeature") + require.True(t, v.HasValue()) + assert.Equal(t, "minesweeper", v.AsString()) +} + +func TestCommandLineProvider_Conversion(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + f.String("dozen", "14", " that number of rolls being allowed to the purchaser of a dozen") + + c := NewCommandLineProvider(f, []string{"--dozen=13"}) + v := c.Get("dozen") + require.True(t, v.HasValue()) + assert.Equal(t, 13, v.AsInt()) +} + +func TestCommandLineProvider_PanicOnUnknownFlags(t *testing.T) { + t.Parallel() + + assert.Panics(t, func() { + NewCommandLineProvider(flag.NewFlagSet("", flag.ContinueOnError), []string{"--boom"}) + }) +} + +func TestCommandLineProvider_Name(t *testing.T) { + t.Parallel() + p := NewCommandLineProvider(flag.NewFlagSet("", flag.PanicOnError), nil) + assert.Equal(t, "cmd", p.Name()) +} + +func TestCommandLineProvider_RepeatingArguments(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + f.Int("count", 1, "If I had a million dollars") + + c := NewCommandLineProvider(f, []string{"--count=2", "--count=3"}) + v := c.Get("count") + require.True(t, v.HasValue()) + assert.Equal(t, "3", v.AsString()) +} + +func TestCommandLineProvider_NestedValues(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + f.String("Name.Source", "default", "Data provider source") + f.Var(&StringSlice{}, "Name.Array", "Example of a nested array") + + c := NewCommandLineProvider(f, []string{"--Name.Source=chocolateFactory", "--Name.Array=cookie, candy,brandy"}) + type Wonka struct { + Source string + Array []string + } + + type Willy struct { + Name Wonka + } + + g := NewProviderGroup("group", NewStaticProvider(Willy{Name: Wonka{Source: "staticFactory"}}), c) + var v Willy + require.NoError(t, g.Get(Root).Populate(&v)) + assert.Equal(t, Willy{Name: Wonka{ + Source: "chocolateFactory", + Array: []string{"cookie", " candy", "brandy"}, + }}, v) +} + +func TestCommandLineProvider_OverlappingFlags(t *testing.T) { + t.Parallel() + + f := flag.NewFlagSet("", flag.PanicOnError) + f.String("Sushi.Tools.1", "Saibashi", "Chopsticks are extremely helpful!") + f.Var(&StringSlice{}, "Sushi.Tools", "yolo") + + c := NewCommandLineProvider(f, []string{"--Sushi.Tools.1=Fork", "--Sushi.Tools=Hocho, Hashi"}) + type Sushi struct { + Tools []string + } + + var v Sushi + require.NoError(t, c.Get("Sushi").Populate(&v)) + assert.Equal(t, Sushi{ + Tools: []string{"Hocho", "Fork"}, + }, v) +} diff --git a/config/config.go b/config/config.go index 723f6f900..6e15b7b7b 100644 --- a/config/config.go +++ b/config/config.go @@ -26,49 +26,82 @@ import ( "path" "path/filepath" "sync" + + flag "github.com/ogier/pflag" ) const ( - // ServiceNameKey is the config key of the service name + // ServiceNameKey is the config key of the service name. ServiceNameKey = "name" + // ServiceDescriptionKey is the config key of the service - // description + // description. ServiceDescriptionKey = "description" - // ServiceOwnerKey is the config key for a service owner + + // ServiceOwnerKey is the config key for a service owner. ServiceOwnerKey = "owner" ) const ( - _appRoot = "APP_ROOT" + _appRoot = "_ROOT" _environment = "_ENVIRONMENT" _configDir = "_CONFIG_DIR" - _configRoot = "./config" - _baseFile = "base" - _secretsFile = "secrets" + _baseFile = "base.yaml" + _secretsFile = "secrets.yaml" + _devEnv = "development" ) -var ( - _setupMux sync.Mutex +type lookUpFunc func(string) (string, bool) - _envPrefix = "APP" - _staticProviderFuncs = []ProviderFunc{YamlProvider()} - _configFiles = baseFiles() - _dynamicProviderFuncs []DynamicProviderFunc -) +// Loader is responsible for loading config providers. +type Loader struct { + lock sync.RWMutex -var ( - _devEnv = "development" -) + envPrefix string + staticProviderFuncs []ProviderFunc + dynamicProviderFuncs []DynamicProviderFunc + + // Files to load. + configFiles []string + + // Dirs to load from. + dirs []string + + // Where to look for environment variables. + lookUp lookUpFunc +} + +// DefaultLoader is going to be used by a service if config is not specified. +// First values are going to be looked in dynamic providers, then in command line provider +// and YAML provider is going to be the last. +var DefaultLoader = NewLoader(commandLineProviderFunc) + +// NewLoader returns a default Loader with providers overriding the YAML provider. +func NewLoader(providers ...ProviderFunc) *Loader { + l := &Loader{ + envPrefix: "APP", + dirs: []string{".", "./config"}, + lookUp: os.LookupEnv, + } + + l.configFiles = l.baseFiles() + // Order is important: we want users to be able to override static provider + l.RegisterProviders(l.YamlProvider()) + l.RegisterProviders(providers...) + + return l +} // AppRoot returns the root directory of your application. UberFx developers // can edit this via the APP_ROOT environment variable. If the environment // variable is not set then it will fallback to the current working directory. -func AppRoot() string { - if appRoot := os.Getenv(_appRoot); appRoot != "" { +func (l *Loader) AppRoot() string { + if appRoot, ok := l.lookUp(l.EnvironmentPrefix() + _appRoot); ok { return appRoot } + if cwd, err := os.Getwd(); err != nil { - panic(fmt.Sprintf("Unable to get the current working directory: %s", err.Error())) + panic(fmt.Sprintf("Unable to get the current working directory: %q", err.Error())) } else { return cwd } @@ -76,139 +109,150 @@ func AppRoot() string { // ResolvePath returns an absolute path derived from AppRoot and the relative path. // If the input parameter is already an absolute path it will be returned immediately. -func ResolvePath(relative string) (string, error) { +func (l *Loader) ResolvePath(relative string) (string, error) { if filepath.IsAbs(relative) { return relative, nil } - abs := path.Join(AppRoot(), relative) + + abs := path.Join(l.AppRoot(), relative) if _, err := os.Stat(abs); err != nil { return "", err } - return abs, nil -} - -func getConfigFiles(fileSet ...string) []string { - var files []string - dirs := []string{".", _configRoot} - for _, dir := range dirs { - for _, baseFile := range fileSet { - files = append(files, fmt.Sprintf("%s/%s.yaml", dir, baseFile)) - } - } - return files + return abs, nil } -func baseFiles() []string { - env := Environment() - return []string{_baseFile, env, _secretsFile} +func (l *Loader) baseFiles() []string { + return []string{_baseFile, l.Environment() + ".yaml", _secretsFile} } -func getResolver() FileResolver { - paths := []string{} - configDir := Path() - if configDir != "" { - paths = []string{configDir} - } - return NewRelativeResolver(paths...) +func (l *Loader) getResolver() FileResolver { + return NewRelativeResolver(l.Paths()...) } // YamlProvider returns function to create Yaml based configuration provider -func YamlProvider() ProviderFunc { +func (l *Loader) YamlProvider() ProviderFunc { return func() (Provider, error) { - return NewYAMLProviderFromFiles(false, getResolver(), getConfigFiles(_configFiles...)...), nil + return NewYAMLProviderFromFiles(false, l.getResolver(), l.getFiles()...), nil } } // Environment returns current environment setup for the service -func Environment() string { - env := os.Getenv(EnvironmentKey()) - if env == "" { - env = _devEnv +func (l *Loader) Environment() string { + if env, ok := l.lookUp(l.EnvironmentKey()); ok { + return env } - return env + + return _devEnv } -// Path returns path to the yaml configurations -func Path() string { - configPath := os.Getenv(EnvironmentPrefix() + _configDir) - if configPath == "" { - configPath = _configRoot +// Paths returns paths to the yaml configurations +func (l *Loader) Paths() []string { + if path, ok := l.lookUp(l.EnvironmentPrefix() + _configDir); ok { + return []string{path} } - return configPath + + return l.dirs +} + +// SetConfigFiles overrides the set of available config files for the service. +func (l *Loader) SetConfigFiles(files ...string) { + l.lock.Lock() + defer l.lock.Unlock() + + l.configFiles = files +} + +func (l *Loader) getFiles() []string { + l.lock.RLock() + defer l.lock.RUnlock() + + res := make([]string, len(l.configFiles)) + copy(res, l.configFiles) + return res } -// SetConfigFiles overrides the set of available config files -// for the service -func SetConfigFiles(files ...string) { - _configFiles = files +// SetDirs overrides the set of dirs to load config files from. +func (l *Loader) SetDirs(dirs ...string) { + l.lock.Lock() + defer l.lock.Unlock() + + l.dirs = dirs } -// SetEnvironmentPrefix sets environment prefix for the application -func SetEnvironmentPrefix(envPrefix string) { - _envPrefix = envPrefix +// SetEnvironmentPrefix sets environment prefix for the application. +func (l *Loader) SetEnvironmentPrefix(envPrefix string) { + l.lock.Lock() + defer l.lock.Unlock() + + l.envPrefix = envPrefix } -// EnvironmentPrefix returns environment prefix for the application -func EnvironmentPrefix() string { - return _envPrefix +// EnvironmentPrefix returns environment prefix for the application. +func (l *Loader) EnvironmentPrefix() string { + l.lock.RLock() + defer l.lock.RUnlock() + + return l.envPrefix } // EnvironmentKey returns environment variable key name -func EnvironmentKey() string { - return _envPrefix + _environment +func (l *Loader) EnvironmentKey() string { + l.lock.RLock() + defer l.lock.RUnlock() + + return l.envPrefix + _environment } -// ProviderFunc is used to create config providers on configuration initialization +// ProviderFunc is used to create config providers on configuration initialization. type ProviderFunc func() (Provider, error) -// DynamicProviderFunc is used to create config providers on configuration initialization +// DynamicProviderFunc is used to create config providers on configuration initialization. type DynamicProviderFunc func(config Provider) (Provider, error) -// RegisterProviders registers configuration providers for the global config -func RegisterProviders(providerFuncs ...ProviderFunc) { - _setupMux.Lock() - defer _setupMux.Unlock() - _staticProviderFuncs = append(_staticProviderFuncs, providerFuncs...) +// RegisterProviders registers configuration providers for the global config. +func (l *Loader) RegisterProviders(providerFuncs ...ProviderFunc) { + l.lock.Lock() + defer l.lock.Unlock() + + l.staticProviderFuncs = append(l.staticProviderFuncs, providerFuncs...) } // RegisterDynamicProviders registers dynamic config providers for the global config // Dynamic provider initialization needs access to Provider for accessing necessary // information for bootstrap, such as port number,keys, endpoints etc. -func RegisterDynamicProviders(dynamicProviderFuncs ...DynamicProviderFunc) { - _setupMux.Lock() - defer _setupMux.Unlock() - _dynamicProviderFuncs = append(_dynamicProviderFuncs, dynamicProviderFuncs...) -} +func (l *Loader) RegisterDynamicProviders(dynamicProviderFuncs ...DynamicProviderFunc) { + l.lock.Lock() + defer l.lock.Unlock() -// Providers should only be used during tests -func Providers() []ProviderFunc { - return _staticProviderFuncs + l.dynamicProviderFuncs = append(l.dynamicProviderFuncs, dynamicProviderFuncs...) } -// UnregisterProviders clears all the default providers -func UnregisterProviders() { - _setupMux.Lock() - defer _setupMux.Unlock() - _staticProviderFuncs = nil - _dynamicProviderFuncs = nil +// UnregisterProviders clears all the default providers. +func (l *Loader) UnregisterProviders() { + l.lock.Lock() + defer l.lock.Unlock() + + l.staticProviderFuncs = nil + l.dynamicProviderFuncs = nil } -// Load creates a Provider for use in a service -func Load() Provider { +// Load creates a Provider for use in a service. +func (l *Loader) Load() Provider { var static []Provider - - for _, providerFunc := range _staticProviderFuncs { + for _, providerFunc := range l.staticProviderFuncs { cp, err := providerFunc() if err != nil { panic(err) } + static = append(static, cp) } + baseCfg := NewProviderGroup("global", static...) - var dynamic = make([]Provider, 0, 2) - for _, providerFunc := range _dynamicProviderFuncs { + var dynamic []Provider + for _, providerFunc := range l.dynamicProviderFuncs { cp, err := providerFunc(baseCfg) if err != nil { panic(err) @@ -217,5 +261,20 @@ func Load() Provider { dynamic = append(dynamic, cp) } } + return NewProviderGroup("global", append(static, dynamic...)...) } + +// SetLookupFn sets the lookup function to get environment variables. +func (l *Loader) SetLookupFn(fn func(string) (string, bool)) { + l.lock.Lock() + defer l.lock.Unlock() + + l.lookUp = fn +} + +func commandLineProviderFunc() (Provider, error) { + var s StringSlice + flag.CommandLine.Var(&s, "roles", "") + return NewCommandLineProvider(flag.CommandLine, os.Args[1:]), nil +} diff --git a/config/config_test.go b/config/config_test.go index 646b69bf8..0d752b662 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -21,13 +21,15 @@ package config import ( + "bytes" + "errors" "fmt" + "io/ioutil" "os" "path" + "sync" "testing" - "go.uber.org/fx/testutils/env" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -96,11 +98,19 @@ type arrayOfStructs struct { } func TestGlobalConfig(t *testing.T) { - SetEnvironmentPrefix("TEST") - cfg := Load() + t.Parallel() + + l := NewLoader() + l.lookUp = func(string) (string, bool) { + return "", false + } + + l.SetConfigFiles("base", "development") + l.SetEnvironmentPrefix("TEST") + cfg := l.Load() assert.Equal(t, "global", cfg.Name()) - assert.Equal(t, "development", Environment()) + assert.Equal(t, "development", l.Environment()) cfg = NewProviderGroup("test", NewYAMLProviderFromBytes([]byte(`name: sample`))) assert.Equal(t, "test", cfg.Name()) @@ -203,6 +213,8 @@ func TestGetAsFloatValue(t *testing.T) { } func TestNestedStructs(t *testing.T) { + t.Parallel() + provider := NewProviderGroup( "test", NewYAMLProviderFromBytes(nestedYaml), @@ -225,6 +237,8 @@ func TestNestedStructs(t *testing.T) { } func TestArrayOfStructs(t *testing.T) { + t.Parallel() + provider := NewProviderGroup( "test", NewYAMLProviderFromBytes(structArrayYaml), @@ -241,6 +255,8 @@ func TestArrayOfStructs(t *testing.T) { } func TestDefault(t *testing.T) { + t.Parallel() + provider := NewProviderGroup( "test", NewYAMLProviderFromBytes(nest1), @@ -266,64 +282,48 @@ boolean: } func TestRegisteredProvidersInitialization(t *testing.T) { - RegisterProviders(StaticProvider(map[string]interface{}{ + t.Parallel() + + l := NewLoader() + l.RegisterProviders(StaticProvider(map[string]interface{}{ "hello": "world", })) - RegisterDynamicProviders(func(dynamic Provider) (Provider, error) { + + l.RegisterDynamicProviders(func(dynamic Provider) (Provider, error) { return NewStaticProvider(map[string]interface{}{ "dynamic": "provider", }), nil }) - cfg := Load() + + cfg := l.Load() assert.Equal(t, "global", cfg.Name()) assert.Equal(t, "world", cfg.Get("hello").AsString()) assert.Equal(t, "provider", cfg.Get("dynamic").AsString()) - UnregisterProviders() - assert.Nil(t, _staticProviderFuncs) - assert.Nil(t, _dynamicProviderFuncs) + l.UnregisterProviders() + assert.Nil(t, l.staticProviderFuncs) + assert.Nil(t, l.dynamicProviderFuncs) } func TestNilProvider(t *testing.T) { - RegisterProviders(func() (Provider, error) { - return nil, fmt.Errorf("error creating Provider") + t.Parallel() + + l := NewLoader() + l.RegisterProviders(func() (Provider, error) { + return nil, errors.New("error creating Provider") }) - assert.Panics(t, func() { Load() }, "Can't initialize with nil provider") - oldProviders := _staticProviderFuncs - defer func() { - _staticProviderFuncs = oldProviders - }() + assert.Panics(t, func() { l.Load() }, "Can't initialize with nil provider") - UnregisterProviders() - RegisterProviders(func() (Provider, error) { + l.UnregisterProviders() + l.RegisterProviders(func() (Provider, error) { return nil, nil }) - // don't panic on Load - Load() - UnregisterProviders() - assert.Nil(t, _staticProviderFuncs) -} - -func TestGetConfigFiles(t *testing.T) { - SetEnvironmentPrefix("TEST") - - files := getConfigFiles(baseFiles()...) - assert.Contains(t, files, "./base.yaml") - assert.Contains(t, files, "./development.yaml") - assert.Contains(t, files, "./secrets.yaml") - assert.Contains(t, files, "./config/base.yaml") - assert.Contains(t, files, "./config/development.yaml") - assert.Contains(t, files, "./config/secrets.yaml") -} + // don't panic on Load + l.Load() -func TestSetConfigFiles(t *testing.T) { - SetConfigFiles("x", "y") - files := getConfigFiles(_configFiles...) - assert.Contains(t, files, "./x.yaml") - assert.Contains(t, files, "./y.yaml") - assert.Contains(t, files, "./config/x.yaml") - assert.Contains(t, files, "./config/y.yaml") + l.UnregisterProviders() + assert.Nil(t, l.staticProviderFuncs) } func expectedResolvePath(t *testing.T) string { @@ -333,26 +333,37 @@ func expectedResolvePath(t *testing.T) string { } func TestResolvePath(t *testing.T) { - res, err := ResolvePath("testdata") + t.Parallel() + + l := NewLoader() + + res, err := l.ResolvePath("testdata") assert.NoError(t, err) assert.Equal(t, expectedResolvePath(t), res) } func TestResolvePathInvalid(t *testing.T) { - res, err := ResolvePath("invalid") + t.Parallel() + + l := NewLoader() + res, err := l.ResolvePath("invalid") assert.Error(t, err) assert.Equal(t, "", res) } func TestResolvePathAbs(t *testing.T) { + t.Parallel() + + l := NewLoader() abs := expectedResolvePath(t) - res, err := ResolvePath(abs) + res, err := l.ResolvePath(abs) assert.NoError(t, err) assert.Equal(t, abs, res) } func TestNopProvider_Get(t *testing.T) { t.Parallel() + p := NopProvider{} assert.Equal(t, "NopProvider", p.Name()) assert.NoError(t, p.RegisterChangeCallback("key", nil)) @@ -439,7 +450,7 @@ ps: } func TestRPCPortField(t *testing.T) { - defer env.Override(t, "COMPANY_TCHANNEL_PORT", "4324")() + t.Parallel() type Port int type TChannelOutbound struct { @@ -473,10 +484,15 @@ rpc: host: 127.0.0.1 port: ${COMPANY_TCHANNEL_PORT:321} ` - p := NewProviderGroup( - "test", - NewYAMLProviderFromBytes([]byte(rpc)), - ) + lookup := func(key string) (string, bool) { + if key == "COMPANY_TCHANNEL_PORT" { + return "4324", true + } + + return "", false + } + + p := newYAMLProviderCore(lookup, ioutil.NopCloser(bytes.NewBufferString(rpc))) cfg := &YARPCConfig{} v := p.Get("rpc") @@ -484,3 +500,122 @@ rpc: require.NoError(t, v.Populate(cfg)) require.Equal(t, 4324, int(*cfg.Outbounds[0].TChannel.Port)) } + +func TestLoader_Environment(t *testing.T) { + t.Parallel() + + l := NewLoader() + l.SetLookupFn(func(key string) (string, bool) { + require.Equal(t, "APP_ENVIRONMENT", key) + return "KGBeast", true + }) + + assert.Equal(t, "KGBeast", l.Environment()) +} + +func TestLoader_AppRoot(t *testing.T) { + t.Parallel() + + l := NewLoader() + l.SetLookupFn(func(key string) (string, bool) { + require.Equal(t, "APP_ROOT", key) + return "Harley Quinn", true + }) + + assert.Equal(t, "Harley Quinn", l.AppRoot()) +} + +func TestLoader_LoadPanicOnDynamicError(t *testing.T) { + t.Parallel() + + l := NewLoader() + l.RegisterDynamicProviders(func(config Provider) (Provider, error) { return nil, errors.New("something scary") }) + + assert.Panics(t, func() { l.Load() }) +} + +func withBase(t *testing.T, f func(dir string), contents string) { + dir, err := ioutil.TempDir("", "TestLoader_Dirs") + require.NoError(t, err) + + defer func() { require.NoError(t, os.Remove(dir)) }() + + base, err := os.Create(fmt.Sprintf("%s/base.yaml", dir)) + require.NoError(t, err) + defer os.Remove(base.Name()) + + base.WriteString(contents) + base.Close() + + f(dir) +} + +func TestLoader_Dirs(t *testing.T) { + t.Parallel() + + f := func(dir string) { + l := NewLoader() + l.SetDirs(dir) + p := l.Load() + assert.Equal(t, "jocker", p.Get("vilain").String()) + } + + withBase(t, f, "vilain: jocker") +} + +func TestParallelLoad(t *testing.T) { + t.Parallel() + + l := NewLoader() + + f := func(dir string) { + l.SetDirs(dir) + p := l.Load() + assert.Equal(t, "bane", p.Get("vilain").String()) + } + + wg := sync.WaitGroup{} + wg.Add(2) + op := func() { + withBase(t, f, "vilain: bane") + wg.Done() + } + + go op() + go op() + + wg.Wait() +} + +func TestZeroInitializeLoader(t *testing.T) { + t.Parallel() + var l Loader + assert.NotPanics(t, func() { l.Load() }) +} + +func TestLoader_StaticProviderOrder(t *testing.T) { + t.Parallel() + f := func(dir string) { + l := NewLoader(func() (Provider, error) { + return NewStaticProvider(map[string]string{"value": "correct"}), nil + }) + + l.SetDirs(dir) + p := l.Load() + assert.Equal(t, "correct", p.Get("value").AsString()) + } + + withBase(t, f, "value: wrong") +} + +func TestLoader_LoadFromCurrentFolder(t *testing.T) { + t.Parallel() + f := func(dir string) { + l := NewLoader() + l.SetConfigFiles(dir + "/base.yaml") + p := l.Load() + assert.Equal(t, "base", p.Get("value").AsString()) + } + + withBase(t, f, "value: base") +} diff --git a/config/decoder.go b/config/decoder.go index 7d9810891..8f7b25d9b 100644 --- a/config/decoder.go +++ b/config/decoder.go @@ -386,14 +386,7 @@ func (d *decoder) iface(key string, value reflect.Value, def string) error { // Sets value to an object type. func (d *decoder) object(childKey string, value reflect.Value) error { - value = value.Addr() - - if value.IsNil() { - tmp := reflect.New(value.Type().Elem()) - value.Set(tmp) - } - - return d.valueStruct(childKey, value.Interface()) + return d.valueStruct(childKey, value.Addr().Interface()) } // Walk through the struct and start asking the providers for values at each key. @@ -408,7 +401,7 @@ func (d *decoder) valueStruct(key string, target interface{}) error { field := targetType.Field(i) // Check for the private field - if field.PkgPath != "" && !field.Anonymous { + if field.PkgPath != "" || field.Anonymous { continue } diff --git a/config/doc.go b/config/doc.go index 36f5f0b1c..b1d2d807c 100644 --- a/config/doc.go +++ b/config/doc.go @@ -39,7 +39,7 @@ // // • Static YAML configuration // -// • Environment variables +// • Command-line flags // // So by stacking these providers, we can have a priority system for defining // configuration that can be overridden by higher priority providers. For example, @@ -84,7 +84,9 @@ // fmt.Println("Port is", target.Port) // "Port is 8081" // // This model respects priority of providers to allow overriding of individual -// values. +// values. Read +// Loading Configuration (#Loading-Configuration) section for more details +// about the loading process. // // // Provider @@ -92,8 +94,8 @@ // Provider is the interface for anything that can provide values. // We provide a few reference implementations (environment and YAML), but you are // free to register your own providers via -// config.RegisterProviders() and -// config.RegisterDynamicProviders. +// RegisterProviders() and +// RegisterDynamicProviders(). // // Static configuration providers // @@ -102,7 +104,7 @@ // environment-based configuration providers. // // -// Dynamic Configuration Providers +// Dynamic configuration providers // // Dynamic configuration providers frequently need some bootstrap configuration to // be useful, so UberFx treats them specially. Dynamic configuration providers @@ -139,7 +141,7 @@ // fmt.Println(foo) // // Output: hello // -// To get an access to the root element use config.Root: +// To get an access to the root element use Root: // // root := provider.Get(config.Root).AsString() // fmt.Println(root) @@ -159,10 +161,9 @@ // Populate // // Populate is akin to json.Unmarshal() in that it takes a pointer to a -// custom struct and fills in the fields. It returns a -// true if the requested -// fields were found and populated properly, and -// false otherwise. +// custom struct or any other type and fills in the fields. It returns an error, +// if the requested fields were not populated properly. +// // // For example, say we have the following YAML file: // @@ -185,36 +186,267 @@ // Note that any fields you wish to deserialize into must be exported, just like // json.Unmarshal and friends. // +// Environment variables +// +// The YAML provider supports accepting values from the environment in which the process +// runs. For example, consider the following YAML file: +// +// +// modules: +// http: +// port: ${HTTP_PORT:3001} +// +// When it loads the file, the YAML provider looks up the HTTP_PORT environment +// variable and checks for a value to use. If the YAML provider doesn't find a value, +// it uses the provided 3001 default. +// +// +// Command-line arguments +// +// The command-line provider is a static provider that reads flags passed to a +// program and wraps them in the +// Provider interface. Dots in flag names act +// as separators for nested values (read about dotted notation in the +// Dynamic configuration providers (Dynamic-configuration-providers) section above). +// Commas indicate to the provider that the flag value is an array of values. +// For example, command +// ./service --roles=actor,writer will set roles to a slice +// with two values +// []string{"actor","writer"}. +// +// Use the pflag.CommandLine global variable to define your own flags: +// +// type Wonka struct { +// Source string +// Array []string +// } +// +// type Willy struct { +// Name Wonka +// } +// +// func main() { +// pflag.CommandLine.String("Name.Source", "default value", "String example") +// pflag.CommandLine.Var( +// &config.StringSlice{}, +// "Name.Array", +// "Example of a nested array") +// +// var v Willy +// config.DefaultLoader.Load().Get(config.Root).Populate(&v) +// log.Println(v) +// } +// +// If you run this program with arguments +// ./main --Name.Source=chocolateFactory --Name.Array=cookie,candy, it will print +// {{chocolateFactory [cookie candy]}} +// +// Testing +// +// The Provider interface makes unit testing easy. You can use the config +// that came loaded with your service or mock it with a static provider. For example, +// let's create a calculator type that does operations with two arguments: +// +// +// // Operation is a simple binary function. +// type Operation func(left, right int) int +// +// // Calculator evaluates operation Op on its Left and Right fields. +// type Calculator struct { +// Left int +// Right int +// Op Operation +// } +// +// func (c Calculator) Eval() int { +// return c.Op(c.Left, c.Right) +// } +// +// The calculator constructor needs only Provider and it loads configuration from +// the root: +// +// +// func NewCalculator(cfg Provider) (*Calculator, error){ +// calc := &Calculator{} +// return calc, cfg.Get(Root).Populate(calc) +// } +// +// Operation has a function type, but we can make it configurable. In order for +// a provider to know how to deserialize it, +// Operation type needs to implement the +// text.Unmarshaller interface: +// +// func (o *Operation) UnmarshalText(text []byte) error { +// switch s := string(text); s { +// case "+": +// *o = func(left, right int) int { return left + right } +// case "-": +// *o = func(left, right int) int { return left - right } +// default: +// return fmt.Errorf("unknown operation %q", s) +// } +// +// return nil +// } +// +// To test with a static provider will be easy, define all arguments with the +// expected results: +// +// +// func TestCalculator_Eval(t *testing.T) { +// t.Parallel() +// +// table := map[string]Provider{ +// "1+2": NewStaticProvider(map[string]string{ +// "Op": "+", "Left": "1", "Right": "2", "Expected": "3"}), +// "1-2": NewStaticProvider(map[string]string{ +// "Op": "-", "Left": "2", "Right": "1", "Expected": "1"}), +// } +// +// for name, cfg := range table { +// t.Run(name, func(t *testing.T) { +// calc, err := NewCalculator(cfg) +// require.NoError(t, err) +// assert.Equal(t, cfg.Get("Expected").AsInt(), calc.Eval()) +// }) +// } +// } +// +// Don't forget to test the error path:: +// +// func TestCalculator_Errors(t *testing.T) { +// t.Parallel() +// +// _, err := newCalculator(NewStaticProvider(map[string]string{ +// "Op": "*", "Left": "3", "Right": "5" +// })) +// +// require.Error(t, err) +// assert.Contains(t, err.Error(), `unknown operation "*"`) +// } +// +// For integration/E2E testing you can customize Loader to load the +// configuration files from either custom folders ( +// Loader.SetDirs()) +// or custom files ( +// Loader.SetFiles()), or you can register providers +// on top of the existing providers ( +// Loader.RegisterProviders()) that will +// override values of the default configs. +// +// +// Utilities +// +// The config package comes with several helpers for writing tests, creating +// new providers, and amending existing providers. +// +// +// • NewCachedProvider(p Provider) returns a new provider that wraps pand caches values in underlying map. It also registers callbacks to track +// changes in all cached values, so you can call +// cached.Get("something")without worrying about latency. It is safe for concurrent use by +// multiple goroutines. +// +// +// • The MockDynamicProvider is a mock provider that can be used to test dynamic +// features. It implements +// Provider interface and lets you set values +// to trigger change callbacks. +// +// +// • Sometimes dynamic providers only let you register one callback per key. +// If you want to have multiple keys per callback, use the +// NewMultiCallbackProvider(p Provider) wrapper. It stores a list of +// all callbacks for each value and calls them when a value changes. +// **Caution**: provider is locked during callbacks execution, you should try to +// make the callbacks as fast as possible. +// +// +// • NopProvider is useful for testing because it can be embedded in any type +// if you are not interested in implementing all Provider methods. +// +// +// • NewProviderGroup(name string, providers ...Provider) groups providers into +// one. Lookups for values are determined by the order providers passed: +// +// +// group := NewProviderGroup("global", provider1, provider2) +// value := group.Get("X") +// +// The group provider checks provider1 for "X" first. If there is no value, +// it returns the result of +// provider2.Get(). +// +// • NewStaticProvider(data interface{}) is a very useful wrapper for testing. +// You can pass custom maps and use them as configs instead of loading them +// from files. +// +// +// Loading Configuration +// +// The load process is controlled by Loader. If a service doesn't +// specify a config provider, +// service.Manager is going to use a provider +// returned by +// DefaultLoader.Load(). +// +// The default loader creates static providers first: +// +// • YAML provider will look for base.yaml and ${environment}.yaml files in +// the current directory and then in the +// ./config directory. You can override +// directories to look for these files with +// Loader.SetDirs(). +// To override file names, use +// Loader.SetFiles(). +// +// • The command-line provider looks for --roles argument to specify service +// roles. Use +// pflags.CommandLine variable to introduce or override config +// values before building a service. +// +// +// You can add more static providers on top of those mentioned above with +// RegisterProviders() function: +// +// config.DefaultLoader.RegisterProviders( +// func() Provider, error { +// return config.NewStaticProvider(map[string]int{"1+2": 3}) +// } +// ) +// +// After static providers are loaded, they are used to create dynamic providers. +// You can add new dynamic providers in the loader with the +// RegisterDynamicProviders()call as well. +// +// +// In the end all providers are grouped together using +// NewProviderGroup("global", staticProviders, dynamicProviders) and returned to +// your service. +// +// +// If you only want a config, you don't need to build a service. You can use +// DefaultLoader.Load() and get exactly the same config as service.Config(). +// +// The loader type is customizable, letting you write parallel tests easily. If you +// don't want to use the +// os.LookupEnv() function to look for environment variables, +// override it with your custom function: +// DefaultLoader.SetLookupFn(). +// // Benchmarks // // Current performance benchmark data: // -// BenchmarkYAMLCreateSingleFile-8 119 allocs/op -// BenchmarkYAMLCreateMultiFile-8 203 allocs/op +// BenchmarkYAMLCreateSingleFile-8 117 allocs/op +// BenchmarkYAMLCreateMultiFile-8 204 allocs/op // BenchmarkYAMLSimpleGetLevel1-8 0 allocs/op // BenchmarkYAMLSimpleGetLevel3-8 0 allocs/op // BenchmarkYAMLSimpleGetLevel7-8 0 allocs/op -// BenchmarkYAMLPopulate-8 16 allocs/op +// BenchmarkYAMLPopulate-8 18 allocs/op // BenchmarkYAMLPopulateNested-8 42 allocs/op // BenchmarkYAMLPopulateNestedMultipleFiles-8 52 allocs/op -// BenchmarkYAMLPopulateNestedTextUnmarshaler-8 211 allocs/op -// BenchmarkZapConfigLoad-8 188 allocs/op -// -// Environment Variables -// -// YAML provider supports accepting values from the environment. -// For example, consider the following YAML file: -// -// -// modules: -// http: -// port: ${HTTP_PORT:3001} -// -// Upon loading file, YAML provider will look up the HTTP_PORT environment variable -// and if available use it's value. If it's not found, the provided -// 3001 default -// will be used. -// +// BenchmarkYAMLPopulateNestedTextUnmarshaler-8 233 allocs/op +// BenchmarkZapConfigLoad-8 136 allocs/op // // package config diff --git a/config/file_resolver.go b/config/file_resolver.go index 59e63c66b..51890d44d 100644 --- a/config/file_resolver.go +++ b/config/file_resolver.go @@ -39,14 +39,7 @@ type RelativeResolver struct { // NewRelativeResolver returns a file resolver relative to the given paths func NewRelativeResolver(paths ...string) FileResolver { pathList := make([]string, len(paths)) - copy(pathList, paths) - - pathList = append(pathList, AppRoot()) - - // add the exe dir - pathList = append(pathList, path.Dir(os.Args[0])) - return &RelativeResolver{ paths: pathList, } diff --git a/config/provider_group.go b/config/provider_group.go index 521a46441..22b84423d 100644 --- a/config/provider_group.go +++ b/config/provider_group.go @@ -25,23 +25,20 @@ type providerGroup struct { providers []Provider } -// NewProviderGroup creates a configuration provider from a group of backends +// NewProviderGroup creates a configuration provider from a group of backends. +// The highest priority provider is the last. func NewProviderGroup(name string, providers ...Provider) Provider { - group := providerGroup{ - name: name, + l := len(providers) + p := providerGroup{ + name: name, + providers: make([]Provider, l), } - for _, provider := range providers { - group.providers = append([]Provider{provider}, group.providers...) - } - return group -} -// WithProvider updates the current Provider -func (p providerGroup) WithProvider(provider Provider) Provider { - return providerGroup{ - name: p.name, - providers: append([]Provider{provider}, p.providers...), + for i := 0; i < l; i++ { + p.providers[i] = providers[l-i-1] } + + return p } func (p providerGroup) Get(key string) Value { diff --git a/config/provider_group_test.go b/config/provider_group_test.go index d1c2e26f7..99b3c1240 100644 --- a/config/provider_group_test.go +++ b/config/provider_group_test.go @@ -47,8 +47,7 @@ func TestProviderGroupScope(t *testing.T) { func TestCallbacks_WithDynamicProvider(t *testing.T) { t.Parallel() data := map[string]interface{}{"hello.world": 42} - mock := NewProviderGroup("with-dynamic", NewStaticProvider(data)) - mock = mock.(providerGroup).WithProvider(NewMockDynamicProvider(data)) + mock := NewProviderGroup("with-dynamic", NewStaticProvider(data), NewMockDynamicProvider(data)) assert.Equal(t, "with-dynamic", mock.Name()) require.NoError(t, mock.RegisterChangeCallback("mockcall", nil)) @@ -65,7 +64,6 @@ func TestCallbacks_WithoutDynamicProvider(t *testing.T) { t.Parallel() data := map[string]interface{}{"hello.world": 42} mock := NewProviderGroup("with-dynamic", NewStaticProvider(data)) - mock = mock.(providerGroup).WithProvider(NewStaticProvider(data)) assert.Equal(t, "with-dynamic", mock.Name()) assert.NoError(t, mock.RegisterChangeCallback("mockcall", nil)) assert.NoError(t, mock.UnregisterChangeCallback("mock")) diff --git a/config/static_provider.go b/config/static_provider.go index 45fd9fd38..f0c7845e1 100644 --- a/config/static_provider.go +++ b/config/static_provider.go @@ -34,7 +34,7 @@ func NewStaticProvider(data interface{}) Provider { panic(err) } - return staticProvider{NewYAMLProviderFromBytes(b)} + return staticProvider{Provider: NewYAMLProviderFromBytes(b)} } // StaticProvider returns function to create StaticProvider during configuration initialization @@ -47,5 +47,3 @@ func StaticProvider(data interface{}) ProviderFunc { func (staticProvider) Name() string { return "static" } - -var _ Provider = &staticProvider{} diff --git a/config/static_provider_test.go b/config/static_provider_test.go index 05ebaeb90..1eacf0dfb 100644 --- a/config/static_provider_test.go +++ b/config/static_provider_test.go @@ -208,3 +208,13 @@ func TestPopulateForNestedMaps(t *testing.T) { assert.Contains(t, err.Error(), `empty map key is ambigious`) assert.Contains(t, err.Error(), `a.`) } + +func TestPopulateNonPointerType(t *testing.T) { + t.Parallel() + + p := NewStaticProvider(42) + x := 13 + err := p.Get(Root).Populate(x) + require.Error(t, err) + assert.Contains(t, err.Error(), "can't populate non pointer type") +} diff --git a/config/testdata/secrets.yaml b/config/testdata/secrets.yaml index 485ec4a7e..21dc509b2 100644 --- a/config/testdata/secrets.yaml +++ b/config/testdata/secrets.yaml @@ -1 +1 @@ -secret: my_secret +secret: my_${secret} diff --git a/config/value.go b/config/value.go index 07e36e1e6..a1bf281b9 100644 --- a/config/value.go +++ b/config/value.go @@ -295,6 +295,10 @@ func convertValue(value interface{}, targetType reflect.Type) (interface{}, erro // Populate fills in an object from configuration func (cv Value) Populate(target interface{}) error { + if reflect.TypeOf(target).Kind() != reflect.Ptr { + return fmt.Errorf("can't populate non pointer type %T", target) + } + d := decoder{Value: &cv, m: make(map[interface{}]struct{})} return d.unmarshal(cv.key, reflect.Indirect(reflect.ValueOf(target)), "") diff --git a/config/yaml.go b/config/yaml.go index 4b09a2a25..87a9f72a3 100644 --- a/config/yaml.go +++ b/config/yaml.go @@ -43,11 +43,11 @@ var ( _emptyDefault = `""` ) -func newYAMLProviderCore(files ...io.ReadCloser) Provider { +func newYAMLProviderCore(lookUp lookUpFunc, files ...io.ReadCloser) Provider { var root interface{} for _, v := range files { var curr interface{} - if err := unmarshalYAMLValue(v, &curr); err != nil { + if err := unmarshalYAMLValue(v, &curr, lookUp); err != nil { if file, ok := v.(*os.File); ok { panic(errors.Wrapf(err, "in file: %q", file.Name())) } @@ -132,13 +132,13 @@ func NewYAMLProviderFromFiles(mustExist bool, resolver FileResolver, files ...st } } - return NewCachedProvider(newYAMLProviderCore(readers...)) + return NewCachedProvider(newYAMLProviderCore(os.LookupEnv, readers...)) } // NewYAMLProviderFromReader creates a configuration provider from a list of `io.ReadClosers`. // As above, all the objects are going to be merged and arrays/values overridden in the order of the files. func NewYAMLProviderFromReader(readers ...io.ReadCloser) Provider { - return NewCachedProvider(newYAMLProviderCore(readers...)) + return NewCachedProvider(newYAMLProviderCore(os.LookupEnv, readers...)) } // NewYAMLProviderFromBytes creates a config provider from a byte-backed YAML blobs. @@ -149,7 +149,7 @@ func NewYAMLProviderFromBytes(yamls ...[]byte) Provider { closers[i] = ioutil.NopCloser(bytes.NewReader(yml)) } - return NewCachedProvider(newYAMLProviderCore(closers...)) + return NewCachedProvider(newYAMLProviderCore(os.LookupEnv, closers...)) } func (y yamlConfigProvider) getNode(key string) *yamlNode { @@ -282,15 +282,26 @@ func (n *yamlNode) Children() []*yamlNode { return nodes } -func unmarshalYAMLValue(reader io.ReadCloser, value interface{}) error { +func unmarshalYAMLValue(reader io.ReadCloser, value interface{}, lookUp lookUpFunc) error { raw, err := ioutil.ReadAll(reader) if err != nil { return errors.Wrap(err, "failed to read the yaml config") } - data, err := interpolateEnvVars(raw) - if err != nil { - return errors.Wrap(err, "failed to interpolate environment variables") + var data []byte + skipInterpolate := false + if f, ok := reader.(*os.File); ok { + if strings.Contains(f.Name(), _secretsFile) { + skipInterpolate = true + data = raw + } + } + + if !skipInterpolate { + data, err = interpolateEnvVars(raw, lookUp) + if err != nil { + return errors.Wrap(err, "failed to interpolate environment variables") + } } if err = yaml.Unmarshal(data, value); err != nil { @@ -312,7 +323,7 @@ func unmarshalYAMLValue(reader io.ReadCloser, value interface{}) error { // will be used // // TODO: what if someone wanted a literal ${FOO} in config? need a small escape hatch -func interpolateEnvVars(data []byte) ([]byte, error) { +func interpolateEnvVars(data []byte, lookUp lookUpFunc) ([]byte, error) { // Is this conversion ok? str := string(data) errs := []string{} @@ -332,7 +343,7 @@ func interpolateEnvVars(data []byte) ([]byte, error) { def = in[sep+1:] } - if envVal, ok := os.LookupEnv(key); ok { + if envVal, ok := lookUp(key); ok { return envVal } diff --git a/config/yaml_test.go b/config/yaml_test.go index 660daeba8..e16b86d84 100644 --- a/config/yaml_test.go +++ b/config/yaml_test.go @@ -26,7 +26,6 @@ import ( "fmt" "io/ioutil" "os" - "path" "reflect" "testing" "time" @@ -167,17 +166,13 @@ func TestExtends(t *testing.T) { assert.Equal(t, "dev_setting", devValue) secretValue := provider.Get("secret").AsString() - assert.Equal(t, "my_secret", secretValue) + assert.Equal(t, "my_${secret}", secretValue) } func TestAppRoot(t *testing.T) { t.Parallel() - cwd, err := os.Getwd() - assert.NoError(t, err) - - defer env.Override(t, _appRoot, path.Join(cwd, "testdata"))() - provider := NewYAMLProviderFromFiles(false, NewRelativeResolver(), "base.yaml", "dev.yaml", "secrets.yaml") + provider := NewYAMLProviderFromFiles(false, NewRelativeResolver("testdata"), "base.yaml", "dev.yaml", "secrets.yaml") baseValue := provider.Get("value").AsString() assert.Equal(t, "base_only", baseValue) @@ -186,7 +181,7 @@ func TestAppRoot(t *testing.T) { assert.Equal(t, "dev_setting", devValue) secretValue := provider.Get("secret").AsString() - assert.Equal(t, "my_secret", secretValue) + assert.Equal(t, "my_${secret}", secretValue) } func TestNewYAMLProviderFromReader(t *testing.T) { @@ -203,7 +198,7 @@ func TestYAMLNode(t *testing.T) { t.Parallel() buff := bytes.NewBuffer([]byte("a: b")) node := &yamlNode{value: make(map[interface{}]interface{})} - err := unmarshalYAMLValue(ioutil.NopCloser(buff), &node.value) + err := unmarshalYAMLValue(ioutil.NopCloser(buff), &node.value, nil) require.NoError(t, err) assert.Equal(t, "map[a:b]", node.String()) assert.Equal(t, "map[interface {}]interface {}", node.Type().String()) @@ -214,7 +209,7 @@ func TestYamlNodeWithNil(t *testing.T) { provider := NewYAMLProviderFromFiles(false, nil) assert.NotNil(t, provider) assert.Panics(t, func() { - _ = unmarshalYAMLValue(nil, nil) + _ = unmarshalYAMLValue(nil, nil, nil) }, "Expected panic with nil inpout.") } @@ -841,3 +836,56 @@ func TestFileNameInPanic(t *testing.T) { NewYAMLProviderFromFiles(true, NewRelativeResolver(), f.Name()) } + +func TestYAMLName(t *testing.T) { + t.Parallel() + + p := NewYAMLProviderFromBytes([]byte(``)) + require.Contains(t, p.Name(), "yaml") +} + +func TestYAMLCallbacks(t *testing.T) { + t.Parallel() + + p := newYAMLProviderCore(nil, ioutil.NopCloser(bytes.NewBuffer(nil))) + require.Nil(t, p.RegisterChangeCallback("key", nil)) + require.Nil(t, p.UnregisterChangeCallback("key")) +} + +func TestAbsolutePaths(t *testing.T) { + t.Parallel() + + file, err := ioutil.TempFile("", "TestAbsolutePaths") + require.NoError(t, err) + file.WriteString("") + require.NoError(t, file.Close()) + defer func() { assert.NoError(t, os.Remove(file.Name())) }() + + p := NewYAMLProviderFromFiles(true, nil, file.Name()) + require.NotNil(t, p) + + val := p.Get("Imaginary") + assert.False(t, val.HasValue()) + assert.Equal(t, time.Time{}, val.LastUpdated()) +} + +func TestPrivateAnonymousField(t *testing.T) { + t.Parallel() + + type x struct { + field string + } + + type y struct { + x + } + + b := []byte(` +x: + field: something +`) + var z y + provider := NewYAMLProviderFromBytes(b) + require.NoError(t, provider.Get(Root).Populate(&z)) + assert.Empty(t, z.field) +} diff --git a/modules/uhttp/router.go b/examples/keyvalue/kv/idl.go similarity index 66% rename from modules/uhttp/router.go rename to examples/keyvalue/kv/idl.go index 64785b917..be97e305d 100644 --- a/modules/uhttp/router.go +++ b/examples/keyvalue/kv/idl.go @@ -1,3 +1,6 @@ +// Code generated by thriftrw v1.2.0 +// @generated + // Copyright (c) 2017 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -18,31 +21,10 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -package uhttp - -import ( - "net/http" - - "github.com/gorilla/mux" - - "go.uber.org/fx/service" -) - -// Router is wrapper around gorila mux -type Router struct { - mux.Router +package kv - host service.Host -} +import "go.uber.org/thriftrw/thriftreflect" -// NewRouter creates a new empty router -func NewRouter(host service.Host) *Router { - return &Router{ - host: host, - } -} +var ThriftModule = &thriftreflect.ThriftModule{Name: "kv", Package: "go.uber.org/fx/examples/keyvalue/kv", FilePath: "kv.thrift", SHA1: "9e8c1c30d0b6bd7d83426a92962269aaef706295", Raw: rawIDL} -// Handle wraps and calls the http.Handler underneath -func (h *Router) Handle(path string, handler http.Handler) { - h.Router.Handle(path, handler) -} +const rawIDL = "exception ResourceDoesNotExist {\n 1: required string key\n 2: optional string message\n}\n\nservice KeyValue {\n string getValue(1: string key)\n throws (1: ResourceDoesNotExist doesNotExist)\n void setValue(1: string key, 2: string value)\n}\n" diff --git a/examples/keyvalue/kv/keyvalue_getvalue.go b/examples/keyvalue/kv/keyvalue_getvalue.go index 96ca8c6b6..bf7d3f20c 100644 --- a/examples/keyvalue/kv/keyvalue_getvalue.go +++ b/examples/keyvalue/kv/keyvalue_getvalue.go @@ -1,4 +1,4 @@ -// Code generated by thriftrw v1.0.0 +// Code generated by thriftrw v1.2.0 // @generated // Copyright (c) 2017 Uber Technologies, Inc. @@ -71,6 +71,9 @@ func (v *KeyValue_GetValue_Args) FromWire(w wire.Value) error { } func (v *KeyValue_GetValue_Args) String() string { + if v == nil { + return "" + } var fields [1]string i := 0 if v.Key != nil { @@ -80,6 +83,13 @@ func (v *KeyValue_GetValue_Args) String() string { return fmt.Sprintf("KeyValue_GetValue_Args{%v}", strings.Join(fields[:i], ", ")) } +func (v *KeyValue_GetValue_Args) Equals(rhs *KeyValue_GetValue_Args) bool { + if !_String_EqualsPtr(v.Key, rhs.Key) { + return false + } + return true +} + func (v *KeyValue_GetValue_Args) MethodName() string { return "getValue" } @@ -210,6 +220,9 @@ func (v *KeyValue_GetValue_Result) FromWire(w wire.Value) error { } func (v *KeyValue_GetValue_Result) String() string { + if v == nil { + return "" + } var fields [2]string i := 0 if v.Success != nil { @@ -223,6 +236,16 @@ func (v *KeyValue_GetValue_Result) String() string { return fmt.Sprintf("KeyValue_GetValue_Result{%v}", strings.Join(fields[:i], ", ")) } +func (v *KeyValue_GetValue_Result) Equals(rhs *KeyValue_GetValue_Result) bool { + if !_String_EqualsPtr(v.Success, rhs.Success) { + return false + } + if !((v.DoesNotExist == nil && rhs.DoesNotExist == nil) || (v.DoesNotExist != nil && rhs.DoesNotExist != nil && v.DoesNotExist.Equals(rhs.DoesNotExist))) { + return false + } + return true +} + func (v *KeyValue_GetValue_Result) MethodName() string { return "getValue" } diff --git a/examples/keyvalue/kv/keyvalue_setvalue.go b/examples/keyvalue/kv/keyvalue_setvalue.go index 120d430ef..4ebd620ea 100644 --- a/examples/keyvalue/kv/keyvalue_setvalue.go +++ b/examples/keyvalue/kv/keyvalue_setvalue.go @@ -1,4 +1,4 @@ -// Code generated by thriftrw v1.0.0 +// Code generated by thriftrw v1.2.0 // @generated // Copyright (c) 2017 Uber Technologies, Inc. @@ -88,6 +88,9 @@ func (v *KeyValue_SetValue_Args) FromWire(w wire.Value) error { } func (v *KeyValue_SetValue_Args) String() string { + if v == nil { + return "" + } var fields [2]string i := 0 if v.Key != nil { @@ -101,6 +104,16 @@ func (v *KeyValue_SetValue_Args) String() string { return fmt.Sprintf("KeyValue_SetValue_Args{%v}", strings.Join(fields[:i], ", ")) } +func (v *KeyValue_SetValue_Args) Equals(rhs *KeyValue_SetValue_Args) bool { + if !_String_EqualsPtr(v.Key, rhs.Key) { + return false + } + if !_String_EqualsPtr(v.Value, rhs.Value) { + return false + } + return true +} + func (v *KeyValue_SetValue_Args) MethodName() string { return "setValue" } @@ -156,11 +169,18 @@ func (v *KeyValue_SetValue_Result) FromWire(w wire.Value) error { } func (v *KeyValue_SetValue_Result) String() string { + if v == nil { + return "" + } var fields [0]string i := 0 return fmt.Sprintf("KeyValue_SetValue_Result{%v}", strings.Join(fields[:i], ", ")) } +func (v *KeyValue_SetValue_Result) Equals(rhs *KeyValue_SetValue_Result) bool { + return true +} + func (v *KeyValue_SetValue_Result) MethodName() string { return "setValue" } diff --git a/examples/keyvalue/kv/keyvalueclient/client.go b/examples/keyvalue/kv/keyvalueclient/client.go index 9e0b3803b..eefa5e1c1 100644 --- a/examples/keyvalue/kv/keyvalueclient/client.go +++ b/examples/keyvalue/kv/keyvalueclient/client.go @@ -26,10 +26,10 @@ package keyvalueclient import ( "context" "go.uber.org/fx/examples/keyvalue/kv" - "go.uber.org/thriftrw/wire" - "go.uber.org/yarpc" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/encoding/thrift" + "go.uber.org/yarpc" + "go.uber.org/thriftrw/wire" ) // Interface is a client for the KeyValue service. diff --git a/examples/keyvalue/kv/keyvalueserver/server.go b/examples/keyvalue/kv/keyvalueserver/server.go index 1d5bec053..bf0f25ec7 100644 --- a/examples/keyvalue/kv/keyvalueserver/server.go +++ b/examples/keyvalue/kv/keyvalueserver/server.go @@ -26,9 +26,9 @@ package keyvalueserver import ( "context" "go.uber.org/fx/examples/keyvalue/kv" - "go.uber.org/thriftrw/wire" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/encoding/thrift" + "go.uber.org/thriftrw/wire" ) // Interface is the server-side interface for the KeyValue service. diff --git a/examples/keyvalue/kv/keyvaluetest/client.go b/examples/keyvalue/kv/keyvaluetest/client.go index 8a1bf3e6c..0c162c4dc 100644 --- a/examples/keyvalue/kv/keyvaluetest/client.go +++ b/examples/keyvalue/kv/keyvaluetest/client.go @@ -25,9 +25,9 @@ package keyvaluetest import ( "context" - "github.com/golang/mock/gomock" "go.uber.org/fx/examples/keyvalue/kv/keyvalueclient" "go.uber.org/yarpc" + "github.com/golang/mock/gomock" ) // MockClient implements a gomock-compatible mock client for service diff --git a/examples/keyvalue/kv/types.go b/examples/keyvalue/kv/types.go index 34dbbd45f..b2367eedb 100644 --- a/examples/keyvalue/kv/types.go +++ b/examples/keyvalue/kv/types.go @@ -1,4 +1,4 @@ -// Code generated by thriftrw v1.0.0 +// Code generated by thriftrw v1.2.0 // @generated // Copyright (c) 2017 Uber Technologies, Inc. @@ -90,6 +90,9 @@ func (v *ResourceDoesNotExist) FromWire(w wire.Value) error { } func (v *ResourceDoesNotExist) String() string { + if v == nil { + return "" + } var fields [2]string i := 0 fields[i] = fmt.Sprintf("Key: %v", v.Key) @@ -101,6 +104,25 @@ func (v *ResourceDoesNotExist) String() string { return fmt.Sprintf("ResourceDoesNotExist{%v}", strings.Join(fields[:i], ", ")) } +func _String_EqualsPtr(lhs, rhs *string) bool { + if lhs != nil && rhs != nil { + x := *lhs + y := *rhs + return (x == y) + } + return lhs == nil && rhs == nil +} + +func (v *ResourceDoesNotExist) Equals(rhs *ResourceDoesNotExist) bool { + if !(v.Key == rhs.Key) { + return false + } + if !_String_EqualsPtr(v.Message, rhs.Message) { + return false + } + return true +} + func (v *ResourceDoesNotExist) Error() string { return v.String() } diff --git a/examples/keyvalue/kv/versioncheck.go b/examples/keyvalue/kv/versioncheck.go index fc212c814..af2cec23f 100644 --- a/examples/keyvalue/kv/versioncheck.go +++ b/examples/keyvalue/kv/versioncheck.go @@ -1,4 +1,4 @@ -// Code generated by thriftrw v1.0.0 +// Code generated by thriftrw v1.2.0 // @generated // Copyright (c) 2017 Uber Technologies, Inc. @@ -26,5 +26,5 @@ package kv import "go.uber.org/thriftrw/version" func init() { - version.CheckCompatWithGeneratedCodeAt("1.0.0", "go.uber.org/fx/examples/keyvalue/kv") + version.CheckCompatWithGeneratedCodeAt("1.2.0", "go.uber.org/fx/examples/keyvalue/kv") } diff --git a/examples/keyvalue/server/config/development.yaml b/examples/keyvalue/server/config/development.yaml new file mode 100644 index 000000000..4044c265e --- /dev/null +++ b/examples/keyvalue/server/config/development.yaml @@ -0,0 +1,7 @@ +logging: + level: debug + stdout: true + encoding: console + encoderConfig: + levelEncoder: color + timeEncoder: iso8601 diff --git a/examples/simple/config/development.yaml b/examples/simple/config/development.yaml new file mode 100644 index 000000000..4044c265e --- /dev/null +++ b/examples/simple/config/development.yaml @@ -0,0 +1,7 @@ +logging: + level: debug + stdout: true + encoding: console + encoderConfig: + levelEncoder: color + timeEncoder: iso8601 diff --git a/examples/simple/handlers.go b/examples/simple/handlers.go index 0c145b7a4..aba256a10 100644 --- a/examples/simple/handlers.go +++ b/examples/simple/handlers.go @@ -25,7 +25,6 @@ import ( "io" "net/http" - "go.uber.org/fx/modules/uhttp" "go.uber.org/fx/service" ) @@ -35,21 +34,8 @@ func (exampleHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { io.WriteString(w, fmt.Sprintf("Headers: %+v", r.Header)) } -func enforceHeader(r uhttp.Route) uhttp.Route { - // require some weird headers - return r.Headers("X-Uber-FX", "yass") -} - -func registerHTTPers(service service.Host) []uhttp.RouteHandler { - handler := &exampleHandler{} - return []uhttp.RouteHandler{ - uhttp.NewRouteHandler("/", handler), - } -} - -type simpleInboundMiddleware struct{} - -func (simpleInboundMiddleware) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - io.WriteString(w, "Going through simpleInboundMiddleware") - next.ServeHTTP(w, r) +func registerHTTPers(service service.Host) http.Handler { + router := http.NewServeMux() + router.Handle("/", &exampleHandler{}) + return router } diff --git a/examples/simple/main.go b/examples/simple/main.go index 6d254395c..8f8a4534b 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -29,7 +29,7 @@ import ( func main() { svc, err := service.WithModule( - uhttp.New(registerHTTPers, uhttp.WithInboundMiddleware(simpleInboundMiddleware{})), + uhttp.New(registerHTTPers), ).Build() if err != nil { diff --git a/glide.lock b/glide.lock index 529d5defe..5e4c694b9 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 03dda616164575043fff9196cb586410a53906cceb2aa0a9105eb17b7a4a2e1f -updated: 2017-03-28T11:36:43.637865373-07:00 +hash: 1b2a476d5083d953c59101472d4cc525971daeaae4d75769775f184780d0aff8 +updated: 2017-04-27T11:45:59.497375972-07:00 imports: - name: github.com/apache/thrift version: 9549b25c77587b29be4e0b5c258221a4ed85d37a @@ -10,11 +10,11 @@ imports: subpackages: - statsd - name: github.com/certifi/gocertifi - version: 03be5e6bb9874570ea7fb0961225d193cbc374c5 + version: a9c833d2837d3b16888d55d5aafa9ffe9afb22b0 - name: github.com/codahale/hdrhistogram version: f8ad88b59a584afeee9d334eff879b104439117b - name: github.com/davecgh/go-spew - version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 + version: 346938d642f2ec3594ed81d874461961cd0faa76 subpackages: - spew - name: github.com/facebookgo/clock @@ -29,27 +29,27 @@ imports: - gomock - name: github.com/gorilla/context version: 1ea25387ff6f684839d82767c1733ff4d4d15d0a -- name: github.com/gorilla/mux - version: 392c28fe23e1c45ddba891b0320b3b5df220beea - name: github.com/gorilla/websocket - version: 3ab3a8b8831546bd18fd182c20687ca853b2bb13 + version: c36f2fe5c330f0ac404b616b96c438b8616b1aaf +- name: github.com/ogier/pflag + version: 45c278ab3607870051a2ea9040bb85fcb8557481 - name: github.com/opentracing/opentracing-go version: 6edb48674bd9467b8e91fda004f2bd7202d60ce4 subpackages: - ext - log - name: github.com/pborman/uuid - version: c55201b036063326c5b1b89ccfe45a184973d073 + version: 1b00554d822231195d1babd97ff4a781231955c9 - name: github.com/pkg/errors version: 645ef00459ed84a119197bfb8d8205042c6df63d - name: github.com/pmezard/go-difflib - version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + version: 792786c7400a136282c1664665ae0a8db921c6c2 subpackages: - difflib - name: github.com/Sirupsen/logrus version: 08a8a7c27e3d058a8989316a850daad1c10bf4ab - name: github.com/stretchr/objx - version: cbeaeb16a013161a98496fad62933b1d21786672 + version: 1a9d0bb9f541897e62256577b352fdbc1fb4fd94 - name: github.com/stretchr/testify version: 4d4bfba8f1d1027c4fdbe371823030df51419987 subpackages: @@ -60,6 +60,8 @@ imports: version: d52ffa061726911f47fcd3d9e8b9b55f22794772 - name: github.com/uber-go/atomic version: e682c1008ac17bf26d2e4b5ad6cdd08520ed0b22 +- name: github.com/uber-go/multierr + version: 737b41aa3bf31c25d11dc84d5275af6bfe2ef4d2 - name: github.com/uber-go/tally version: 34be4a565ce6286a0ba91a54a81be3f6181ca2e2 - name: github.com/uber/cherami-client-go @@ -72,7 +74,7 @@ imports: - common/websocket - stream - name: github.com/uber/cherami-thrift - version: 09ed2ceaeab9e52820a81caece0dee9914c31f5d + version: 0f0585c53937209f08a57c6ae51d8a7fd281e100 subpackages: - .generated/go/cherami - name: github.com/uber/jaeger-client-go @@ -90,7 +92,7 @@ imports: - transport/udp - utils - name: github.com/uber/jaeger-lib - version: 9dd8526f119f8cd8379427bfefdc406e81bc3b2f + version: b9556711760c45a30bd79c31c8f041d0d9aba997 subpackages: - metrics - metrics/tally @@ -108,11 +110,11 @@ imports: - trand - typed - name: go.uber.org/atomic - version: 3b8db5e93c4c02efbc313e17b2e796b0914a01fb + version: 4e336646b2ef9fc6e47be8e21594178f98e5ebcf - name: go.uber.org/dig version: 869ade8e3afd0b6dee05418a8e165cdcb487070a - name: go.uber.org/thriftrw - version: 05f870b3c56597d180af568a6392209cc33269e2 + version: dde90c2a40f45fb2b6361d13c1b4bf09465401c0 subpackages: - envelope - internal @@ -127,10 +129,11 @@ imports: - protocol - protocol/binary - ptr + - thriftreflect - version - wire - name: go.uber.org/yarpc - version: 6ad92c34d7e982d4bb7299f20e6a27652116cf5d + version: 6ae533f0810337028ef055b690360e050aa13219 subpackages: - api/encoding - api/middleware @@ -139,7 +142,6 @@ imports: - encoding/thrift - encoding/thrift/internal - internal - - internal/buffer - internal/clientconfig - internal/encoding - internal/errors @@ -156,7 +158,7 @@ imports: - transport/tchannel - transport/tchannel/internal - name: go.uber.org/zap - version: 4257c7cf05477d92ec25c31cfd3d60e89575f18a + version: 6a4e056f2cc954cfec3581729e758909604b3f76 subpackages: - buffer - internal/bufferpool @@ -166,7 +168,7 @@ imports: - zapcore - zaptest - name: golang.org/x/net - version: a6577fac2d73be281a500b310739095313165611 + version: da118f7b8e5954f39d0d2130ab35d4bf0e3cb344 subpackages: - context - context/ctxhttp @@ -179,12 +181,12 @@ imports: subpackages: - go/ast/astutil - name: gopkg.in/yaml.v2 - version: a3f3340b5840cee44f372bddb5880fcbc419b46a + version: a83829b6f1293c91addabc89d0571c246397bbf4 testImports: - name: github.com/anmitsu/go-shlex version: 648efa622239a2f6ff949fed78ee37b48d499ba4 - name: github.com/axw/gocov - version: c77561ca0c0cb1ed5d4ce4a912a75f5532566422 + version: 3a69a0d2a4ef1f263e2d92b041a69593d6964fe8 subpackages: - gocov - name: github.com/go-playground/overalls @@ -196,25 +198,27 @@ testImports: - name: github.com/google/gofuzz version: 44d81051d367757e1c7c6a5a86423ece9afcf63c - name: github.com/jessevdk/go-flags - version: 460c7bb0abd6e927f2767cadc91aa6ef776a98b4 + version: 0648c820cd4e564706597268ae2d2c7d9e6900c6 - name: github.com/kisielk/errcheck version: 23699b7e2cbfdb89481023524954ba2aeff6be90 - name: github.com/kisielk/gotool version: 0de1eaf82fa3f583ce21fde859f1e7e0c5e9b220 - name: github.com/kyoh86/richgo - version: 35d295f8d8df6dc5159273293c5d294cb6fb6b84 + version: 4aa6fe9df163a501e6eed51be841fc234103295c - name: github.com/mattn/goveralls version: a99c5ee06aeeca2a2befc7e90b99061b1180850c - name: github.com/mvdan/interfacer - version: 049d0176189d83d4e2535611f0253b6f1db487ad + version: 22c51662ff476dfd97944f74db1b263ed920ee83 subpackages: - cmd/interfacer +- name: github.com/mvdan/lint + version: c9cbe299b369cbfea16318baaa037b19a69e45d2 - name: github.com/russross/blackfriday version: 5ebfae50aacdef0dacd1a1acc469c2c1c7a7d4d8 - name: github.com/sectioneight/md-to-godoc version: f274e5a4257c85a9eaf60ac820ee813b78cac6ab - name: github.com/shurcooL/sanitized_anchor_name - version: 1dba4b3954bc059efc3991ec364f9f9a35f597d2 + version: 79c90efaf01eddc01945af5bc1797859189b830b - name: github.com/yookoala/realpath version: c416d99ab5ed256fa30c1f3bab73152deb59bb69 - name: go.uber.org/tools diff --git a/glide.yaml b/glide.yaml index cbbab0b04..14106d9cb 100644 --- a/glide.yaml +++ b/glide.yaml @@ -4,8 +4,6 @@ import: version: ^1 - package: github.com/uber-go/tally version: ^2.1.0 -- package: github.com/gorilla/mux - version: ^1.1.0 - package: github.com/gorilla/context version: ^1.1.0 - package: go.uber.org/yarpc @@ -35,6 +33,9 @@ import: - package: github.com/uber/cherami-thrift subpackages: - .generated/go/cherami +- package: github.com/ogier/pflag +- package: github.com/uber-go/multierr + version: ~0.1 testImport: - package: golang.org/x/tools subpackages: @@ -67,6 +68,7 @@ testImport: version: 2 - package: github.com/shurcooL/sanitized_anchor_name - package: github.com/mvdan/interfacer/cmd/interfacer +- package: github.com/mvdan/lint - package: github.com/kyoh86/richgo - package: go.uber.org/tools subpackages: diff --git a/metrics/nop_reporter.go b/metrics/nop_reporter.go index 770bbe311..e407d8c37 100644 --- a/metrics/nop_reporter.go +++ b/metrics/nop_reporter.go @@ -27,9 +27,9 @@ import ( ) var ( - capabilitiesReportingNoTagging = &capabilities{ + capabilitiesReporting = &capabilities{ reporting: true, - tagging: false, + tagging: true, } ) @@ -52,6 +52,9 @@ func (c *capabilities) Tagging() bool { // Remove and replace metrics.NopCachedStatsReporter with tally.NopCachedStatsReporter once issue is resolved var NopCachedStatsReporter tally.CachedStatsReporter = nopCachedStatsReporter{} +// NopScope is a root scope that does nothing +var NopScope, _ = tally.NewRootScope(tally.ScopeOptions{CachedReporter: NopCachedStatsReporter}, 0) + type nopCachedStatsReporter struct{} func (nopCachedStatsReporter) AllocateCounter(name string, tags map[string]string) tally.CachedCount { @@ -71,7 +74,7 @@ func (nopCachedStatsReporter) AllocateTimer(name string, tags map[string]string) } func (r nopCachedStatsReporter) Capabilities() tally.Capabilities { - return capabilitiesReportingNoTagging + return capabilitiesReporting } func (r nopCachedStatsReporter) Flush() {} diff --git a/modules/task/cherami/cherami.go b/modules/task/cherami/cherami.go index cc5d5b006..d8cbd1aa0 100644 --- a/modules/task/cherami/cherami.go +++ b/modules/task/cherami/cherami.go @@ -30,6 +30,8 @@ import ( "go.uber.org/fx/service" "go.uber.org/fx/ulog" + "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go/ext" "github.com/pkg/errors" "github.com/uber-go/tally" "github.com/uber/cherami-client-go/client/cherami" @@ -43,7 +45,9 @@ const ( _initialized state = iota _running _stopped - _pathPrefix = "/uberfx_async/" + _pathPrefix = "/uberfx_async/" + _ctxKey = "ctxKey" + _operationName = "task.Run" ) var ( @@ -86,6 +90,8 @@ type Backend struct { stateMu sync.RWMutex taskSuccess tally.Counter taskFailure tally.Counter + tracer opentracing.Tracer + ctxEncoder task.ContextEncoding } // RegisterHyperbahnBootstrapFile registers the hyperbahn bootstrap filename required for cherami @@ -173,6 +179,8 @@ func newBackendWithConfig( scope: scope, taskSuccess: scope.Counter("task.success"), taskFailure: scope.Counter("task.fail"), + tracer: host.Tracer(), + ctxEncoder: task.ContextEncoding{Tracer: host.Tracer()}, }, nil } @@ -259,22 +267,48 @@ func (b *Backend) consumeAndExecute() { b.consumeAndExecute() } }() + for delivery := range b.deliveryCh { messageData := delivery.GetMessage().GetPayload().GetData() - // TODO (madhu): Only specific errors should be retried - // TODO (madhu): Once context is added to the message, use that here - ctx := context.Background() - if err := task.Run(ctx, messageData); err != nil { + b.withContext(delivery, func(ctx context.Context) { + // TODO (madhu): Only specific errors should be retried + if err := task.Run(ctx, messageData); err != nil { + ulog.Logger(ctx).Error("Task run failed", zap.Error(err)) + b.taskFailure.Inc(1) + if err := delivery.Nack(); err != nil { + ulog.Logger(ctx).Error("Delivery Nack failed", zap.Error(err)) + } + } else { + b.taskSuccess.Inc(1) + if err = delivery.Ack(); err != nil { + ulog.Logger(ctx).Error("Task ack to cherami failed", zap.Error(err)) + } + } + }) + } +} + +func (b *Backend) withContext(delivery cherami.Delivery, f func(context.Context)) { + ctxData := delivery.GetMessage().GetPayload().GetUserContext() + ctx := context.Background() + if ctxVal, ok := ctxData[_ctxKey]; ok { + if spanCtx, err := b.ctxEncoder.Unmarshal([]byte(ctxVal)); err != nil { + ulog.Logger(ctx).Error("Unable to decode context", zap.Error(err)) b.taskFailure.Inc(1) - ulog.Logger(ctx).Error("Task run failed", zap.Error(err)) - _ = delivery.Nack() - } else { - b.taskSuccess.Inc(1) - if err = delivery.Ack(); err != nil { - ulog.Logger(ctx).Error("Task ack to cherami failed", zap.Error(err)) + if err := delivery.Nack(); err != nil { + ulog.Logger(ctx).Error("Delivery Nack failed", zap.Error(err)) } + } else { + var span opentracing.Span + span = b.tracer.StartSpan(_operationName, ext.RPCServerOption(spanCtx)) + defer span.Finish() + f(opentracing.ContextWithSpan(ctx, span)) } + + return } + + f(ctx) } // IsRunning returns true if backend is running @@ -286,11 +320,21 @@ func (b *Backend) isRunning() bool { // Enqueue sends the message to cherami func (b *Backend) Enqueue(ctx context.Context, message []byte) error { - // TODO (madhu): Extract and serialize context with the message + ctxBytes, err := b.ctxEncoder.Marshal(ctx) + if err != nil { + return errors.Wrap(err, "unable to encode context") + } + + ctxMap := make(map[string]string) + if len(ctxBytes) > 0 { + ctxMap[_ctxKey] = string(ctxBytes) + } + receipt := b.publisher.Publish(&cherami.PublisherMessage{ Data: message, - UserContext: make(map[string]string), + UserContext: ctxMap, }) + return receipt.Error } diff --git a/modules/task/cherami/cherami_test.go b/modules/task/cherami/cherami_test.go index 7b63dbc6d..a656a8c1e 100644 --- a/modules/task/cherami/cherami_test.go +++ b/modules/task/cherami/cherami_test.go @@ -29,12 +29,12 @@ import ( "testing" "time" - "go.uber.org/fx/auth" "go.uber.org/fx/config" cherami_mocks "go.uber.org/fx/mocks/modules/task/cherami" "go.uber.org/fx/modules/task" "go.uber.org/fx/service" "go.uber.org/fx/testutils" + "go.uber.org/fx/testutils/tracing" "go.uber.org/fx/ulog" "github.com/opentracing/opentracing-go" @@ -46,11 +46,10 @@ import ( ) var ( - _host = service.NopHostConfigured( - auth.NopClient, ulog.Logger(context.Background()), opentracing.NoopTracer{}, - ) - _pathName = _pathPrefix + _host.Name() - _cgName = _pathPrefix + _host.Name() + "_cg" + _host = service.NopHostConfigured(ulog.Logger(context.Background()), opentracing.NoopTracer{}) + _pathName = _pathPrefix + _host.Name() + _cgName = _pathPrefix + _host.Name() + "_cg" + _publishMsg = []byte("Hello") ) func TestBackendWorkflow(t *testing.T) { @@ -58,22 +57,23 @@ func TestBackendWorkflow(t *testing.T) { defer m.AssertExpectations(t) zapLogger, buf := testutils.GetLockedInMemoryLogger() defer ulog.SetLogger(zapLogger)() - bknd := createNewBackend(t, m) - assert.NotNil(t, bknd.Encoder()) - deliveryCh, err := startBackend(t, m, bknd, nil, nil) - require.NoError(t, err) - assert.True(t, bknd.(*Backend).isRunning()) - require.NoError(t, bknd.ExecuteAsync()) - - publish(t, m, bknd, deliveryCh) - publish(t, m, bknd, deliveryCh) - time.Sleep(10 * time.Millisecond) - stopBackend(t, m, bknd) - lines := buf.Lines() - require.Equal(t, 2, len(lines)) - for _, line := range lines { - assert.Contains(t, line, "forget to register") - } + tracing.WithTracer(t, zapLogger, func(tracer opentracing.Tracer) { + host := service.NopHostConfigured(zapLogger, tracer) + bknd := createNewBackend(t, m, host) + assert.NotNil(t, bknd.Encoder()) + deliveryCh, err := startBackend(t, m, bknd, nil, nil) + require.NoError(t, err) + assert.True(t, bknd.(*Backend).isRunning()) + require.NoError(t, bknd.ExecuteAsync()) + tracing.WithSpan(t, zapLogger, func(span opentracing.Span) { + publish(t, m, bknd, deliveryCh, span, nil) + publish(t, m, bknd, deliveryCh, span, errors.New("nack error")) + }) + time.Sleep(30 * time.Millisecond) + stopBackend(t, m, bknd) + lines := buf.Lines() + findInLogs(t, lines, map[string]int{"forget to register": 2, "nack error": 1}) + }) } func TestBackendWorkflowWorkerPanic(t *testing.T) { @@ -81,42 +81,40 @@ func TestBackendWorkflowWorkerPanic(t *testing.T) { defer m.AssertExpectations(t) zapLogger, buf := testutils.GetLockedInMemoryLogger() defer ulog.SetLogger(zapLogger)() - bknd := createNewBackend(t, m) + bknd := createNewBackend(t, m, _host) deliveryCh, err := startBackend(t, m, bknd, nil, nil) require.NoError(t, err) assert.True(t, bknd.(*Backend).isRunning()) require.NoError(t, bknd.ExecuteAsync()) // Panic on ConsumeWorkerCount and make sure workers are still alive to consume messages for i := 0; i < _defaultClientConfig.ConsumeWorkerCount; i++ { - deliveryCh <- m.Delivery m.Delivery.On("GetMessage").Return( &cherami_gen.ConsumerMessage{ - Payload: &cherami_gen.PutMessage{Data: []byte("Hello")}, + Payload: &cherami_gen.PutMessage{Data: _publishMsg}, }, ) - m.Delivery.On("Nack").Run(func(mock.Arguments) { panic("nack panic") }) + m.Delivery.On("Nack").Run(func(mock.Arguments) { panic("nack panic") }).Once() + deliveryCh <- m.Delivery } // Publish valid message - publish(t, m, bknd, deliveryCh) - time.Sleep(10 * time.Millisecond) + publish(t, m, bknd, deliveryCh, nil, nil) + time.Sleep(100 * time.Millisecond) assert.True(t, bknd.(*Backend).isRunning()) stopBackend(t, m, bknd) // Nack panics are sent for a count of _numWorkers and 1 valid publish. Make sure they are // all processed lines := buf.Lines() - var ct int - for _, line := range lines { - if strings.Contains(line, "forget to register") { - ct++ - } - } - require.Equal(t, _defaultClientConfig.ConsumeWorkerCount+1, ct) + findInLogs( + t, + lines, + map[string]int{"forget to register": _defaultClientConfig.ConsumeWorkerCount + 1}, + ) } func TestBackendWorkflowStateLocks(t *testing.T) { m := newMock() defer m.AssertExpectations(t) - bknd := createNewBackend(t, m) + bknd := createNewBackend(t, m, _host) assert.NotNil(t, bknd.Encoder()) var wg sync.WaitGroup wg.Add(2) @@ -241,7 +239,7 @@ func TestStartBackendInvalidStateError(t *testing.T) { stateToError := map[state]string{_running: "already running", _stopped: "has been stopped"} for state, errStr := range stateToError { m := newMock() - bknd := createNewBackend(t, m) + bknd := createNewBackend(t, m, _host) bknd.(*Backend).setState(state) err := bknd.Start() assert.Contains(t, err.Error(), errStr) @@ -251,7 +249,7 @@ func TestStartBackendInvalidStateError(t *testing.T) { func TestStartBackendOpenPublisherError(t *testing.T) { m := newMock() defer m.AssertExpectations(t) - bknd := createNewBackend(t, m) + bknd := createNewBackend(t, m, _host) errStr := "publish error" _, err := startBackend(t, m, bknd, errors.New(errStr), nil) assert.False(t, bknd.(*Backend).isRunning()) @@ -261,13 +259,70 @@ func TestStartBackendOpenPublisherError(t *testing.T) { func TestStartBackendOpenConsumerError(t *testing.T) { m := newMock() defer m.AssertExpectations(t) - bknd := createNewBackend(t, m) + bknd := createNewBackend(t, m, _host) errStr := "consume error" _, err := startBackend(t, m, bknd, nil, errors.New(errStr)) assert.False(t, bknd.(*Backend).isRunning()) assert.Contains(t, err.Error(), errStr) } +func TestEncodingErrors(t *testing.T) { + m := newMock() + defer m.AssertExpectations(t) + testArgs := []struct { + nackError error + expectedLogs map[string]int + }{ + {nil, map[string]int{"extract error": 1}}, + {errors.New("nack error"), map[string]int{"extract error": 1, "nack error": 1}}, + } + for _, testArg := range testArgs { + tracer := &tracing.ErrorTracer{Tracer: opentracing.NoopTracer{}} + zapLogger, buf := testutils.GetLockedInMemoryLogger() + defer ulog.SetLogger(zapLogger)() + host := service.NopHostConfigured(zapLogger, tracer) + + bknd := createNewBackend(t, m, host) + cBknd := bknd.(*Backend) + m.Delivery.On("GetMessage").Return( + &cherami_gen.ConsumerMessage{ + Payload: &cherami_gen.PutMessage{ + Data: _publishMsg, + UserContext: map[string]string{_ctxKey: ""}, + }, + }, + ) + m.Delivery.On("Nack").Return(testArg.nackError).Once() + cBknd.withContext(m.Delivery, func(context.Context) {}) + tracing.WithSpan(t, zapLogger, func(span opentracing.Span) { + ctx := opentracing.ContextWithSpan(context.Background(), span) + err := cBknd.Enqueue(ctx, _publishMsg) + require.Error(t, err) + assert.Contains(t, err.Error(), "unable to encode context") + }) + findInLogs(t, buf.Lines(), testArg.expectedLogs) + } +} + +func findInLogs(t *testing.T, logs []string, expectedLinesWithCt map[string]int) { + actualLinesWithCt := make(map[string]int) + for _, line := range logs { + for k := range expectedLinesWithCt { + if strings.Contains(line, k) { + actualLinesWithCt[k]++ + } + } + } + for k, v := range expectedLinesWithCt { + assert.Equal( + t, + v, + actualLinesWithCt[k], + "Expected msg: %s to occur %d times but found %d", k, v, actualLinesWithCt[k], + ) + } +} + type cheramiMock struct { Client *cherami_mocks.Client Pub *cherami_mocks.Publisher @@ -291,12 +346,12 @@ func (m *cheramiMock) AssertExpectations(t *testing.T) { m.Delivery.AssertExpectations(t) } -func createNewBackend(t *testing.T, m *cheramiMock) task.Backend { +func createNewBackend(t *testing.T, m *cheramiMock, host service.Host) task.Backend { setupHappyClientFunc(m) setupDest(m, _pathName, nil) setupCg(m, _pathName, _cgName, nil) setupPublisherConsumer(m, _pathName, _cgName) - bknd, err := NewBackend(_host) + bknd, err := NewBackend(host) require.NoError(t, err) assert.NotNil(t, bknd) assert.False(t, bknd.(*Backend).isRunning()) @@ -365,18 +420,31 @@ func startBackend( return deliveryCh, err } -func publish(t *testing.T, m *cheramiMock, bknd task.Backend, deliveryCh chan cherami.Delivery) { - msg := []byte("Hello") +func publish( + t *testing.T, m *cheramiMock, + bknd task.Backend, + deliveryCh chan cherami.Delivery, + span opentracing.Span, + nackErr error, +) { + ctx := context.Background() + userCtx := make(map[string]string) + if span != nil { + ctx = opentracing.ContextWithSpan(ctx, span) + ctxBytes, err := bknd.(*Backend).ctxEncoder.Marshal(ctx) + require.NoError(t, err) + userCtx[_ctxKey] = string(ctxBytes) + } m.Pub.On( - "Publish", &cherami.PublisherMessage{Data: msg, UserContext: map[string]string{}}, + "Publish", &cherami.PublisherMessage{Data: _publishMsg, UserContext: userCtx}, ).Run( func(mock.Arguments) { deliveryCh <- m.Delivery }, - ).Return(&cherami.PublisherReceipt{}) + ).Return(&cherami.PublisherReceipt{}).Once() m.Delivery.On("GetMessage").Return( &cherami_gen.ConsumerMessage{ - Payload: &cherami_gen.PutMessage{Data: msg}, + Payload: &cherami_gen.PutMessage{Data: _publishMsg}, }, ) - m.Delivery.On("Nack").Return(nil) - require.NoError(t, bknd.Enqueue(context.Background(), msg)) + m.Delivery.On("Nack").Return(nackErr).Once() + require.NoError(t, bknd.Enqueue(ctx, _publishMsg)) } diff --git a/modules/task/encoding.go b/modules/task/encoding.go index 08cadba05..e7068aa22 100644 --- a/modules/task/encoding.go +++ b/modules/task/encoding.go @@ -22,8 +22,10 @@ package task import ( "bytes" + "context" "encoding/gob" + "github.com/opentracing/opentracing-go" "github.com/pkg/errors" ) @@ -81,3 +83,40 @@ func (g GobEncoding) Unmarshal(data []byte, obj interface{}) error { } return nil } + +// ContextEncoding supports encoding for the context object +type ContextEncoding struct { + Tracer opentracing.Tracer +} + +// Marshal encodes a context into bytes +func (c *ContextEncoding) Marshal(ctx context.Context) ([]byte, error) { + span := opentracing.SpanFromContext(ctx) + if span == nil { + return nil, nil + } + + spanCtx := span.Context() + if spanCtx == nil { + return nil, nil + } + + var carrier bytes.Buffer + err := c.Tracer.Inject(spanCtx, opentracing.Binary, &carrier) + return carrier.Bytes(), err +} + +// Unmarshal decodes a bytes array into context +// NOTE: If we were to add more things to the context, this will need to change to return a +// collection of context values instead of just SpanContext +func (c *ContextEncoding) Unmarshal(data []byte) (opentracing.SpanContext, error) { + carrier := bytes.NewBuffer(data) + spanContext, err := c.Tracer.Extract(opentracing.Binary, carrier) + + // If no SpanContext was given, we return nil instead of erroring + if err == opentracing.ErrSpanContextNotFound { + return nil, nil + } + + return spanContext, err +} diff --git a/modules/task/encoding_test.go b/modules/task/encoding_test.go index 4e0e0651d..6db1f62dc 100644 --- a/modules/task/encoding_test.go +++ b/modules/task/encoding_test.go @@ -21,38 +21,66 @@ package task import ( + "context" "reflect" "testing" + "go.uber.org/fx/testutils/tracing" + "go.uber.org/zap" + + "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -type encodingTest struct { - encoding Encoding - inputObj interface{} - verifyEncoding bool -} - var kvMap = map[string]string{"key": "value"} -var encodingTests = []encodingTest{ - {&NopEncoding{}, kvMap, false}, - {&GobEncoding{}, kvMap, true}, -} - func TestEncoding(t *testing.T) { + encodingTests := []struct { + encoding Encoding + inputObj interface{} + expectedObj interface{} + }{ + {&NopEncoding{}, kvMap, nil}, + {&GobEncoding{}, kvMap, kvMap}, + } for _, test := range encodingTests { - testEncMethods(t, test.encoding, test.inputObj, test.verifyEncoding) + testEncMethods(t, test.encoding, test.inputObj, test.expectedObj) } } -func testEncMethods(t *testing.T, encoding Encoding, obj interface{}, deepChecks bool) { +func testEncMethods(t *testing.T, encoding Encoding, obj interface{}, expectedObj interface{}) { assert.NoError(t, encoding.Register(obj)) msg, err := encoding.Marshal(obj) - assert.NoError(t, err) + require.NoError(t, err) receivedObj := make(map[string]string) - assert.NoError(t, encoding.Unmarshal(msg, &receivedObj)) - if deepChecks { - assert.True(t, reflect.DeepEqual(obj, receivedObj)) + require.NoError(t, encoding.Unmarshal(msg, &receivedObj)) + if expectedObj != nil { + assert.True(t, reflect.DeepEqual(expectedObj, receivedObj)) } } + +func TestContextEncoding(t *testing.T) { + nopZap := zap.NewNop() + tracing.WithTracer(t, nopZap, func(tracer opentracing.Tracer) { + encoding := ContextEncoding{Tracer: tracer} + tracing.WithSpan(t, nopZap, func(span opentracing.Span) { + ctx := opentracing.ContextWithSpan(context.Background(), span) + msg, err := encoding.Marshal(ctx) + require.NoError(t, err) + spanCtx, err := encoding.Unmarshal(msg) + require.NoError(t, err) + assert.Equal(t, span.Context(), spanCtx) + }) + }) +} + +func TestContextEncodingWithNoSpan(t *testing.T) { + encoding := ContextEncoding{Tracer: opentracing.NoopTracer{}} + msg, err := encoding.Marshal(context.Background()) + require.NoError(t, err) + assert.Nil(t, msg) + spanCtx, err := encoding.Unmarshal(msg) + require.NoError(t, err) + assert.Nil(t, spanCtx) +} diff --git a/modules/task/task.go b/modules/task/task.go index 9681191ed..70041eb21 100644 --- a/modules/task/task.go +++ b/modules/task/task.go @@ -23,16 +23,16 @@ package task import ( "sync" - "github.com/pkg/errors" - "github.com/uber-go/tally" - + "go.uber.org/fx/metrics" "go.uber.org/fx/service" + + "github.com/pkg/errors" ) var ( _globalBackendMu sync.RWMutex _globalBackend Backend = &NopBackend{} - _globalBackendStatsClient = newStatsClient(tally.NoopScope) + _globalBackendStatsClient = newStatsClient(metrics.NopScope) _asyncMod service.Module _asyncModErr error _once sync.Once diff --git a/modules/uhttp/README.md b/modules/uhttp/README.md index 47ab2482f..c08678655 100644 --- a/modules/uhttp/README.md +++ b/modules/uhttp/README.md @@ -1,7 +1,9 @@ # HTTP Module -The HTTP module is built on top of [Gorilla Mux](https://github.com/gorilla/mux), +The HTTP module is built on top of [standardlib http library](https://golang.org/pkg/net/http/), but the details of that are abstracted away through `uhttp.RouteHandler`. +As part of module initialization, you can now pass in a `mux.Router` to the +`uhttp` module. ```go package main @@ -25,15 +27,14 @@ func main() { svc.Start(true) } -func registerHTTP(service service.Host) []uhttp.RouteHandler { - handleHome := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func registerHTTP(service service.Host) http.Handler { + handleHome := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ulog.Logger(r.Context()).Info("Inside the handler") io.WriteString(w, "Hello, world") }) - - return []uhttp.RouteHandler{ - uhttp.NewRouteHandler("/", handleHome) - } + router := http.NewServeMux() + router.Handle("/", handleHome) + return router } ``` @@ -68,7 +69,7 @@ func main() { log.Fatal("Could not initialize service: ", err) } - client := uhttpclient.New(svc) + client := uhttpclient.New(opentracing.GlobalTracer(), svc) client.Get("https://www.uber.com") } ``` diff --git a/modules/uhttp/doc.go b/modules/uhttp/doc.go index 184e44db6..222a7ebd1 100644 --- a/modules/uhttp/doc.go +++ b/modules/uhttp/doc.go @@ -20,9 +20,12 @@ // Package uhttp is the HTTP Module. // -// The HTTP module is built on top of Gorilla Mux (https://github.com/gorilla/mux), +// The HTTP module is built on top of standardlib http library (https://golang.org/pkg/net/http/), // but the details of that are abstracted away through // uhttp.RouteHandler. +// As part of module initialization, you can now pass in a +// mux.Router to the +// uhttp module. // // package main // @@ -45,15 +48,14 @@ // svc.Start(true) // } // -// func registerHTTP(service service.Host) []uhttp.RouteHandler { -// handleHome := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// func registerHTTP(service service.Host) http.Handler { +// handleHome := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // ulog.Logger(r.Context()).Info("Inside the handler") // io.WriteString(w, "Hello, world") // }) -// -// return []uhttp.RouteHandler{ -// uhttp.NewRouteHandler("/", handleHome) -// } +// router := http.NewServeMux() +// router.Handle("/", handleHome) +// return router // } // // HTTP handlers are set up with inbound middleware that inject tracing, @@ -88,7 +90,7 @@ // log.Fatal("Could not initialize service: ", err) // } // -// client := uhttpclient.New(svc) +// client := uhttpclient.New(opentracing.GlobalTracer(), svc) // client.Get("https://www.uber.com") // } // diff --git a/modules/uhttp/http.go b/modules/uhttp/http.go index 12e753345..50cd6dc4b 100644 --- a/modules/uhttp/http.go +++ b/modules/uhttp/http.go @@ -29,6 +29,7 @@ import ( "sync" "time" + "go.uber.org/fx/auth" "go.uber.org/fx/service" "go.uber.org/fx/ulog" @@ -55,24 +56,19 @@ const ( // default healthcheck endpoint healthPath = "/health" + + // default pprof endpoint + pprofPath = "/debug/pprof" ) var _ service.Module = &Module{} // A Module is a module to handle HTTP requests type Module struct { - service.Host - config Config - log *zap.Logger - srv *http.Server listener net.Listener - handlers []RouteHandler - mcb inboundMiddlewareChainBuilder lock sync.RWMutex } -var _ service.Module = &Module{} - // Config handles config for HTTP modules type Config struct { Port int `yaml:"port"` @@ -80,27 +76,18 @@ type Config struct { Debug bool `yaml:"debug" default:"true"` } -// GetHandlersFunc returns a slice of registrants from a service host -type GetHandlersFunc func(service service.Host) []RouteHandler +// GetHandlersFunc returns http handler created by caller +type GetHandlersFunc func(service service.Host) http.Handler // New returns a new HTTP ModuleProvider. -func New(hookup GetHandlersFunc, options ...ModuleOption) service.ModuleProvider { +func New(handlerFunc GetHandlersFunc) service.ModuleProvider { return service.ModuleProviderFromFunc("uhttp", func(host service.Host) (service.Module, error) { - return newModule(host, hookup, options...) + handler := handlerFunc(host) + return newModule(host, handler) }) } -func newModule( - host service.Host, - getHandlers GetHandlersFunc, - options ...ModuleOption, -) (*Module, error) { - moduleOptions := &moduleOptions{} - for _, option := range options { - if err := option(moduleOptions); err != nil { - return nil, err - } - } +func newModule(host service.Host, handler http.Handler) (*Module, error) { // setup config defaults cfg := Config{ Port: defaultPort, @@ -110,50 +97,53 @@ func newModule( if err := host.Config().Get("modules").Get(host.ModuleName()).Populate(&cfg); err != nil { log.Error("Error loading http module configuration", zap.Error(err)) } - module := &Module{ - Host: host, - handlers: addHealth(getHandlers(host)), - mcb: defaultInboundMiddlewareChainBuilder(log, host.AuthClient(), newStatsClient(host.Metrics())), - config: cfg, - log: log, - } - module.mcb = module.mcb.AddMiddleware(moduleOptions.inboundMiddleware...) - return module, nil -} - -// Start begins serving requests over HTTP -func (m *Module) Start() error { - mux := http.NewServeMux() - // Do something unrelated to annotations - router := NewRouter(m.Host) - - mux.Handle("/", router) - - for _, h := range m.handlers { - router.Handle(h.Path, m.mcb.Build(h.Handler)) - } - if m.config.Debug { - router.PathPrefix("/debug/pprof").Handler(http.DefaultServeMux) + serveMux := http.NewServeMux() + serveMux.Handle(healthPath, healthHandler{}) + + // TODO: pass in the auth client as part of module construction + authClient := auth.Load(host.Config(), host.Metrics()) + stats := newStatsClient(host.Metrics()) + + handle := + panicInbound( + metricsInbound( + tracingInbound( + authorizationInbound(handler, authClient, stats), + ), stats, + ), stats, + ) + serveMux.Handle("/", handle) + + if cfg.Debug { + serveMux.Handle(pprofPath, http.DefaultServeMux) } - // Set up the socket - listener, err := net.Listen("tcp", fmt.Sprintf(":%d", m.config.Port)) + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", cfg.Port)) if err != nil { - return errors.Wrap(err, "unable to open TCP listener for HTTP module") + return nil, errors.Wrap(err, "unable to open TCP listener for HTTP module") } + // finally, start the http server. - // TODO update log object to be accessed via http context #74 - m.log.Info("Server listening on port", zap.Int("port", m.config.Port)) + log.Info("Server listening on port", zap.Int("port", cfg.Port)) + + srv := &http.Server{ + Handler: serveMux, + } - m.listener = listener - m.srv = &http.Server{Handler: mux} go func() { // TODO(pedge): what to do about error? - if err := m.srv.Serve(listener); err != nil { - m.log.Error("HTTP Serve error", zap.Error(err)) + if err := srv.Serve(listener); err != nil { + log.Error("HTTP Serve error", zap.Error(err)) } }() + return &Module{ + listener: listener, + }, nil +} + +// Start begins serving requests over HTTP +func (m *Module) Start() error { return nil } @@ -171,17 +161,3 @@ func (m *Module) Stop() error { } return err } - -// addHealth adds in the default if health handler is not set -func addHealth(handlers []RouteHandler) []RouteHandler { - healthFound := false - for _, h := range handlers { - if h.Path == healthPath { - healthFound = true - } - } - if !healthFound { - handlers = append(handlers, NewRouteHandler(healthPath, healthHandler{})) - } - return handlers -} diff --git a/modules/uhttp/http_options.go b/modules/uhttp/http_options.go deleted file mode 100644 index 56cc987d2..000000000 --- a/modules/uhttp/http_options.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package uhttp - -// ModuleOption is a function that configures module creation. -type ModuleOption func(*moduleOptions) error - -type moduleOptions struct { - inboundMiddleware []InboundMiddleware -} - -// WithInboundMiddleware adds inbound middleware to uhttp Module that will be applied to all incoming http requests. -func WithInboundMiddleware(m ...InboundMiddleware) ModuleOption { - return func(moduleOptions *moduleOptions) error { - moduleOptions.inboundMiddleware = append(moduleOptions.inboundMiddleware, m...) - return nil - } -} diff --git a/modules/uhttp/http_test.go b/modules/uhttp/http_test.go index 7ccc51387..17905fd6c 100644 --- a/modules/uhttp/http_test.go +++ b/modules/uhttp/http_test.go @@ -23,12 +23,12 @@ package uhttp import ( "fmt" "io" - "io/ioutil" "net/http" "runtime" "testing" "time" + "go.uber.org/fx/config" "go.uber.org/fx/service" . "go.uber.org/fx/service/testutils" . "go.uber.org/fx/testutils" @@ -39,48 +39,27 @@ import ( "github.com/uber-go/tally" ) +var _httpconfig = []byte(` +modules: + uhttp: + port: 0 + debug: true +`) + // Custom default client since http's defaultClient does not set timeout var _defaultHTTPClient = &http.Client{Timeout: 2 * time.Second} func TestNew_OK(t *testing.T) { + t.Parallel() WithService(New(registerNothing), nil, []service.Option{configOption()}, func(s service.Manager) { assert.NotNil(t, s, "Should create a module") - }) -} -func TestHTTPModule_WithInboundMiddleware(t *testing.T) { - withModule( - t, - registerPanic, - []ModuleOption{WithInboundMiddleware(fakeInbound())}, - false, - func(m *Module) { - assert.NotNil(t, m) - makeRequest(m, "GET", "/", nil, func(r *http.Response) { - body, err := ioutil.ReadAll(r.Body) - assert.NoError(t, err) - assert.Contains(t, string(body), "inbound middleware is executed") - }) - verifyMetrics(t, m.Metrics()) - }) -} - -func TestHTTPModule_WithUserPanicInboundMiddleware(t *testing.T) { - withModule( - t, - registerTracerCheckHandler, - []ModuleOption{WithInboundMiddleware(userPanicInbound())}, - false, - func(m *Module) { - assert.NotNil(t, m) - makeRequest(m, "GET", "/", nil, func(r *http.Response) { - assert.Equal(t, http.StatusInternalServerError, r.StatusCode, "Expected 500 with panic wrapper") - }) - }) + }) } func TestHTTPModule_Panic_OK(t *testing.T) { - withModule(t, registerPanic, nil, false, func(m *Module) { + t.Parallel() + withModule(t, registerPanic(), false, func(m *Module) { assert.NotNil(t, m) makeRequest(m, "GET", "/", nil, func(r *http.Response) { assert.Equal(t, http.StatusInternalServerError, r.StatusCode, "Expected 500 with panic wrapper") @@ -89,7 +68,8 @@ func TestHTTPModule_Panic_OK(t *testing.T) { } func TestHTTPModule_Tracer(t *testing.T) { - withModule(t, registerTracerCheckHandler, nil, false, func(m *Module) { + t.Parallel() + withModule(t, registerTracerCheckHandler(), false, func(m *Module) { assert.NotNil(t, m) makeRequest(m, "GET", "/", nil, func(r *http.Response) { assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with tracer check") @@ -98,13 +78,15 @@ func TestHTTPModule_Tracer(t *testing.T) { } func TestHTTPModule_StartsAndStops(t *testing.T) { - withModule(t, registerPanic, nil, false, func(m *Module) { + t.Parallel() + withModule(t, registerPanic(), false, func(m *Module) { assert.NotNil(t, m.listener, "Start should be successful") }) } func TestBuiltinHealth_OK(t *testing.T) { - withModule(t, registerNothing, nil, false, func(m *Module) { + t.Parallel() + withModule(t, registerNothing(nil), false, func(m *Module) { assert.NotNil(t, m) makeRequest(m, "GET", "/health", nil, func(r *http.Response) { assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with default health handler") @@ -112,27 +94,6 @@ func TestBuiltinHealth_OK(t *testing.T) { }) } -func TestOverrideHealth_OK(t *testing.T) { - withModule(t, registerCustomHealth, nil, false, func(m *Module) { - assert.NotNil(t, m) - makeRequest(m, "GET", "/health", nil, func(r *http.Response) { - assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 with default health handler") - body, err := ioutil.ReadAll(r.Body) - require.NoError(t, err, "Should be able to read health body") - assert.Equal(t, "not ok", string(body)) - }) - }) -} - -func TestPProf_Registered(t *testing.T) { - withModule(t, registerNothing, nil, false, func(m *Module) { - assert.NotNil(t, m) - makeRequest(m, "GET", "/debug/pprof", nil, func(r *http.Response) { - assert.Equal(t, http.StatusOK, r.StatusCode, "Expected 200 from pprof handler") - }) - }) -} - // TODO(ai) add a test for binding a bad port and get an error out of Start() func configOption() service.Option { @@ -141,22 +102,24 @@ func configOption() service.Option { func withModule( t testing.TB, - hookup GetHandlersFunc, - moduleOptions []ModuleOption, + handler http.Handler, expectError bool, fn func(*Module), ) { - host, err := service.NewScopedHost(service.NopHost(), "uhttp", "hello") + host, err := service.NewScopedHost( + service.NopHostWithConfig( + config.NewYAMLProviderFromBytes(_httpconfig)), + "uhttp", + "hello", + ) require.NoError(t, err) - mod, err := newModule(host, hookup, moduleOptions...) + mod, err := newModule(host, handler) if expectError { require.Error(t, err, "Expected error instantiating module") fn(nil) return } require.NoError(t, err, "Unable to instantiate module") - // us an ephemeral port on tests - mod.config.Port = 0 assert.NoError(t, mod.Start(), "Got error from starting") fn(mod) runtime.Gosched() @@ -183,21 +146,12 @@ func makeRequest(m *Module, method, url string, body io.Reader, fn func(r *http. fn(response) } -func registerNothing(_ service.Host) []RouteHandler { +func registerNothing(_ service.Host) http.Handler { return nil } -func makeSingleHandler(path string, fn func(http.ResponseWriter, *http.Request)) []RouteHandler { - return []RouteHandler{ - { - Path: path, - Handler: http.HandlerFunc(fn), - }, - } -} - -func registerTracerCheckHandler(host service.Host) []RouteHandler { - return makeSingleHandler("/", func(_ http.ResponseWriter, r *http.Request) { +func registerTracerCheckHandler() http.HandlerFunc { + return func(_ http.ResponseWriter, r *http.Request) { span := opentracing.SpanFromContext(r.Context()) if span == nil { panic(fmt.Sprintf("Intentional panic, invalid span: %v", span)) @@ -207,30 +161,11 @@ func registerTracerCheckHandler(host service.Host) []RouteHandler { opentracing.GlobalTracer(), )) } - }) -} - -func registerCustomHealth(_ service.Host) []RouteHandler { - return makeSingleHandler("/health", func(w http.ResponseWriter, _ *http.Request) { - io.WriteString(w, "not ok") - }) -} - -func registerPanic(_ service.Host) []RouteHandler { - return makeSingleHandler("/", func(_ http.ResponseWriter, r *http.Request) { - panic("Intentional panic for:" + r.URL.Path) - }) -} - -func fakeInbound() InboundMiddlewareFunc { - return func(w http.ResponseWriter, r *http.Request, next http.Handler) { - io.WriteString(w, "inbound middleware is executed") - next.ServeHTTP(w, r) } } -func userPanicInbound() InboundMiddlewareFunc { - return func(_ http.ResponseWriter, r *http.Request, _ http.Handler) { +func registerPanic() http.HandlerFunc { + return func(_ http.ResponseWriter, r *http.Request) { panic("Intentional panic for:" + r.URL.Path) } } diff --git a/modules/uhttp/middleware.go b/modules/uhttp/middleware.go index 315aeecaa..ebcfecf2b 100644 --- a/modules/uhttp/middleware.go +++ b/modules/uhttp/middleware.go @@ -35,91 +35,61 @@ import ( const _panicResponse = "Server Error" -// InboundMiddleware applies inbound middleware on requests or responses such as -// adding tracing to the context. -type InboundMiddleware interface { - Handle(w http.ResponseWriter, r *http.Request, next http.Handler) -} - -// InboundMiddlewareFunc is an adaptor to call normal functions to apply inbound middleware. -type InboundMiddlewareFunc func(w http.ResponseWriter, r *http.Request, next http.Handler) - -// Handle implements Handle from the InboundMiddleware interface and simply delegates to the function -func (f InboundMiddlewareFunc) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - f(w, r, next) -} - -type contextInbound struct { - log *zap.Logger -} - -func (f contextInbound) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - next.ServeHTTP(w, r.WithContext(r.Context())) -} - -type tracingInbound struct{} +func tracingInbound(next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + operationName := r.Method + carrier := opentracing.HTTPHeadersCarrier(r.Header) + spanCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, carrier) + if err != nil && err != opentracing.ErrSpanContextNotFound { + ulog.Logger(ctx).Warn("Malformed inbound tracing context: ", zap.Error(err)) + } + span := opentracing.GlobalTracer().StartSpan(operationName, ext.RPCServerOption(spanCtx)) + ext.HTTPUrl.Set(span, r.URL.String()) + defer span.Finish() -func (f tracingInbound) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - ctx := r.Context() - operationName := r.Method - carrier := opentracing.HTTPHeadersCarrier(r.Header) - spanCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, carrier) - if err != nil && err != opentracing.ErrSpanContextNotFound { - ulog.Logger(ctx).Warn("Malformed inbound tracing context: ", zap.Error(err)) + ctx = opentracing.ContextWithSpan(ctx, span) + next.ServeHTTP(w, r.WithContext(ctx)) } - span := opentracing.GlobalTracer().StartSpan(operationName, ext.RPCServerOption(spanCtx)) - ext.HTTPUrl.Set(span, r.URL.String()) - defer span.Finish() - - ctx = opentracing.ContextWithSpan(ctx, span) - next.ServeHTTP(w, r.WithContext(ctx)) } -// authorizationInbound authorizes services based on configuration -type authorizationInbound struct { - authClient auth.Client - statsClient *statsClient -} - -func (f authorizationInbound) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - if err := f.authClient.Authorize(r.Context()); err != nil { - f.statsClient.HTTPAuthFailCounter().Inc(1) - ulog.Logger(r.Context()).Error(auth.ErrAuthorization, zap.Error(err)) - http.Error(w, fmt.Sprintf("Unauthorized access: %+v", err), http.StatusUnauthorized) - return +func authorizationInbound(next http.Handler, authClient auth.Client, statsClient *statsClient) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := authClient.Authorize(r.Context()); err != nil { + statsClient.HTTPAuthFailCounter().Inc(1) + ulog.Logger(r.Context()).Error(auth.ErrAuthorization, zap.Error(err)) + http.Error(w, fmt.Sprintf("Unauthorized access: %+v", err), http.StatusUnauthorized) + return + } + next.ServeHTTP(w, r) } - next.ServeHTTP(w, r) } // panicInbound handles any panics and return an error // panic inbound middleware should be added at the end of middleware chain to catch panics -type panicInbound struct { - statsClient *statsClient -} - -func (f panicInbound) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - ctx := r.Context() - defer func() { - if err := recover(); err != nil { - ulog.Logger(ctx).Error("Panic recovered serving request", - zap.Error(errors.Errorf("panic in handler: %+v", err)), - zap.Stringer("url", r.URL), - ) - f.statsClient.HTTPPanicCounter().Inc(1) - http.Error(w, _panicResponse, http.StatusInternalServerError) - } - }() - next.ServeHTTP(w, r) +func panicInbound(next http.Handler, statsClient *statsClient) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + defer func() { + if err := recover(); err != nil { + ulog.Logger(ctx).Error("Panic recovered serving request", + zap.Error(errors.Errorf("panic in handler: %+v", err)), + zap.Stringer("url", r.URL), + ) + statsClient.HTTPPanicCounter().Inc(1) + http.Error(w, _panicResponse, http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + } } // metricsInbound adds any default metrics related to HTTP -type metricsInbound struct { - statsClient *statsClient -} - -func (f metricsInbound) Handle(w http.ResponseWriter, r *http.Request, next http.Handler) { - stopwatch := f.statsClient.HTTPMethodTimer().Timer(r.Method).Start() - defer stopwatch.Stop() - defer f.statsClient.HTTPStatusCountScope().Tagged(map[string]string{_tagStatus: w.Header().Get("Status")}).Counter("total").Inc(1) - next.ServeHTTP(w, r) +func metricsInbound(next http.Handler, statsClient *statsClient) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + stopwatch := statsClient.HTTPMethodTimer().Timer(r.Method).Start() + defer stopwatch.Stop() + defer statsClient.HTTPStatusCountScope().Tagged(map[string]string{_tagStatus: w.Header().Get("Status")}).Counter("total").Inc(1) + next.ServeHTTP(w, r) + } } diff --git a/modules/uhttp/middleware_test.go b/modules/uhttp/middleware_test.go index 8e45d3a26..fac3da633 100644 --- a/modules/uhttp/middleware_test.go +++ b/modules/uhttp/middleware_test.go @@ -27,6 +27,7 @@ import ( "strings" "testing" + "go.uber.org/fx/auth" "go.uber.org/fx/service" "go.uber.org/fx/testutils" "go.uber.org/fx/tracing" @@ -89,7 +90,7 @@ func TestDefaultMiddlewareWithNopHostAuthFailure(t *testing.T) { } // setup - host := service.NopHostAuthFailure() + host := service.NopHost() t.Run("parallel group", func(t *testing.T) { for _, tt := range tests { @@ -121,8 +122,7 @@ func TestDefaultInboundMiddlewareWithNopHostConfigured(t *testing.T) { } func testInboundMiddlewareChain(t *testing.T, host service.Host) { - chain := newInboundMiddlewareChainBuilder().AddMiddleware([]InboundMiddleware{}...).Build(getNopHandler()) - response := testServeHTTP(chain) + response := testServeHTTP(getNopHandler()) assert.True(t, strings.Contains(response.Body.String(), "inbound middleware ok")) } @@ -141,9 +141,8 @@ func testTracingInboundWithLogs(t *testing.T) { defer closer.Close() opentracing.InitGlobalTracer(tracer) defer opentracing.InitGlobalTracer(opentracing.NoopTracer{}) + response := testServeHTTP(tracingInbound(getNopHandler())) - chain := newInboundMiddlewareChainBuilder().AddMiddleware([]InboundMiddleware{contextInbound{loggerWithZap}, tracingInbound{}}...).Build(getNopHandler()) - response := testServeHTTP(chain) assert.Contains(t, response.Body.String(), "inbound middleware ok") assert.True(t, len(buf.Lines()) > 0) var tracecount = 0 @@ -163,34 +162,18 @@ func testTracingInboundWithLogs(t *testing.T) { } func testInboundTraceInboundAuthChain(t *testing.T, host service.Host) { - chain := newInboundMiddlewareChainBuilder().AddMiddleware( - tracingInbound{}, - authorizationInbound{ - authClient: host.AuthClient(), - statsClient: newStatsClient(host.Metrics()), - }).Build(getNopHandler()) - - response := testServeHTTP(chain) + response := testServeHTTP(authorizationInbound(tracingInbound(getNopHandler()), auth.NopClient, newStatsClient(host.Metrics()))) assert.Contains(t, response.Body.String(), "inbound middleware ok") } func testInboundMiddlewareChainAuthFailure(t *testing.T, host service.Host) { - chain := newInboundMiddlewareChainBuilder().AddMiddleware( - tracingInbound{}, - authorizationInbound{ - authClient: host.AuthClient(), - statsClient: newStatsClient(host.Metrics()), - }).Build(getNopHandler()) - response := testServeHTTP(chain) + response := testServeHTTP(authorizationInbound(tracingInbound(getNopHandler()), auth.FailureClient, newStatsClient(host.Metrics()))) assert.Equal(t, response.Body.String(), "Unauthorized access: Error authorizing the service\n") assert.Equal(t, 401, response.Code) } func testPanicInbound(t *testing.T, host service.Host) { - chain := newInboundMiddlewareChainBuilder().AddMiddleware( - panicInbound{newStatsClient(host.Metrics())}, - ).Build(getPanicHandler()) - response := testServeHTTP(chain) + response := testServeHTTP(panicInbound(getPanicHandler(), newStatsClient(host.Metrics()))) assert.Equal(t, response.Body.String(), _panicResponse+"\n") assert.Equal(t, http.StatusInternalServerError, response.Code) @@ -201,10 +184,7 @@ func testPanicInbound(t *testing.T, host service.Host) { } func testMetricsInbound(t *testing.T, host service.Host) { - chain := newInboundMiddlewareChainBuilder().AddMiddleware( - metricsInbound{newStatsClient(host.Metrics())}, - ).Build(getNopHandler()) - response := testServeHTTP(chain) + response := testServeHTTP(metricsInbound(getNopHandler(), newStatsClient(host.Metrics()))) assert.Contains(t, response.Body.String(), "inbound middleware ok") testScope := host.Metrics() @@ -215,7 +195,7 @@ func testMetricsInbound(t *testing.T, host service.Host) { assert.NotNil(t, timers["GET"].Values()) } -func testServeHTTP(chain inboundMiddlewareChain) *httptest.ResponseRecorder { +func testServeHTTP(chain http.Handler) *httptest.ResponseRecorder { request := httptest.NewRequest("", "http://middleware", nil) response := httptest.NewRecorder() chain.ServeHTTP(response, request) diff --git a/modules/uhttp/middlewarechain_builder.go b/modules/uhttp/middlewarechain_builder.go deleted file mode 100644 index 17bb1600f..000000000 --- a/modules/uhttp/middlewarechain_builder.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package uhttp - -import ( - "net/http" - - "go.uber.org/fx/auth" - - "go.uber.org/zap" -) - -type inboundMiddlewareChain struct { - currentMiddleware int - finalHandler http.Handler - middleware []InboundMiddleware -} - -func (fc inboundMiddlewareChain) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if fc.currentMiddleware == len(fc.middleware) { - fc.finalHandler.ServeHTTP(w, r) - } else { - middleware := fc.middleware[fc.currentMiddleware] - fc.currentMiddleware++ - middleware.Handle(w, r, fc) - } -} - -type inboundMiddlewareChainBuilder struct { - finalHandler http.Handler - middleware []InboundMiddleware -} - -func defaultInboundMiddlewareChainBuilder(log *zap.Logger, authClient auth.Client, statsClient *statsClient) inboundMiddlewareChainBuilder { - mcb := newInboundMiddlewareChainBuilder() - return mcb.AddMiddleware( - contextInbound{log}, - panicInbound{statsClient}, - metricsInbound{statsClient}, - tracingInbound{}, - authorizationInbound{authClient, statsClient}, - ) -} - -// newInboundMiddlewareChainBuilder creates an empty middlewareChainBuilder for setup -func newInboundMiddlewareChainBuilder() inboundMiddlewareChainBuilder { - return inboundMiddlewareChainBuilder{} -} - -func (m inboundMiddlewareChainBuilder) AddMiddleware(middleware ...InboundMiddleware) inboundMiddlewareChainBuilder { - m.middleware = append(m.middleware, middleware...) - return m -} - -func (m inboundMiddlewareChainBuilder) Build(finalHandler http.Handler) inboundMiddlewareChain { - return inboundMiddlewareChain{ - middleware: m.middleware, - finalHandler: finalHandler, - } -} diff --git a/modules/uhttp/router_test.go b/modules/uhttp/router_test.go deleted file mode 100644 index e3224eac9..000000000 --- a/modules/uhttp/router_test.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package uhttp - -import ( - "fmt" - "io/ioutil" - "net" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.uber.org/fx/service" -) - -func serve(t *testing.T, h http.Handler) net.Listener { - l, err := net.Listen("tcp", "127.0.0.1:0") - require.Nil(t, err) - - go http.Serve(l, h) - return l -} - -func withRouter(t *testing.T, f func(r *Router, l net.Listener)) { - r := NewRouter(service.NopHost()) - l := serve(t, r) - defer l.Close() - r.Handle("/foo/baz/quokka", - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("hello")) - })) - r.Handle("/foo/bar/quokka", - http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("world")) - })) - f(r, l) -} - -func TestRouting_ExpectSecond(t *testing.T) { - withRouter(t, func(r *Router, l net.Listener) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/foo/bar/quokka", l.Addr().String()), nil) - require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - assert.Equal(t, "world", string(body)) - }) -} - -func TestRouting_ExpectFirst(t *testing.T) { - withRouter(t, func(r *Router, l net.Listener) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://%s/foo/baz/quokka", l.Addr().String()), nil) - require.NoError(t, err) - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - assert.Equal(t, "hello", string(body)) - }) -} diff --git a/modules/uhttp/routes.go b/modules/uhttp/routes.go deleted file mode 100644 index e5ad0fd28..000000000 --- a/modules/uhttp/routes.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package uhttp - -import ( - "net/http" - - "github.com/gorilla/mux" -) - -// FromGorilla turns a gorilla mux route into an UberFx route -func FromGorilla(r *mux.Route) Route { - return Route{ - r: r, - } -} - -// A RouteHandler is an HTTP handler for a single route -type RouteHandler struct { - Path string - Handler http.Handler -} - -// NewRouteHandler creates a route handler -func NewRouteHandler(path string, handler http.Handler) RouteHandler { - return RouteHandler{ - Path: path, - Handler: handler, - } -} - -// A Route represents a handler for HTTP requests, with restrictions -type Route struct { - r *mux.Route -} - -// GorillaMux returns the underlying mux if you need to use it directly -func (r Route) GorillaMux() *mux.Route { - return r.r -} - -// Headers allows easy enforcement of headers -func (r Route) Headers(headerPairs ...string) Route { - return Route{ - r.r.Headers(headerPairs...), - } -} - -// Methods allows easy enforcement of metthods (HTTP Verbs) -func (r Route) Methods(methods ...string) Route { - return Route{ - r.r.Methods(methods...), - } -} diff --git a/modules/uhttp/routes_test.go b/modules/uhttp/routes_test.go deleted file mode 100644 index c0f3434f7..000000000 --- a/modules/uhttp/routes_test.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package uhttp - -import ( - "fmt" - "net/http" - "testing" - - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" -) - -func TestFromGorilla_OK(t *testing.T) { - r := mux.NewRouter() - route := r.Headers("foo", "bar") - f := FromGorilla(route) - assert.Equal(t, f.r, route) -} - -func TestNewRouteHandler(t *testing.T) { - rh := NewRouteHandler("/", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - fmt.Fprintf(w, "Hi\n") - })) - - assert.Equal(t, rh.Path, "/") -} - -func TestGorillaMux_OK(t *testing.T) { - r := mux.NewRouter() - route := r.Path("/foo") - ours := FromGorilla(route) - rounded := ours.GorillaMux() - assert.Equal(t, route, rounded) -} - -func TestHeaders_OK(t *testing.T) { - r := mux.NewRouter() - route := Route{r.Path("/foo")} - withHeaders := route.Headers("foo", "bar") - assert.NotNil(t, withHeaders.r) -} - -func TestMethods_OK(t *testing.T) { - r := mux.NewRouter() - route := Route{r.Path("/foo")} - withMethods := route.Methods("GET") - assert.NotNil(t, withMethods.r) -} diff --git a/modules/uhttp/uhttpclient/client.go b/modules/uhttp/uhttpclient/client.go index 72d1c0853..685123417 100644 --- a/modules/uhttp/uhttpclient/client.go +++ b/modules/uhttp/uhttpclient/client.go @@ -25,23 +25,25 @@ import ( "time" "go.uber.org/fx/auth" + "go.uber.org/fx/config" "github.com/opentracing/opentracing-go" + "github.com/uber-go/tally" ) // New creates an http.Client that includes 2 extra outbound middleware: tracing and auth // they are going to be applied in following order: tracing, auth, remaining outbound middleware // and only if all of them passed the request is going to be send. // Client is safe to use by multiple go routines, if global tracer is not changed. -func New(tracer opentracing.Tracer, info auth.CreateAuthInfo, middleware ...OutboundMiddleware) *http.Client { +func New(tracer opentracing.Tracer, config config.Provider, scope tally.Scope, middleware ...OutboundMiddleware) *http.Client { defaultMiddleware := make([]OutboundMiddleware, 0, 2+len(middleware)) if tracer != nil { defaultMiddleware = append(defaultMiddleware, tracingOutbound(tracer)) - if info != nil { - defaultMiddleware = append(defaultMiddleware, authenticationOutbound(info)) - } } - + if config.Get("auth").HasValue() { + authClient := auth.Load(config, scope) + defaultMiddleware = append(defaultMiddleware, authenticationOutbound(config, authClient)) + } defaultMiddleware = append(defaultMiddleware, middleware...) return &http.Client{ Transport: newExecutionChain(defaultMiddleware, http.DefaultTransport), diff --git a/modules/uhttp/uhttpclient/client_middleware.go b/modules/uhttp/uhttpclient/client_middleware.go index db6a0b323..cf1e97fbb 100644 --- a/modules/uhttp/uhttpclient/client_middleware.go +++ b/modules/uhttp/uhttpclient/client_middleware.go @@ -85,9 +85,8 @@ func tracingOutbound(tracer opentracing.Tracer) OutboundMiddlewareFunc { // authenticationOutbound on client side calls authenticate, and gets a claim that client is who they say they are // We only authorize with the claim on server side -func authenticationOutbound(info auth.CreateAuthInfo) OutboundMiddlewareFunc { - authClient := auth.Load(info) - serviceName := info.Config().Get(config.ServiceNameKey).AsString() +func authenticationOutbound(cfg config.Provider, authClient auth.Client) OutboundMiddlewareFunc { + serviceName := cfg.Get(config.ServiceNameKey).AsString() return func(req *http.Request, next Executor) (resp *http.Response, err error) { ctx := req.Context() // Client needs to know what service it is to authenticate diff --git a/modules/uhttp/uhttpclient/client_middleware_benchmark_test.go b/modules/uhttp/uhttpclient/client_middleware_benchmark_test.go index 65fae606a..569ac71c2 100644 --- a/modules/uhttp/uhttpclient/client_middleware_benchmark_test.go +++ b/modules/uhttp/uhttpclient/client_middleware_benchmark_test.go @@ -26,6 +26,7 @@ import ( "testing" "go.uber.org/fx/auth" + "go.uber.org/fx/config" "go.uber.org/fx/tracing" "github.com/opentracing/opentracing-go" @@ -47,11 +48,12 @@ func BenchmarkClientMiddleware(b *testing.B) { } defer closer.Close() + cfg := config.NewYAMLProviderFromBytes(_testYaml) bm := map[string][]OutboundMiddleware{ "empty": {}, "tracing": {tracingOutbound(tracer)}, - "auth": {authenticationOutbound(fakeAuthInfo{_testYaml})}, - "default": {tracingOutbound(tracer), authenticationOutbound(fakeAuthInfo{_testYaml})}, + "auth": {authenticationOutbound(cfg, auth.Load(cfg, tally.NoopScope))}, + "default": {tracingOutbound(tracer), authenticationOutbound(cfg, auth.Load(cfg, tally.NoopScope))}, } for name, middleware := range bm { diff --git a/modules/uhttp/uhttpclient/client_middleware_test.go b/modules/uhttp/uhttpclient/client_middleware_test.go index 843b04509..69e4946ff 100644 --- a/modules/uhttp/uhttpclient/client_middleware_test.go +++ b/modules/uhttp/uhttpclient/client_middleware_test.go @@ -117,15 +117,17 @@ func TestExecutionChainOutboundMiddleware_AuthContextPropagationFailure(t *testi } func getExecChainWithAuth(t *testing.T) executionChain { + cfg := config.NewYAMLProviderFromBytes(_testYaml) return newExecutionChain( - []OutboundMiddleware{authenticationOutbound(fakeAuthInfo{_testYaml})}, + []OutboundMiddleware{authenticationOutbound(cfg, auth.Load(cfg, tally.NoopScope))}, contextPropagationTransport{t}, ) } func TestOutboundMiddlewareWithTracerErrors(t *testing.T) { + cfg := config.NewYAMLProviderFromBytes(_testYaml) testCases := map[string]OutboundMiddleware{ - "auth": authenticationOutbound(fakeAuthInfo{_testYaml}), + "auth": authenticationOutbound(cfg, auth.Load(cfg, tally.NoopScope)), "tracing": tracingOutbound(opentracing.NoopTracer{}), } @@ -156,22 +158,6 @@ func TestOutboundMiddlewareWithTracerErrors(t *testing.T) { } } -type fakeAuthInfo struct { - yaml []byte -} - -func (f fakeAuthInfo) Config() config.Provider { - return config.NewYAMLProviderFromBytes(f.yaml) -} - -func (f fakeAuthInfo) Logger() *zap.Logger { - return zap.NewNop() -} - -func (f fakeAuthInfo) Metrics() tally.Scope { - return tally.NoopScope -} - type contextPropagationTransport struct { *testing.T } diff --git a/modules/uhttp/uhttpclient/client_test.go b/modules/uhttp/uhttpclient/client_test.go index f923d3afe..2a7b53aee 100644 --- a/modules/uhttp/uhttpclient/client_test.go +++ b/modules/uhttp/uhttpclient/client_test.go @@ -25,18 +25,21 @@ import ( "net/http/httptest" "testing" - "go.uber.org/fx/auth" + "go.uber.org/fx/config" "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/uber-go/tally" ) var ( _testYaml = []byte(` name: test +auth: + service: test `) - _testClient = New(opentracing.NoopTracer{}, fakeAuthInfo{yaml: _testYaml}) + _testClient = New(opentracing.NoopTracer{}, config.NewYAMLProviderFromBytes(_testYaml), tally.NoopScope) ) func TestNew(t *testing.T) { @@ -49,7 +52,7 @@ func TestNew(t *testing.T) { func TestNew_Panic(t *testing.T) { t.Parallel() assert.Panics(t, func() { - New(opentracing.NoopTracer{}, fakeAuthInfo{yaml: []byte(``)}) + New(opentracing.NoopTracer{}, nil, tally.NoopScope) }) } @@ -81,7 +84,7 @@ func TestClientGetTwiceExecutesAllMiddleware(t *testing.T) { return next.Execute(r) } - cl := New(opentracing.NoopTracer{}, fakeAuthInfo{yaml: _testYaml}, f) + cl := New(opentracing.NoopTracer{}, config.NewYAMLProviderFromBytes(_testYaml), tally.NoopScope, f) resp, err := cl.Get(svr.URL) checkOKResponse(t, resp, err) require.Equal(t, 1, count) @@ -129,17 +132,19 @@ func TestClientPostForm(t *testing.T) { func TestClientWithNilParameters(t *testing.T) { t.Parallel() svr := startServer() + cfg := config.NewYAMLProviderFromBytes([]byte(``)) tests := map[string]struct { - info auth.CreateAuthInfo + cfg config.Provider + scope tally.Scope tracer opentracing.Tracer }{ - "NilInfo": {info: nil, tracer: opentracing.NoopTracer{}}, - "NilTracer": {info: fakeAuthInfo{yaml: _testYaml}, tracer: nil}, - "BothNil": {}, + "NilScope": {cfg: cfg, scope: nil, tracer: opentracing.NoopTracer{}}, + "NilTracer": {cfg: cfg, scope: tally.NoopScope, tracer: nil}, + "TracerScopeNil": {cfg: cfg}, } for name, params := range tests { t.Run(name, func(t *testing.T) { - client := New(params.tracer, params.info) + client := New(params.tracer, params.cfg, params.scope) resp, err := client.Head(svr.URL) checkOKResponse(t, resp, err) }) diff --git a/modules/yarpc/middleware.go b/modules/yarpc/middleware.go index 335409319..33da4140f 100644 --- a/modules/yarpc/middleware.go +++ b/modules/yarpc/middleware.go @@ -24,48 +24,16 @@ import ( "context" "go.uber.org/fx/auth" - "go.uber.org/fx/service" "go.uber.org/fx/ulog" "go.uber.org/zap" - "github.com/pkg/errors" "go.uber.org/yarpc/api/transport" ) const _panicResponse = "Server Error" -type contextInboundMiddleware struct { - statsClient *statsClient -} - -func (f contextInboundMiddleware) Handle( - ctx context.Context, - req *transport.Request, - resw transport.ResponseWriter, - handler transport.UnaryHandler, -) error { - stopwatch := f.statsClient.RPCHandleTimer(). - Tagged(map[string]string{_tagProcedure: req.Procedure}). - Timer(req.Procedure). - Start() - defer stopwatch.Stop() - - return handler.Handle(ctx, req, resw) -} - -type contextOnewayInboundMiddleware struct{} - -func (f contextOnewayInboundMiddleware) HandleOneway( - ctx context.Context, - req *transport.Request, - handler transport.OnewayHandler, -) error { - return handler.HandleOneway(ctx, req) -} - type authInboundMiddleware struct { - service.Host - statsClient *statsClient + authClient auth.Client } func (a authInboundMiddleware) Handle( @@ -74,7 +42,7 @@ func (a authInboundMiddleware) Handle( resw transport.ResponseWriter, handler transport.UnaryHandler, ) error { - fxctx, err := authorize(ctx, a.Host, a.statsClient) + fxctx, err := authorize(ctx, a.authClient) if err != nil { return err } @@ -82,8 +50,7 @@ func (a authInboundMiddleware) Handle( } type authOnewayInboundMiddleware struct { - service.Host - statsClient *statsClient + authClient auth.Client } func (a authOnewayInboundMiddleware) HandleOneway( @@ -91,16 +58,15 @@ func (a authOnewayInboundMiddleware) HandleOneway( req *transport.Request, handler transport.OnewayHandler, ) error { - fxctx, err := authorize(ctx, a.Host, a.statsClient) + fxctx, err := authorize(ctx, a.authClient) if err != nil { return err } return handler.HandleOneway(fxctx, req) } -func authorize(ctx context.Context, host service.Host, statsClient *statsClient) (context.Context, error) { - if err := host.AuthClient().Authorize(ctx); err != nil { - statsClient.RPCAuthFailCounter().Inc(1) +func authorize(ctx context.Context, authClient auth.Client) (context.Context, error) { + if err := authClient.Authorize(ctx); err != nil { ulog.Logger(ctx).Error(auth.ErrAuthorization, zap.Error(err)) // TODO(anup): GFM-255 update returned error to transport.BadRequestError (user error than server error) // https://github.com/yarpc/yarpc-go/issues/687 @@ -108,42 +74,3 @@ func authorize(ctx context.Context, host service.Host, statsClient *statsClient) } return ctx, nil } - -type panicInboundMiddleware struct { - statsClient *statsClient -} - -func (p panicInboundMiddleware) Handle( - ctx context.Context, - req *transport.Request, - resw transport.ResponseWriter, - handler transport.UnaryHandler, -) error { - defer panicRecovery(ctx, p.statsClient) - return handler.Handle(ctx, req, resw) -} - -type panicOnewayInboundMiddleware struct { - statsClient *statsClient -} - -func (p panicOnewayInboundMiddleware) HandleOneway( - ctx context.Context, - req *transport.Request, - handler transport.OnewayHandler, -) error { - defer panicRecovery(ctx, p.statsClient) - return handler.HandleOneway(ctx, req) -} - -func panicRecovery(ctx context.Context, statsClient *statsClient) { - if err := recover(); err != nil { - statsClient.RPCPanicCounter().Inc(1) - ulog.Logger(ctx).Error("Panic recovered serving request", - zap.Error(errors.Errorf("panic in handler: %+v", err)), - ) - // rethrow panic back to yarpc - // before https://github.com/yarpc/yarpc-go/issues/734 fixed, throw a generic error. - panic(_panicResponse) - } -} diff --git a/modules/yarpc/middleware_test.go b/modules/yarpc/middleware_test.go index bc6c6f20f..b9fb81f7e 100644 --- a/modules/yarpc/middleware_test.go +++ b/modules/yarpc/middleware_test.go @@ -25,48 +25,16 @@ import ( "errors" "testing" - "go.uber.org/fx/service" - "go.uber.org/fx/testutils" - "go.uber.org/fx/testutils/tracing" + "go.uber.org/fx/auth" "go.uber.org/fx/ulog" "go.uber.org/thriftrw/wire" "go.uber.org/yarpc/api/transport" - "github.com/opentracing/opentracing-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/uber-go/tally" - "go.uber.org/zap" "go.uber.org/zap/zaptest" ) -func TestInboundMiddleware_Context(t *testing.T) { - host := service.NopHost() - unary := contextInboundMiddleware{newStatsClient(host.Metrics())} - testutils.WithInMemoryLogger(t, nil, func(loggerWithZap *zap.Logger, buf *zaptest.Buffer) { - defer ulog.SetLogger(loggerWithZap)() - tracing.WithSpan(t, loggerWithZap, func(span opentracing.Span) { - ctx := opentracing.ContextWithSpan(context.Background(), span) - err := unary.Handle(ctx, &transport.Request{}, nil, &fakeUnary{t: t}) - require.Contains(t, err.Error(), "handle") - checkLogForTrace(t, buf) - }) - }) -} - -func TestOnewayInboundMiddleware_Context(t *testing.T) { - oneway := contextOnewayInboundMiddleware{} - testutils.WithInMemoryLogger(t, nil, func(loggerWithZap *zap.Logger, buf *zaptest.Buffer) { - defer ulog.SetLogger(loggerWithZap)() - tracing.WithSpan(t, loggerWithZap, func(span opentracing.Span) { - ctx := opentracing.ContextWithSpan(context.Background(), span) - err := oneway.HandleOneway(ctx, &transport.Request{}, &fakeOneway{t: t}) - require.Contains(t, err.Error(), "oneway handle") - checkLogForTrace(t, buf) - }) - }) -} - func checkLogForTrace(t *testing.T, buf *zaptest.Buffer) { require.True(t, len(buf.Lines()) > 0) for _, line := range buf.Lines() { @@ -77,64 +45,30 @@ func checkLogForTrace(t *testing.T, buf *zaptest.Buffer) { } func TestInboundMiddleware_auth(t *testing.T) { - host := service.NopHost() - unary := authInboundMiddleware{host, newStatsClient(host.Metrics())} + unary := authInboundMiddleware{auth.NopClient} err := unary.Handle(context.Background(), &transport.Request{}, nil, &fakeUnary{t: t}) assert.EqualError(t, err, "handle") } func TestInboundMiddleware_authFailure(t *testing.T) { - host := service.NopHostAuthFailure() - unary := authInboundMiddleware{host, newStatsClient(host.Metrics())} + unary := authInboundMiddleware{auth.FailureClient} err := unary.Handle(context.Background(), &transport.Request{}, nil, &fakeUnary{t: t}) assert.EqualError(t, err, "Error authorizing the service") } func TestOnewayInboundMiddleware_auth(t *testing.T) { - oneway := authOnewayInboundMiddleware{ - Host: service.NopHost(), - } + oneway := authOnewayInboundMiddleware{auth.NopClient} err := oneway.HandleOneway(context.Background(), &transport.Request{}, &fakeOneway{t: t}) assert.EqualError(t, err, "oneway handle") } func TestOnewayInboundMiddleware_authFailure(t *testing.T) { - host := service.NopHostAuthFailure() - oneway := authOnewayInboundMiddleware{host, newStatsClient(host.Metrics())} + oneway := authOnewayInboundMiddleware{auth.FailureClient} err := oneway.HandleOneway(context.Background(), &transport.Request{}, &fakeOneway{t: t}) assert.EqualError(t, err, "Error authorizing the service") } -func TestInboundMiddleware_panic(t *testing.T) { - host := service.NopHost() - testScope := host.Metrics() - statsClient := newStatsClient(testScope) - - defer testPanicHandler(t, testScope) - unary := panicInboundMiddleware{statsClient} - unary.Handle(context.Background(), &transport.Request{}, nil, &alwaysPanicUnary{}) -} - -func TestOnewayInboundMiddleware_panic(t *testing.T) { - host := service.NopHost() - testScope := host.Metrics() - statsClient := newStatsClient(testScope) - - defer testPanicHandler(t, testScope) - oneway := panicOnewayInboundMiddleware{statsClient} - oneway.HandleOneway(context.Background(), &transport.Request{}, &alwaysPanicOneway{}) -} - -func testPanicHandler(t *testing.T, testScope tally.Scope) { - r := recover() - assert.EqualValues(t, r, _panicResponse) - - snapshot := testScope.(tally.TestScope).Snapshot() - counters := snapshot.Counters() - assert.True(t, counters["panic"].Value() > 0) -} - type fakeEnveloper struct { serviceName string } diff --git a/modules/yarpc/stats.go b/modules/yarpc/stats.go deleted file mode 100644 index 4de3821e9..000000000 --- a/modules/yarpc/stats.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2017 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -package yarpc - -import "github.com/uber-go/tally" - -const ( - //_tagModule is module tag for metrics - _tagModule = "module" - // _tagType is either request or response - _tagType = "type" - // _tagProcedure is the procedure name - _tagProcedure = "procedure" - //_tagMiddleware is middleware type - _tagMiddleware = "middleware" -) - -var ( - rpcTags = map[string]string{ - _tagModule: "rpc", - _tagType: "request", - } -) - -type statsClient struct { - rpcAuthFailCounter tally.Counter - rpcHandleTimer tally.Scope - rpcPanicCounter tally.Counter -} - -func newStatsClient(scope tally.Scope) *statsClient { - rpcTagsScope := scope.Tagged(rpcTags) - return &statsClient{ - rpcTagsScope.Tagged(map[string]string{_tagMiddleware: "auth"}).Counter("fail"), - rpcTagsScope.Tagged(rpcTags), - rpcTagsScope.Counter("panic"), - } -} - -// RPCAuthFailCounter counts auth failures -func (c *statsClient) RPCAuthFailCounter() tally.Counter { - return c.rpcAuthFailCounter -} - -// RPCHandleTimer is a turnaround time for rpc handler -func (c *statsClient) RPCHandleTimer() tally.Scope { - return c.rpcHandleTimer -} - -// RPCPanicCounter counts panics occurred for rpc handler -func (c *statsClient) RPCPanicCounter() tally.Counter { - return c.rpcPanicCounter -} diff --git a/modules/yarpc/thriftrw-plugin-thriftsync/main.go b/modules/yarpc/thriftrw-plugin-thriftsync/main.go index 9a09fefbf..8b81f2eeb 100644 --- a/modules/yarpc/thriftrw-plugin-thriftsync/main.go +++ b/modules/yarpc/thriftrw-plugin-thriftsync/main.go @@ -107,7 +107,7 @@ func (generator) Generate(req *api.GenerateServiceRequest) (*api.GenerateService f := NewUpdater(opts) if err := f.UpdateExistingHandlerFile(service, gofilePath, *_baseDir, *_handlerStructName); err != nil { return nil, err - } else if err = f.RefreshAll(service, gofilePath); err != nil { + } else if err = f.RefreshAll(service, gofilePath, *_handlerStructName); err != nil { return nil, err } } diff --git a/modules/yarpc/thriftrw-plugin-thriftsync/update.go b/modules/yarpc/thriftrw-plugin-thriftsync/update.go index da70b9c8e..6f9efd734 100644 --- a/modules/yarpc/thriftrw-plugin-thriftsync/update.go +++ b/modules/yarpc/thriftrw-plugin-thriftsync/update.go @@ -121,7 +121,7 @@ func (u *Updater) compare(service *api.Service, filepath string, handlerDir stri // RefreshAll creates new funcs if they are missing from *.go file, and updates // all the existing functions from the idl. -func (u *Updater) RefreshAll(service *api.Service, filepath string) error { +func (u *Updater) RefreshAll(service *api.Service, filepath string, handlerStructName string) error { fset := token.NewFileSet() file, err := parser.ParseFile(fset, filepath, nil, parser.ParseComments) if err != nil { @@ -132,7 +132,7 @@ func (u *Updater) RefreshAll(service *api.Service, filepath string) error { switch x := n.(type) { case *ast.FuncDecl: if x.Name.Name == function.Name { - exp, err := u.createExpr(filepath, service, function) + exp, err := u.createExpr(filepath, service, function, handlerStructName) if err != nil { return false } @@ -163,8 +163,8 @@ func (u *Updater) RefreshAll(service *api.Service, filepath string) error { return err } -func (u *Updater) createExpr(filepath string, service *api.Service, f *api.Function) (*ast.File, error) { - buff, err := u.generateSingleFunction(filepath, service, f) +func (u *Updater) createExpr(filepath string, service *api.Service, f *api.Function, handlerStructName string) (*ast.File, error) { + buff, err := u.generateSingleFunction(filepath, service, f, handlerStructName) tmpf, err := ioutil.TempFile("", "tempbuff") if _, err := tmpf.Write(buff); err != nil { return nil, err @@ -176,12 +176,12 @@ func (u *Updater) createExpr(filepath string, service *api.Service, f *api.Funct return exp, err } -func (u *Updater) generateSingleFunction(goFilePath string, service *api.Service, f *api.Function) ([]byte, error) { - var funcs []*api.Function +func (u *Updater) generateSingleFunction(goFilePath string, service *api.Service, f *api.Function, handlerStructName string) ([]byte, error) { newData := &updatedData{ - Service: service, - Functions: append(funcs, f), + Service: service, } + newData.Functions = append(newData.Functions, f) + newData.HandlerStructName = handlerStructName return u.generate(goFilePath, newData) } diff --git a/modules/yarpc/yarpc.go b/modules/yarpc/yarpc.go index 0f029bb2b..81bced911 100644 --- a/modules/yarpc/yarpc.go +++ b/modules/yarpc/yarpc.go @@ -27,6 +27,7 @@ import ( "strconv" "sync" + "go.uber.org/fx/auth" "go.uber.org/fx/service" "go.uber.org/fx/ulog" @@ -86,11 +87,11 @@ func New(hookup ServiceCreateFunc, options ...ModuleOption) service.ModuleProvid // the lifecycle of all of the in/out bound traffic, so we will // register it in a dig.Graph provided with options/default graph. type Module struct { - host service.Host - statsClient *statsClient - config yarpcConfig - log *zap.Logger - controller *dispatcherController + authClient auth.Client + host service.Host + config yarpcConfig + log *zap.Logger + controller *dispatcherController } // ModuleOption is a function that configures module creation. @@ -131,14 +132,16 @@ func newModule( } } module := &Module{ - host: host, - statsClient: newStatsClient(host.Metrics()), - log: ulog.Logger(context.Background()).With(zap.String("module", host.ModuleName())), + host: host, + log: ulog.Logger(context.Background()).With(zap.String("module", host.ModuleName())), } if err := host.Config().Get("modules").Get(host.ModuleName()).Populate(&module.config); err != nil { return nil, errs.Wrap(err, "can't read inbounds") } + // TODO: pass in the auth client as part of module construction + module.authClient = auth.Load(host.Config(), host.Metrics()) + // iterate over inbounds transportsIn, err := prepareInbounds(module.config.Inbounds, host.Name()) if err != nil { @@ -183,7 +186,7 @@ func newModule( // Start begins serving requests with YARPC. func (m *Module) Start() error { // TODO(alsam) allow services to advertise with a name separate from the host name. - if err := m.controller.Start(m.host, m.statsClient); err != nil { + if err := m.controller.Start(m.authClient, m.host); err != nil { return errs.Wrap(err, "unable to start dispatcher") } m.log.Info("Module started") @@ -248,9 +251,9 @@ type dispatcherController struct { // 4. Start the dispatcher // // Once started the controller will not start the dispatcher again. -func (c *dispatcherController) Start(host service.Host, statsClient *statsClient) error { +func (c *dispatcherController) Start(authClient auth.Client, host service.Host) error { c.start.Do(func() { - c.addDefaultMiddleware(host, statsClient) + c.addDefaultMiddleware(authClient) var cfg yarpc.Config var err error @@ -319,17 +322,13 @@ func (c *dispatcherController) applyHandlers() error { } // Adds the default middleware: context propagation and auth. -func (c *dispatcherController) addDefaultMiddleware(host service.Host, statsClient *statsClient) { +func (c *dispatcherController) addDefaultMiddleware(authClient auth.Client) { cfg := yarpcConfig{ inboundMiddleware: []middleware.UnaryInbound{ - contextInboundMiddleware{statsClient}, - panicInboundMiddleware{statsClient}, - authInboundMiddleware{host, statsClient}, + authInboundMiddleware{authClient}, }, onewayInboundMiddleware: []middleware.OnewayInbound{ - contextOnewayInboundMiddleware{}, - panicOnewayInboundMiddleware{statsClient}, - authOnewayInboundMiddleware{host, statsClient}, + authOnewayInboundMiddleware{authClient}, }, } diff --git a/modules/yarpc/yarpc_test.go b/modules/yarpc/yarpc_test.go index 848e6bd56..abe5560ee 100644 --- a/modules/yarpc/yarpc_test.go +++ b/modules/yarpc/yarpc_test.go @@ -25,6 +25,7 @@ import ( "fmt" "testing" + "go.uber.org/fx/auth" "go.uber.org/fx/config" "go.uber.org/fx/service" "go.uber.org/yarpc" @@ -101,7 +102,7 @@ func TestDispatcher(t *testing.T) { c := dispatcherController{} host := service.NopHost() c.addConfig(yarpcConfig{transports: transports{inbounds: []transport.Inbound{}}}) - assert.NoError(t, c.Start(host, newStatsClient(host.Metrics()))) + assert.NoError(t, c.Start(auth.NopClient, host)) } func TestBindToBadPortReturnsError(t *testing.T) { @@ -114,8 +115,7 @@ func TestBindToBadPortReturnsError(t *testing.T) { } c.addConfig(cfg) - host := service.NopHost() - assert.Error(t, c.Start(host, newStatsClient(host.Metrics()))) + assert.Error(t, c.Start(auth.NopClient, service.NopHost())) } func TestMergeOfEmptyConfigCollectionReturnsError(t *testing.T) { @@ -124,7 +124,7 @@ func TestMergeOfEmptyConfigCollectionReturnsError(t *testing.T) { _, err := c.mergeConfig("test") assert.EqualError(t, err, "unable to merge empty configs") host := service.NopHost() - assert.EqualError(t, c.Start(host, newStatsClient(host.Metrics())), err.Error()) + assert.EqualError(t, c.Start(auth.NopClient, host), err.Error()) } func TestInboundPrint(t *testing.T) { diff --git a/service/README.md b/service/README.md index 67730c150..2f19f2cc4 100644 --- a/service/README.md +++ b/service/README.md @@ -65,14 +65,10 @@ func main() { } ``` -Which then allows us to set the roles either via a command line variable: +Which then allows us to set the roles either via a command line flags: -`export CONFIG__roles__0=worker` - -Or via the service parameters, we would activate in the following ways: - -* `./myservice` or `./myservice --roles "service,worker"`: Runs all modules -* `./myservice --roles "worker"`: Runs only the **Kakfa** module +* `./myservice` or `./myservice --roles=service,worker`: Runs all modules +* `./myservice --roles=worker`: Runs only the **Kafka** module * Etc... ## Options diff --git a/service/builder_test.go b/service/builder_test.go index 4c9795ea4..8d9151a06 100644 --- a/service/builder_test.go +++ b/service/builder_test.go @@ -23,15 +23,20 @@ package service import ( "errors" "testing" + "time" . "go.uber.org/fx/testutils" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/fx/config" ) var ( - nopModuleProvider = &StubModuleProvider{"nop", nopModule} - errModuleProvider = &StubModuleProvider{"err", errModule} + nopModuleProvider = &StubModuleProvider{"nop", nopModule} + errModuleProvider = &StubModuleProvider{"err", errModule} + startTimeoutProvider = &StubModuleProvider{"timeoutStart", timeoutStartModule} + stopTimeoutProvider = &StubModuleProvider{"timeoutStop", timeoutStopModule} ) func TestWithModules_OK(t *testing.T) { @@ -56,6 +61,68 @@ func TestWithModules_SkipsModulesBadInit(t *testing.T) { assert.Error(t, err, "Expected service name to be provided") } +func TestWithModules_StartTimeout(t *testing.T) { + cfg := config.NewStaticProvider(map[string]interface{}{ + "startTimeout": time.Microsecond, + "stopTimeout": time.Microsecond, + "name": "test", + "owner": "test@uber.com", + }) + + svc, err := WithModule(startTimeoutProvider). + WithModule(startTimeoutProvider). + WithOptions( + WithConfiguration(cfg), + ).Build() + + require.NoError(t, err) + + ctl := svc.StartAsync() + require.Error(t, ctl.ServiceError) + assert.Contains(t, ctl.ServiceError.Error(), "timeoutStart") + assert.Contains(t, ctl.ServiceError.Error(), `didn't start after "1µs"`) +} + +func TestWithModules_StopTimeout(t *testing.T) { + cfg := config.NewStaticProvider(map[string]interface{}{ + "startTimeout": time.Microsecond, + "stopTimeout": time.Microsecond, + "name": "test", + "owner": "test@uber.com", + }) + + svc, err := WithModule(stopTimeoutProvider). + WithModule(stopTimeoutProvider). + WithOptions( + WithConfiguration(cfg), + ).Build() + + require.NoError(t, err) + + ctl := svc.StartAsync() + require.NoError(t, ctl.ServiceError) + + err = svc.Stop("someReason", 1) + require.Error(t, err) + assert.Contains(t, err.Error(), "timeoutStop") + assert.Contains(t, err.Error(), `timedout after "1µs"`) +} + +func TestDefaultTimeouts(t *testing.T) { + svc, err := WithModule(stopTimeoutProvider). + WithModule(stopTimeoutProvider). + WithOptions( + WithConfiguration(StaticAppData(nil)), + ).Build() + + require.NoError(t, err) + m, ok := svc.(*manager) + require.True(t, ok, "expect manager returned by Build") + require.NotNil(t, m) + assert.Equal(t, 10*time.Second, m.StartTimeout) + assert.Equal(t, 10*time.Second, m.StopTimeout) +} + func nopModule(_ Host) (Module, error) { return nil, nil } @@ -63,3 +130,33 @@ func nopModule(_ Host) (Module, error) { func errModule(_ Host) (Module, error) { return nil, errors.New("intentional module creation failure") } + +func timeoutStartModule(_ Host) (Module, error) { + return timeoutStart{}, nil +} + +type timeoutStart struct{} + +func (timeoutStart) Start() error { + <-make(chan int) + return nil +} + +func (timeoutStart) Stop() error { + return nil +} + +func timeoutStopModule(_ Host) (Module, error) { + return timeoutStop{}, nil +} + +type timeoutStop struct{} + +func (timeoutStop) Start() error { + return nil +} + +func (timeoutStop) Stop() error { + <-make(chan int) + return nil +} diff --git a/service/doc.go b/service/doc.go index 4ac1fa1da..988abb9c5 100644 --- a/service/doc.go +++ b/service/doc.go @@ -86,15 +86,11 @@ // svc.Start() // } // -// Which then allows us to set the roles either via a command line variable: +// Which then allows us to set the roles either via a command line flags: // -// export CONFIG__roles__0=worker +// • ./myservice or ./myservice --roles=service,worker: Runs all modules // -// Or via the service parameters, we would activate in the following ways: -// -// • ./myservice or ./myservice --roles "service,worker": Runs all modules -// -// • ./myservice --roles "worker": Runs only the **Kakfa** module +// • ./myservice --roles=worker: Runs only the **Kafka** module // // • Etc... // diff --git a/service/host_mock.go b/service/host_mock.go index 68fdd6422..8fa1b002d 100644 --- a/service/host_mock.go +++ b/service/host_mock.go @@ -21,12 +21,10 @@ package service import ( - "go.uber.org/fx/auth" "go.uber.org/fx/config" "go.uber.org/fx/metrics" "github.com/opentracing/opentracing-go" - "github.com/uber-go/tally" "go.uber.org/zap" ) @@ -37,28 +35,19 @@ func NopHost() Host { // NopHostWithConfig is to be used in tests and allows setting of config. func NopHostWithConfig(configProvider config.Provider) Host { - return nopHostConfigured(auth.NopClient, zap.NewNop(), opentracing.NoopTracer{}, configProvider) -} - -// NopHostAuthFailure is nop manager with failure auth client -func NopHostAuthFailure() Host { - auth.UnregisterClient() - defer auth.UnregisterClient() - auth.RegisterClient(auth.FakeFailureClient) - return NopHostConfigured(auth.Load(nil), zap.NewNop(), opentracing.NoopTracer{}) + return nopHostConfigured(zap.NewNop(), opentracing.NoopTracer{}, configProvider) } // NopHostConfigured is a nop manager with set logger and tracer for tests -func NopHostConfigured(client auth.Client, logger *zap.Logger, tracer opentracing.Tracer) Host { - return nopHostConfigured(client, logger, tracer, nil) +func NopHostConfigured(logger *zap.Logger, tracer opentracing.Tracer) Host { + return nopHostConfigured(logger, tracer, nil) } -func nopHostConfigured(client auth.Client, logger *zap.Logger, tracer opentracing.Tracer, configProvider config.Provider) Host { +func nopHostConfigured(logger *zap.Logger, tracer opentracing.Tracer, configProvider config.Provider) Host { if configProvider == nil { configProvider = config.NewStaticProvider(nil) } return &serviceCore{ - authClient: client, configProvider: configProvider, standardConfig: serviceConfig{ Name: "dummy", @@ -66,7 +55,7 @@ func nopHostConfigured(client auth.Client, logger *zap.Logger, tracer opentracin Description: "does cool stuff", }, metricsCore: metricsCore{ - metrics: tally.NoopScope, + metrics: metrics.NopScope, statsReporter: metrics.NopCachedStatsReporter, }, tracerCore: tracerCore{ diff --git a/service/host_mock_test.go b/service/host_mock_test.go index b58a4520f..6afdef4cf 100644 --- a/service/host_mock_test.go +++ b/service/host_mock_test.go @@ -30,9 +30,3 @@ func TestNopHost_OK(t *testing.T) { sh := NopHost() assert.Equal(t, "dummy", sh.Name()) } - -func TestNopHost_AuthFailures(t *testing.T) { - sh := NopHostAuthFailure() - assert.Equal(t, "dummy", sh.Name()) - assert.Equal(t, "failure", sh.AuthClient().Name()) -} diff --git a/service/manager.go b/service/manager.go index a43b36c90..57df1b93b 100644 --- a/service/manager.go +++ b/service/manager.go @@ -29,13 +29,10 @@ import ( "time" "go.uber.org/fx/config" - "go.uber.org/zap" "github.com/pkg/errors" -) - -const ( - defaultStartupWait = 10 * time.Second + "github.com/uber-go/multierr" + "go.uber.org/zap" ) // A ExitCallback is a function to handle a service shutdown and provide @@ -45,6 +42,10 @@ type ExitCallback func(shutdown Exit) int // Implements Manager interface type manager struct { serviceCore + + StartTimeout time.Duration `default:"10s"` + StopTimeout time.Duration `default:"10s"` + locked bool observer Observer moduleWrappers []*moduleWrapper @@ -70,10 +71,6 @@ func newManager(builder *Builder) (Manager, error) { moduleWrappers: []*moduleWrapper{}, serviceCore: serviceCore{}, } - m.roles = map[string]bool{} - for _, r := range m.standardConfig.Roles { - m.roles[r] = true - } for _, opt := range builder.options { if optionErr := opt(m); optionErr != nil { return nil, errors.Wrap(optionErr, "option failed to apply") @@ -83,21 +80,32 @@ func newManager(builder *Builder) (Manager, error) { // If the user didn't pass in a configuration provider, load the standard. // Bypassing standard config load is pretty much only used for tests, although it could be // useful in certain circumstances. - m.configProvider = config.Load() + m.configProvider = config.DefaultLoader.Load() } if err := m.setupStandardConfig(); err != nil { return nil, err } + + if err := m.configProvider.Get(config.Root).Populate(m); err != nil { + return nil, err + } + + m.roles = map[string]bool{} + for _, r := range m.standardConfig.Roles { + m.roles[r] = true + } + // Initialize metrics. If no metrics reporters were Registered, do nop // TODO(glib): add a logging reporter and use it by default, rather than nop m.setupMetrics() if err := m.setupLogging(); err != nil { return nil, err } - m.setupAuthClient() + if err := m.setupRuntimeMetricsCollector(); err != nil { return nil, err } + m.setupVersionMetricsEmitter() if err := m.setupTracer(); err != nil { return nil, err @@ -262,8 +270,7 @@ func (m *manager) shutdown(err error, reason string, exitCode *int) (bool, error } m.transitionState(Stopped) - - return true, err + return true, multierr.Combine(err, multierr.Combine(errs...)) } func (m *manager) addModule(provider ModuleProvider, options ...ModuleOption) error { @@ -325,7 +332,7 @@ func (m *manager) start() Control { m.registerSignalHandlers() if len(errs) > 0 { var serviceErr error - errChan := make(chan Exit, 1) + errChan := make(chan Exit, len(errs)) // grab the first error, shut down the service and return the error for _, e := range errs { errChan <- Exit{ @@ -334,19 +341,24 @@ func (m *manager) start() Control { ExitCode: 4, } - m.shutdownMu.Unlock() - if _, err := m.shutdown(e, "", nil); err != nil { - zap.L().Error("Unable to shut down modules", - zap.NamedError("initialError", e), - zap.NamedError("shutdownError", err), - ) - } - zap.L().Error("Error starting the module", zap.Error(e)) - // return first service error - if serviceErr == nil { - serviceErr = e - } } + + m.shutdownMu.Unlock() + + e := multierr.Combine(errs...) + if _, err := m.shutdown(e, "", nil); err != nil { + zap.L().Error("Unable to shut down modules", + zap.NamedError("combinedErrors", e), + zap.NamedError("shutdownError", err), + ) + } + + zap.L().Error("Error starting the module", zap.Error(e)) + // return first service error + if serviceErr == nil { + serviceErr = e + } + return Control{ ExitChan: errChan, ReadyChan: readyCh, @@ -403,11 +415,11 @@ func (m *manager) startModules() []error { } else { zap.L().Info("Module started up cleanly", zap.String("module", mw.Name())) } - case <-time.After(defaultStartupWait): + case <-time.After(m.StartTimeout): lock.Lock() results = append( results, - fmt.Errorf("module: %s didn't start after %v", mw.Name(), defaultStartupWait), + fmt.Errorf("module: %q didn't start after %q", mw.Name(), m.StartTimeout), ) lock.Unlock() } @@ -423,24 +435,28 @@ func (m *manager) startModules() []error { func (m *manager) stopModules() []error { var results []error - var lock sync.Mutex - wg := sync.WaitGroup{} - wg.Add(len(m.moduleWrappers)) for _, mod := range m.moduleWrappers { - go func(m *moduleWrapper) { - if !m.IsRunning() { - // TODO: have a timeout here so a bad shutdown - // doesn't block everyone - if err := m.Stop(); err != nil { - lock.Lock() - results = append(results, err) - lock.Unlock() - } + errC := make(chan error, 1) + go func(mod *moduleWrapper) { + if mod.IsRunning() { + errC <- mod.Stop() + return } - wg.Done() + errC <- nil }(mod) + + select { + case err := <-errC: + if err != nil { + results = append(results, + fmt.Errorf("module %q stopped with error %q", mod.Name(), err)) + } + case <-time.After(m.StopTimeout): + results = append(results, + fmt.Errorf("stop module %q timedout after %q", mod.Name(), m.StopTimeout)) + } } - wg.Wait() + return results } diff --git a/service/options_test.go b/service/options_test.go index 8aa665801..9ef7f18e4 100644 --- a/service/options_test.go +++ b/service/options_test.go @@ -43,13 +43,17 @@ func TestNewOwner_ModulesErr(t *testing.T) { func TestNewOwner_WithMetricsOK(t *testing.T) { assert.NotPanics(t, func() { - newManager(WithModule(nopModuleProvider).WithOptions(WithMetrics(tally.NoopScope, metrics.NopCachedStatsReporter))) + newManager(WithModule(nopModuleProvider).WithOptions( + withConfig(validServiceConfig), + WithMetrics(tally.NoopScope, metrics.NopCachedStatsReporter))) }) } func TestNewOwner_WithTracingOK(t *testing.T) { tracer := &opentracing.NoopTracer{} assert.NotPanics(t, func() { - newManager(WithModule(nopModuleProvider).WithOptions(WithTracer(tracer))) + newManager(WithModule(nopModuleProvider).WithOptions( + withConfig(validServiceConfig), + WithTracer(tracer))) }) } diff --git a/service/service.go b/service/service.go index e36fee23f..8bce1712d 100644 --- a/service/service.go +++ b/service/service.go @@ -21,7 +21,6 @@ package service import ( - "go.uber.org/fx/auth" "go.uber.org/fx/config" "go.uber.org/fx/metrics" @@ -49,7 +48,6 @@ const ( // A Host represents the hosting environment for a service instance type Host interface { - AuthClient() auth.Client Name() string Description() string Roles() []string diff --git a/service/service_core.go b/service/service_core.go index 059ba500b..7b57e121d 100644 --- a/service/service_core.go +++ b/service/service_core.go @@ -26,7 +26,6 @@ import ( "sync" "time" - "go.uber.org/fx/auth" "go.uber.org/fx/config" "go.uber.org/fx/internal/util" "go.uber.org/fx/metrics" @@ -84,7 +83,6 @@ type serviceConfig struct { type serviceCore struct { metricsCore tracerCore - authClient auth.Client configProvider config.Provider logConfig ulog.Configuration observer Observer @@ -97,10 +95,6 @@ type serviceCore struct { var _ Host = &serviceCore{} -func (s *serviceCore) AuthClient() auth.Client { - return s.authClient -} - func (s *serviceCore) Name() string { return s.standardConfig.Name } @@ -228,13 +222,6 @@ func (s *serviceCore) setupObserver() { } } -func (s *serviceCore) setupAuthClient() { - if s.authClient != nil { - return - } - s.authClient = auth.Load(s) -} - func loadInstanceConfig(cfg config.Provider, key string, instance interface{}) bool { fieldName := instanceConfigName if field, found := util.FindField(instance, &fieldName, nil); found { diff --git a/testutils/tracing/tracer.go b/testutils/tracing/tracer.go index 1533f6aa6..4dc67d66a 100644 --- a/testutils/tracing/tracer.go +++ b/testutils/tracing/tracer.go @@ -21,6 +21,7 @@ package tracing import ( + "errors" "testing" "go.uber.org/fx/tracing" @@ -31,14 +32,43 @@ import ( "go.uber.org/zap" ) -// WithSpan is used for generating a span to be used in testing -func WithSpan(t *testing.T, log *zap.Logger, f func(opentracing.Span)) { - tracer, closer, err := tracing.CreateTracer(nil, "serviceName", log, tally.NoopScope) +// WithTracer is used for generating a tracer to be used in testing +func WithTracer(t *testing.T, log *zap.Logger, f func(opentracing.Tracer)) { + tracer, closer, err := tracing.CreateTracer(nil, "dummy", log, tally.NoopScope) require.NoError(t, err) defer func() { require.NoError(t, closer.Close()) }() - span := tracer.StartSpan("test") - defer span.Finish() - f(span) + f(tracer) +} + +// WithSpan is used for generating a span to be used in testing +func WithSpan(t *testing.T, log *zap.Logger, f func(opentracing.Span)) { + WithTracer(t, log, func(tracer opentracing.Tracer) { + span := tracer.StartSpan("test") + defer span.Finish() + f(span) + }) +} + +// ErrorTracer is used to test error scenarios from context encoding +type ErrorTracer struct { + opentracing.Tracer +} + +// Inject implements opentracing.Tracer +func (e *ErrorTracer) Inject( + sm opentracing.SpanContext, + format interface{}, + carrier interface{}, +) error { + return errors.New("inject error") +} + +// Extract implements opentracing.Tracer +func (e *ErrorTracer) Extract( + format interface{}, + carrier interface{}, +) (opentracing.SpanContext, error) { + return nil, errors.New("extract error") } diff --git a/version.go b/version.go index a6f4e2ef4..688209934 100644 --- a/version.go +++ b/version.go @@ -21,4 +21,4 @@ package fx // Version is exported for runtime compatibility checks -const Version = "1.0.0-beta3" +const Version = "1.0.0-beta4-dev"