diff --git a/adapters/postgres/postgres_test.go b/adapters/postgres/postgres_test.go index 0cd2044a4..421a41da6 100644 --- a/adapters/postgres/postgres_test.go +++ b/adapters/postgres/postgres_test.go @@ -37,7 +37,7 @@ func TestLoad(t *testing.T) { // Only run the failing part when a specific env variable is set if os.Getenv("BE_CRASHER") == "1" { Load() - os.Setenv("PREST_PG_DATABASE", "prest-test") + t.Setenv("PREST_PG_DATABASE", "prest-test") return } // Start the actual test in a different subprocess @@ -1294,7 +1294,7 @@ func BenchmarkPrepare(b *testing.B) { } func TestDisableCache(t *testing.T) { - os.Setenv("PREST_PG_CACHE", "false") + t.Setenv("PREST_PG_CACHE", "false") config.Load() Load() ClearStmt() @@ -1306,7 +1306,6 @@ func TestDisableCache(t *testing.T) { if ok { t.Error("has query in cache") } - os.Setenv("PREST_PG_CACHE", "true") } func TestParseBatchInsertRequest(t *testing.T) { diff --git a/cache/buntdb_test.go b/cache/buntdb_test.go index bb1c3d63f..6af1a6cbb 100644 --- a/cache/buntdb_test.go +++ b/cache/buntdb_test.go @@ -2,7 +2,6 @@ package cache import ( "net/http/httptest" - "os" "testing" "github.com/prest/prest/adapters/postgres" @@ -15,7 +14,7 @@ func init() { } func TestBuntGetDoesntExist(t *testing.T) { - os.Setenv("PREST_CACHE", "true") + t.Setenv("PREST_CACHE", "true") config.Load() w := httptest.NewRecorder() diff --git a/cache/cache_test.go b/cache/cache_test.go index ebe45281e..9524d529d 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -10,12 +10,10 @@ import ( func init() { os.Setenv("PREST_CONF", "./testdata/prest.toml") os.Setenv("PREST_CACHE_ENABLED", "true") + os.Setenv("PREST_PG_CACHE", "true") config.Load() } func TestEndpointRulesEnable(t *testing.T) { - os.Setenv("PREST_CONF", "./testdata/prest.toml") - os.Setenv("PREST_CACHE_ENABLED", "true") - config.Load() config.PrestConf.Cache.Endpoints = append(config.PrestConf.Cache.Endpoints, config.CacheEndpoint{ Time: 5, Endpoint: "/prest/public/test", @@ -38,9 +36,6 @@ func TestEndpointRulesNotExist(t *testing.T) { } func TestEndpointRulesDisable(t *testing.T) { - os.Setenv("PREST_CONF", "./testdata/prest.toml") - os.Setenv("PREST_CACHE_ENABLED", "true") - config.Load() config.PrestConf.Cache.Endpoints = append(config.PrestConf.Cache.Endpoints, config.CacheEndpoint{ Endpoint: "/prest/public/test-disable", Enabled: false, diff --git a/config/config_test.go b/config/config_test.go index ff7ce1e71..9db6d01a4 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -8,7 +8,7 @@ import ( ) func TestLoad(t *testing.T) { - os.Setenv("PREST_CONF", "../testdata/prest.toml") + t.Setenv("PREST_CONF", "../testdata/prest.toml") Load() require.Greaterf(t, len(PrestConf.AccessConf.Tables), 2, "expected > 2, got: %d", len(PrestConf.AccessConf.Tables)) @@ -19,78 +19,75 @@ func TestLoad(t *testing.T) { } require.True(t, PrestConf.AccessConf.Restrict, "expected true, but got false") require.Equal(t, 60, PrestConf.HTTPTimeout) - - os.Setenv("PREST_CONF", "foo/bar/prest.toml") } func TestParse(t *testing.T) { - os.Setenv("PREST_CONF", "../testdata/prest.toml") - viperCfg() - cfg := &Prest{} - err := Parse(cfg) - require.NoError(t, err) - require.Equal(t, 3000, cfg.HTTPPort) - - var expected string - expected = os.Getenv("PREST_PG_DATABASE") - if len(expected) == 0 { - expected = "prest" - } - - require.Equal(t, expected, cfg.PGDatabase) - - os.Unsetenv("PREST_CONF") - os.Unsetenv("PREST_JWT_DEFAULT") - os.Setenv("PREST_HTTP_PORT", "4000") - - viperCfg() - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 4000, cfg.HTTPPort) - require.True(t, cfg.EnableDefaultJWT) - - os.Unsetenv("PREST_CONF") - - os.Setenv("PREST_CONF", "") - os.Setenv("PREST_JWT_DEFAULT", "false") - - viperCfg() - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 4000, cfg.HTTPPort) - require.False(t, cfg.EnableDefaultJWT) - - os.Unsetenv("PREST_JWT_DEFAULT") - - viperCfg() - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 4000, cfg.HTTPPort) - - os.Unsetenv("PREST_CONF") - os.Unsetenv("PREST_HTTP_PORT") - os.Setenv("PREST_JWT_KEY", "s3cr3t") - - viperCfg() - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, "s3cr3t", cfg.JWTKey) - require.Equal(t, "HS256", cfg.JWTAlgo) - - os.Unsetenv("PREST_JWT_KEY") - os.Setenv("PREST_JWT_ALGO", "HS512") - - viperCfg() - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, "HS512", cfg.JWTAlgo) - - os.Unsetenv("PREST_JWT_ALGO") + t.Run("PREST_CONF", func(t *testing.T) { + t.Setenv("PREST_CONF", "../testdata/prest.toml") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 3000, cfg.HTTPPort) + + var expected string + expected = os.Getenv("PREST_PG_DATABASE") + if len(expected) == 0 { + expected = "prest" + } + + require.Equal(t, expected, cfg.PGDatabase) + }) + + t.Run("PREST_HTTP_PORT and unset PREST_JWT_DEFAULT", func(t *testing.T) { + t.Setenv("PREST_HTTP_PORT", "4000") + os.Unsetenv("PREST_JWT_DEFAULT") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 4000, cfg.HTTPPort) + require.True(t, cfg.EnableDefaultJWT) + }) + + t.Run("empty PREST_CONF and falsey PREST_JWT_DEFAULT", func(t *testing.T) { + t.Setenv("PREST_CONF", "") + t.Setenv("PREST_JWT_DEFAULT", "false") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 3000, cfg.HTTPPort) + require.False(t, cfg.EnableDefaultJWT) + }) + + t.Run("empty PREST_CONF", func(t *testing.T) { + t.Setenv("PREST_CONF", "") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 3000, cfg.HTTPPort) + }) + + t.Run("PREST_JWT_KEY", func(t *testing.T) { + t.Setenv("PREST_JWT_KEY", "s3cr3t") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, "s3cr3t", cfg.JWTKey) + require.Equal(t, "HS256", cfg.JWTAlgo) + }) + + t.Run("PREST_JWT_ALGO", func(t *testing.T) { + t.Setenv("PREST_JWT_ALGO", "HS512") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, "HS512", cfg.JWTAlgo) + }) } func TestGetDefaultPrestConf(t *testing.T) { @@ -120,59 +117,62 @@ func TestGetDefaultPrestConf(t *testing.T) { } func TestDatabaseURL(t *testing.T) { - os.Setenv("PREST_PG_URL", "postgresql://user:pass@localhost:1234/mydatabase/?sslmode=disable") - viperCfg() - cfg := &Prest{} - err := Parse(cfg) - require.NoError(t, err) - require.Equal(t, "mydatabase", cfg.PGDatabase) - require.Equal(t, "localhost", cfg.PGHost) - require.Equal(t, 1234, cfg.PGPort) - require.Equal(t, "user", cfg.PGUser) - require.Equal(t, "pass", cfg.PGPass) - require.Equal(t, "disable", cfg.SSLMode) - - os.Unsetenv("PREST_PG_URL") - os.Setenv("DATABASE_URL", "postgresql://cloud:cloudPass@localhost:5432/CloudDatabase/?sslmode=disable") - - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 5432, cfg.PGPort) - require.Equal(t, "cloud", cfg.PGUser) - require.Equal(t, "cloudPass", cfg.PGPass) - require.Equal(t, "disable", cfg.SSLMode) - os.Unsetenv("DATABASE_URL") + t.Run("PREST_PG_URL", func(t *testing.T) { + t.Setenv("PREST_PG_URL", "postgresql://user:pass@localhost:1234/mydatabase/?sslmode=disable") + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, "mydatabase", cfg.PGDatabase) + require.Equal(t, "localhost", cfg.PGHost) + require.Equal(t, 1234, cfg.PGPort) + require.Equal(t, "user", cfg.PGUser) + require.Equal(t, "pass", cfg.PGPass) + require.Equal(t, "disable", cfg.SSLMode) + }) + + t.Run("DATABASE_URL", func(t *testing.T) { + t.Setenv("DATABASE_URL", "postgresql://cloud:cloudPass@localhost:5432/CloudDatabase/?sslmode=disable") + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 5432, cfg.PGPort) + require.Equal(t, "cloud", cfg.PGUser) + require.Equal(t, "cloudPass", cfg.PGPass) + require.Equal(t, "disable", cfg.SSLMode) + }) } func TestHTTPPort(t *testing.T) { - os.Setenv("PORT", "8080") - viperCfg() - cfg := &Prest{} - err := Parse(cfg) - require.NoError(t, err) - require.Equal(t, 8080, cfg.HTTPPort) - - // set env PREST_HTTP_PORT and PORT - os.Setenv("PREST_HTTP_PORT", "3000") - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 8080, cfg.HTTPPort) - - // unset env PORT and set PREST_HTTP_PORT - os.Unsetenv("PORT") - - cfg = &Prest{} - err = Parse(cfg) - require.NoError(t, err) - require.Equal(t, 3000, cfg.HTTPPort) - - os.Unsetenv("PREST_HTTP_PORT") + t.Run("set PORT", func(t *testing.T) { + t.Setenv("PORT", "8080") + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 8080, cfg.HTTPPort) + }) + + t.Run("set PREST_HTTP_PORT", func(t *testing.T) { + t.Setenv("PREST_HTTP_PORT", "3000") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 3000, cfg.HTTPPort) + }) + + t.Run("set PORT and PREST_HTTP_PORT", func(t *testing.T) { + t.Setenv("PORT", "8080") + t.Setenv("PREST_HTTP_PORT", "3000") + viperCfg() + cfg := &Prest{} + err := Parse(cfg) + require.NoError(t, err) + require.Equal(t, 8080, cfg.HTTPPort) + }) } func Test_parseDatabaseURL(t *testing.T) { @@ -194,16 +194,14 @@ func Test_parseDatabaseURL(t *testing.T) { func Test_portFromEnv(t *testing.T) { c := &Prest{} - os.Setenv("PORT", "PORT") + t.Setenv("PORT", "PORT") err := portFromEnv(c) require.Error(t, err) - - os.Unsetenv("PORT") } func Test_Auth(t *testing.T) { - os.Setenv("PREST_CONF", "../testdata/prest.toml") + t.Setenv("PREST_CONF", "../testdata/prest.toml") viperCfg() cfg := &Prest{} @@ -225,7 +223,7 @@ func Test_Auth(t *testing.T) { } func Test_ExposeDataConfig(t *testing.T) { - os.Setenv("PREST_CONF", "../testdata/prest_expose.toml") + t.Setenv("PREST_CONF", "../testdata/prest_expose.toml") viperCfg() cfg := &Prest{} diff --git a/controllers/sql_test.go b/controllers/sql_test.go index 9bc6b672e..42312f083 100644 --- a/controllers/sql_test.go +++ b/controllers/sql_test.go @@ -4,7 +4,6 @@ package controllers import ( "net/http" "net/http/httptest" - "os" "testing" "github.com/gorilla/mux" @@ -92,7 +91,7 @@ func TestRenderWithXML(t *testing.T) { }{ {"Get schemas with COUNT clause with XML Render", "/schemas?_count=*&_renderer=xml", "GET", 200, "4"}, } - os.Setenv("PREST_DEBUG", "true") + t.Setenv("PREST_DEBUG", "true") config.Load() postgres.Load() n := middlewares.GetApp() diff --git a/middlewares/config_test.go b/middlewares/config_test.go index 653655ca4..b3730554f 100644 --- a/middlewares/config_test.go +++ b/middlewares/config_test.go @@ -5,7 +5,6 @@ import ( "io" "net/http" "net/http/httptest" - "os" "testing" "github.com/gorilla/mux" @@ -77,7 +76,7 @@ func TestGetAppWithoutReorderedMiddleware(t *testing.T) { } func Test_Middleware_DoesntBlock_CustomRoutes(t *testing.T) { - os.Setenv("PREST_DEBUG", "true") + t.Setenv("PREST_DEBUG", "true") config.Load() postgres.Load() app = nil @@ -90,7 +89,7 @@ func Test_Middleware_DoesntBlock_CustomRoutes(t *testing.T) { AccessControl(), negroni.Wrap(crudRoutes), )) - os.Setenv("PREST_CONF", "../testdata/prest.toml") + t.Setenv("PREST_CONF", "../testdata/prest.toml") n := GetApp() n.UseHandler(r) @@ -121,7 +120,6 @@ func Test_Middleware_DoesntBlock_CustomRoutes(t *testing.T) { require.Contains(t, string(body), "required authorization to table") MiddlewareStack = []negroni.Handler{} - os.Setenv("PREST_CONF", "") } func customMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { @@ -136,7 +134,7 @@ func customMiddleware(w http.ResponseWriter, r *http.Request, next http.HandlerF } func TestDebug(t *testing.T) { - os.Setenv("PREST_DEBUG", "true") + t.Setenv("PREST_DEBUG", "true") config.Load() nd := appTest() serverd := httptest.NewServer(nd) @@ -148,8 +146,8 @@ func TestDebug(t *testing.T) { func TestEnableDefaultJWT(t *testing.T) { app = nil - os.Setenv("PREST_JWT_DEFAULT", "false") - os.Setenv("PREST_DEBUG", "false") + t.Setenv("PREST_JWT_DEFAULT", "false") + t.Setenv("PREST_DEBUG", "false") config.Load() nd := appTest() serverd := httptest.NewServer(nd) @@ -162,8 +160,8 @@ func TestEnableDefaultJWT(t *testing.T) { func TestJWTIsRequired(t *testing.T) { MiddlewareStack = []negroni.Handler{} app = nil - os.Setenv("PREST_JWT_DEFAULT", "true") - os.Setenv("PREST_DEBUG", "false") + t.Setenv("PREST_JWT_DEFAULT", "true") + t.Setenv("PREST_DEBUG", "false") config.Load() nd := appTestWithJwt() serverd := httptest.NewServer(nd) @@ -178,10 +176,10 @@ func TestJWTSignatureOk(t *testing.T) { app = nil MiddlewareStack = nil bearer := "Bearer eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQHNvbWV3aGVyZS5jb20iLCJpYXQiOjE1MTc1NjM2MTYsImlzcyI6InByaXZhdGUiLCJqdGkiOiJjZWZhNzRmZS04OTRjLWZmNjMtZDgxNi00NjIwYjhjZDkyZWUiLCJvcmciOiJwcml2YXRlIiwic3ViIjoiam9obi5kb2UifQ.zLWkEd4hP4XdCD_DlRy6mgPeKwEl1dcdtx5A_jHSfmc87EsrGgNSdi8eBTzCgSU0jgV6ssTgQwzY6x4egze2xA" - os.Setenv("PREST_JWT_DEFAULT", "true") - os.Setenv("PREST_DEBUG", "false") - os.Setenv("PREST_JWT_KEY", "s3cr3t") - os.Setenv("PREST_JWT_ALGO", "HS512") + t.Setenv("PREST_JWT_DEFAULT", "true") + t.Setenv("PREST_DEBUG", "false") + t.Setenv("PREST_JWT_KEY", "s3cr3t") + t.Setenv("PREST_JWT_ALGO", "HS512") config.Load() nd := appTestWithJwt() serverd := httptest.NewServer(nd) @@ -201,10 +199,10 @@ func TestJWTSignatureOk(t *testing.T) { func TestJWTSignatureKo(t *testing.T) { app = nil bearer := "Bearer: eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQHNvbWV3aGVyZS5jb20iLCJleHAiOjE1MjUzMzk2MTYsImlhdCI6MTUxNzU2MzYxNiwiaXNzIjoicHJpdmF0ZSIsImp0aSI6ImNlZmE3NGZlLTg5NGMtZmY2My1kODE2LTQ2MjBiOGNkOTJlZSIsIm9yZyI6InByaXZhdGUiLCJzdWIiOiJqb2huLmRvZSJ9.zGP1Xths2bK2r9FN0Gv1SzyoisO0dhRwvqrPvunGxUyU5TbkfdnTcQRJNYZzJfGILeQ9r3tbuakWm-NIoDlbbA" - os.Setenv("PREST_JWT_DEFAULT", "true") - os.Setenv("PREST_DEBUG", "false") - os.Setenv("PREST_JWT_KEY", "s3cr3t") - os.Setenv("PREST_JWT_ALGO", "HS256") + t.Setenv("PREST_JWT_DEFAULT", "true") + t.Setenv("PREST_DEBUG", "false") + t.Setenv("PREST_JWT_KEY", "s3cr3t") + t.Setenv("PREST_JWT_ALGO", "HS256") config.Load() nd := appTestWithJwt() serverd := httptest.NewServer(nd) @@ -251,9 +249,9 @@ func appTestWithJwt() *negroni.Negroni { func Test_CORS_Middleware(t *testing.T) { MiddlewareStack = []negroni.Handler{} - os.Setenv("PREST_DEBUG", "true") - os.Setenv("PREST_CORS_ALLOWORIGIN", "*") - os.Setenv("PREST_CONF", "../testdata/prest.toml") + t.Setenv("PREST_DEBUG", "true") + t.Setenv("PREST_CORS_ALLOWORIGIN", "*") + t.Setenv("PREST_CONF", "../testdata/prest.toml") config.Load() app = nil r := mux.NewRouter() @@ -283,8 +281,8 @@ func Test_CORS_Middleware(t *testing.T) { func TestExposeTablesMiddleware(t *testing.T) { MiddlewareStack = []negroni.Handler{} app = nil - os.Setenv("PREST_DEBUG", "true") - os.Setenv("PREST_CONF", "../testdata/prest_expose.toml") + t.Setenv("PREST_DEBUG", "true") + t.Setenv("PREST_CONF", "../testdata/prest_expose.toml") config.Load() app = nil r := mux.NewRouter()