diff --git a/viper.go b/viper.go index 20eb4da17..c223a90e9 100644 --- a/viper.go +++ b/viper.go @@ -689,6 +689,59 @@ func (v *Viper) searchMap(source map[string]any, path []string) any { return nil } +// searchMapWithAliases recursively searches for slice field in source map and +// replace them with the environment variable value if it exists. +// +// Returns replaced values. +func (v *Viper) searchAndReplaceSliceValueWithEnv(source any, envKey string) any { + switch sourceValue := source.(type) { + case []any: + var newSliceValues []any + for i, sliceValue := range sourceValue { + envKey := envKey + v.keyDelim + strconv.Itoa(i) + switch existingValue := sliceValue.(type) { + case map[string]any: + newVal := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + newSliceValues = append(newSliceValues, newVal) + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newSliceValues = append(newSliceValues, newVal) + } else { + newSliceValues = append(newSliceValues, existingValue) + } + } + } + return newSliceValues + + case map[string]any: + var newMapValues map[string]any = make(map[string]any) + for key, mapValue := range sourceValue { + envKey := envKey + v.keyDelim + key + switch existingValue := mapValue.(type) { + case map[string]any: + newVal := v.searchAndReplaceSliceValueWithEnv(existingValue, envKey) + newMapValues[key] = newVal + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + newMapValues[key] = newVal + } else { + newMapValues[key] = existingValue + } + } + } + return newMapValues + + default: + if newVal, ok := v.getEnv(v.mergeWithEnvPrefix(envKey)); ok { + return newVal + } else { + return source + } + } +} + // searchIndexableWithPathPrefixes recursively searches for a value for path in source map/slice. // // While searchMap() considers each path element as a single map key or slice index, this @@ -906,6 +959,11 @@ func (v *Viper) Get(key string) any { return nil } + // Check for Env override again, to handle slices + if v.automaticEnvApplied { + val = v.searchAndReplaceSliceValueWithEnv(val, lcaseKey) + } + if v.typeByDefValue { // TODO(bep) this branch isn't covered by a single test. valType := val diff --git a/viper_test.go b/viper_test.go index 0b1f40741..0375300d2 100644 --- a/viper_test.go +++ b/viper_test.go @@ -2606,6 +2606,85 @@ func TestSliceIndexAccess(t *testing.T) { assert.Equal(t, "Static", v.GetString("tv.0.episodes.1.2")) } +var yamlSimpleSlice = []byte(` +name: Steve +port: 8080 +auth: + secret: 88888-88888 +modes: + - 1 + - 2 + - 3 +clients: + - name: foo + - name: bar +proxy: + clients: + - name: proxy_foo + - name: proxy_bar + - name: proxy_baz +`) + +func TestSliceIndexAutomaticEnv(t *testing.T) { + v.SetConfigType("yaml") + r := strings.NewReader(string(yamlSimpleSlice)) + + type ClientConfig struct { + Name string + } + + type AuthConfig struct { + Secret string + } + + type ProxyConfig struct { + Clients []ClientConfig + } + + type Configuration struct { + Port int + Name string + Auth AuthConfig + Modes []int + Clients []ClientConfig + Proxy ProxyConfig + } + + // Read yaml as default value + err := v.unmarshalReader(r, v.config) + require.NoError(t, err) + + assert.Equal(t, "Steve", v.GetString("name")) + assert.Equal(t, 8080, v.GetInt("port")) + assert.Equal(t, "88888-88888", v.GetString("auth.secret")) + assert.Equal(t, "foo", v.GetString("clients.0.name")) + assert.Equal(t, "bar", v.GetString("clients.1.name")) + assert.Equal(t, "proxy_foo", v.GetString("proxy.clients.0.name")) + assert.Equal(t, []int{1, 2, 3}, v.GetIntSlice("modes")) + + // Override with env variable + t.Setenv("NAME", "Steven") + t.Setenv("AUTH_SECRET", "99999-99999") + t.Setenv("MODES_2", "300") + t.Setenv("CLIENTS_1_NAME", "baz") + t.Setenv("PROXY_CLIENTS_0_NAME", "ProxyFoo") + + SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + AutomaticEnv() + + // Unmarshal into struct + var config Configuration + v.Unmarshal(&config) + + assert.Equal(t, "Steven", config.Name) + assert.Equal(t, 8080, config.Port) + assert.Equal(t, "99999-99999", config.Auth.Secret) + assert.Equal(t, []int{1, 2, 300}, config.Modes) + assert.Equal(t, "foo", config.Clients[0].Name) + assert.Equal(t, "baz", config.Clients[1].Name) + assert.Equal(t, "ProxyFoo", config.Proxy.Clients[0].Name) +} + func TestIsPathShadowedInFlatMap(t *testing.T) { v := New()