Skip to content

Commit

Permalink
test: use T.Setenv to set env vars in tests
Browse files Browse the repository at this point in the history
This commit replaces `os.Setenv` with `t.Setenv` in tests. The
environment variable is automatically restored to its original value
when the test and all its subtests complete.

Reference: https://pkg.go.dev/testing#T.Setenv
Signed-off-by: Eng Zer Jun <engzerjun@gmail.com>
  • Loading branch information
Juneezee authored and arxdsilva committed Jan 5, 2023
1 parent 022d2fa commit 4ed2a9f
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 156 deletions.
5 changes: 2 additions & 3 deletions adapters/postgres/postgres_test.go
Expand Up @@ -37,7 +37,7 @@ func TestLoad(t *testing.T) {
// Only run the failing part when a specific env variable is set
if os.Getenv("BE_CRASHER") == "1" {
Load()
os.Setenv("PREST_PG_DATABASE", "prest-test")
t.Setenv("PREST_PG_DATABASE", "prest-test")
return
}
// Start the actual test in a different subprocess
Expand Down Expand Up @@ -1294,7 +1294,7 @@ func BenchmarkPrepare(b *testing.B) {
}

func TestDisableCache(t *testing.T) {
os.Setenv("PREST_PG_CACHE", "false")
t.Setenv("PREST_PG_CACHE", "false")
config.Load()
Load()
ClearStmt()
Expand All @@ -1306,7 +1306,6 @@ func TestDisableCache(t *testing.T) {
if ok {
t.Error("has query in cache")
}
os.Setenv("PREST_PG_CACHE", "true")
}

func TestParseBatchInsertRequest(t *testing.T) {
Expand Down
3 changes: 1 addition & 2 deletions cache/buntdb_test.go
Expand Up @@ -2,7 +2,6 @@ package cache

import (
"net/http/httptest"
"os"
"testing"

"github.com/prest/prest/adapters/postgres"
Expand All @@ -15,7 +14,7 @@ func init() {
}

func TestBuntGetDoesntExist(t *testing.T) {
os.Setenv("PREST_CACHE", "true")
t.Setenv("PREST_CACHE", "true")
config.Load()
w := httptest.NewRecorder()

Expand Down
7 changes: 1 addition & 6 deletions cache/cache_test.go
Expand Up @@ -10,12 +10,10 @@ import (
func init() {
os.Setenv("PREST_CONF", "./testdata/prest.toml")
os.Setenv("PREST_CACHE_ENABLED", "true")
os.Setenv("PREST_PG_CACHE", "true")
config.Load()
}
func TestEndpointRulesEnable(t *testing.T) {
os.Setenv("PREST_CONF", "./testdata/prest.toml")
os.Setenv("PREST_CACHE_ENABLED", "true")
config.Load()
config.PrestConf.Cache.Endpoints = append(config.PrestConf.Cache.Endpoints, config.CacheEndpoint{
Time: 5,
Endpoint: "/prest/public/test",
Expand All @@ -38,9 +36,6 @@ func TestEndpointRulesNotExist(t *testing.T) {
}

func TestEndpointRulesDisable(t *testing.T) {
os.Setenv("PREST_CONF", "./testdata/prest.toml")
os.Setenv("PREST_CACHE_ENABLED", "true")
config.Load()
config.PrestConf.Cache.Endpoints = append(config.PrestConf.Cache.Endpoints, config.CacheEndpoint{
Endpoint: "/prest/public/test-disable",
Enabled: false,
Expand Down
240 changes: 119 additions & 121 deletions config/config_test.go
Expand Up @@ -8,7 +8,7 @@ import (
)

func TestLoad(t *testing.T) {
os.Setenv("PREST_CONF", "../testdata/prest.toml")
t.Setenv("PREST_CONF", "../testdata/prest.toml")
Load()
require.Greaterf(t, len(PrestConf.AccessConf.Tables), 2,
"expected > 2, got: %d", len(PrestConf.AccessConf.Tables))
Expand All @@ -19,78 +19,75 @@ func TestLoad(t *testing.T) {
}
require.True(t, PrestConf.AccessConf.Restrict, "expected true, but got false")
require.Equal(t, 60, PrestConf.HTTPTimeout)

os.Setenv("PREST_CONF", "foo/bar/prest.toml")
}

func TestParse(t *testing.T) {
os.Setenv("PREST_CONF", "../testdata/prest.toml")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)

var expected string
expected = os.Getenv("PREST_PG_DATABASE")
if len(expected) == 0 {
expected = "prest"
}

require.Equal(t, expected, cfg.PGDatabase)

os.Unsetenv("PREST_CONF")
os.Unsetenv("PREST_JWT_DEFAULT")
os.Setenv("PREST_HTTP_PORT", "4000")

viperCfg()
cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 4000, cfg.HTTPPort)
require.True(t, cfg.EnableDefaultJWT)

os.Unsetenv("PREST_CONF")

os.Setenv("PREST_CONF", "")
os.Setenv("PREST_JWT_DEFAULT", "false")

viperCfg()
cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 4000, cfg.HTTPPort)
require.False(t, cfg.EnableDefaultJWT)

os.Unsetenv("PREST_JWT_DEFAULT")

viperCfg()
cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 4000, cfg.HTTPPort)

os.Unsetenv("PREST_CONF")
os.Unsetenv("PREST_HTTP_PORT")
os.Setenv("PREST_JWT_KEY", "s3cr3t")

viperCfg()
cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, "s3cr3t", cfg.JWTKey)
require.Equal(t, "HS256", cfg.JWTAlgo)

os.Unsetenv("PREST_JWT_KEY")
os.Setenv("PREST_JWT_ALGO", "HS512")

viperCfg()
cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, "HS512", cfg.JWTAlgo)

os.Unsetenv("PREST_JWT_ALGO")
t.Run("PREST_CONF", func(t *testing.T) {
t.Setenv("PREST_CONF", "../testdata/prest.toml")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)

var expected string
expected = os.Getenv("PREST_PG_DATABASE")
if len(expected) == 0 {
expected = "prest"
}

require.Equal(t, expected, cfg.PGDatabase)
})

t.Run("PREST_HTTP_PORT and unset PREST_JWT_DEFAULT", func(t *testing.T) {
t.Setenv("PREST_HTTP_PORT", "4000")
os.Unsetenv("PREST_JWT_DEFAULT")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 4000, cfg.HTTPPort)
require.True(t, cfg.EnableDefaultJWT)
})

t.Run("empty PREST_CONF and falsey PREST_JWT_DEFAULT", func(t *testing.T) {
t.Setenv("PREST_CONF", "")
t.Setenv("PREST_JWT_DEFAULT", "false")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)
require.False(t, cfg.EnableDefaultJWT)
})

t.Run("empty PREST_CONF", func(t *testing.T) {
t.Setenv("PREST_CONF", "")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)
})

t.Run("PREST_JWT_KEY", func(t *testing.T) {
t.Setenv("PREST_JWT_KEY", "s3cr3t")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, "s3cr3t", cfg.JWTKey)
require.Equal(t, "HS256", cfg.JWTAlgo)
})

t.Run("PREST_JWT_ALGO", func(t *testing.T) {
t.Setenv("PREST_JWT_ALGO", "HS512")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, "HS512", cfg.JWTAlgo)
})
}

func TestGetDefaultPrestConf(t *testing.T) {
Expand Down Expand Up @@ -120,59 +117,62 @@ func TestGetDefaultPrestConf(t *testing.T) {
}

func TestDatabaseURL(t *testing.T) {
os.Setenv("PREST_PG_URL", "postgresql://user:pass@localhost:1234/mydatabase/?sslmode=disable")

viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, "mydatabase", cfg.PGDatabase)
require.Equal(t, "localhost", cfg.PGHost)
require.Equal(t, 1234, cfg.PGPort)
require.Equal(t, "user", cfg.PGUser)
require.Equal(t, "pass", cfg.PGPass)
require.Equal(t, "disable", cfg.SSLMode)

os.Unsetenv("PREST_PG_URL")
os.Setenv("DATABASE_URL", "postgresql://cloud:cloudPass@localhost:5432/CloudDatabase/?sslmode=disable")

cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 5432, cfg.PGPort)
require.Equal(t, "cloud", cfg.PGUser)
require.Equal(t, "cloudPass", cfg.PGPass)
require.Equal(t, "disable", cfg.SSLMode)

os.Unsetenv("DATABASE_URL")
t.Run("PREST_PG_URL", func(t *testing.T) {
t.Setenv("PREST_PG_URL", "postgresql://user:pass@localhost:1234/mydatabase/?sslmode=disable")
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, "mydatabase", cfg.PGDatabase)
require.Equal(t, "localhost", cfg.PGHost)
require.Equal(t, 1234, cfg.PGPort)
require.Equal(t, "user", cfg.PGUser)
require.Equal(t, "pass", cfg.PGPass)
require.Equal(t, "disable", cfg.SSLMode)
})

t.Run("DATABASE_URL", func(t *testing.T) {
t.Setenv("DATABASE_URL", "postgresql://cloud:cloudPass@localhost:5432/CloudDatabase/?sslmode=disable")
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 5432, cfg.PGPort)
require.Equal(t, "cloud", cfg.PGUser)
require.Equal(t, "cloudPass", cfg.PGPass)
require.Equal(t, "disable", cfg.SSLMode)
})
}

func TestHTTPPort(t *testing.T) {
os.Setenv("PORT", "8080")

viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 8080, cfg.HTTPPort)

// set env PREST_HTTP_PORT and PORT
os.Setenv("PREST_HTTP_PORT", "3000")

cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 8080, cfg.HTTPPort)

// unset env PORT and set PREST_HTTP_PORT
os.Unsetenv("PORT")

cfg = &Prest{}
err = Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)

os.Unsetenv("PREST_HTTP_PORT")
t.Run("set PORT", func(t *testing.T) {
t.Setenv("PORT", "8080")
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 8080, cfg.HTTPPort)
})

t.Run("set PREST_HTTP_PORT", func(t *testing.T) {
t.Setenv("PREST_HTTP_PORT", "3000")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 3000, cfg.HTTPPort)
})

t.Run("set PORT and PREST_HTTP_PORT", func(t *testing.T) {
t.Setenv("PORT", "8080")
t.Setenv("PREST_HTTP_PORT", "3000")
viperCfg()
cfg := &Prest{}
err := Parse(cfg)
require.NoError(t, err)
require.Equal(t, 8080, cfg.HTTPPort)
})
}

func Test_parseDatabaseURL(t *testing.T) {
Expand All @@ -194,16 +194,14 @@ func Test_parseDatabaseURL(t *testing.T) {
func Test_portFromEnv(t *testing.T) {
c := &Prest{}

os.Setenv("PORT", "PORT")
t.Setenv("PORT", "PORT")

err := portFromEnv(c)
require.Error(t, err)

os.Unsetenv("PORT")
}

func Test_Auth(t *testing.T) {
os.Setenv("PREST_CONF", "../testdata/prest.toml")
t.Setenv("PREST_CONF", "../testdata/prest.toml")

viperCfg()
cfg := &Prest{}
Expand All @@ -225,7 +223,7 @@ func Test_Auth(t *testing.T) {
}

func Test_ExposeDataConfig(t *testing.T) {
os.Setenv("PREST_CONF", "../testdata/prest_expose.toml")
t.Setenv("PREST_CONF", "../testdata/prest_expose.toml")

viperCfg()
cfg := &Prest{}
Expand Down
3 changes: 1 addition & 2 deletions controllers/sql_test.go
Expand Up @@ -4,7 +4,6 @@ package controllers
import (
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -92,7 +91,7 @@ func TestRenderWithXML(t *testing.T) {
}{
{"Get schemas with COUNT clause with XML Render", "/schemas?_count=*&_renderer=xml", "GET", 200, "<objects><object><count>4</count></object></objects>"},
}
os.Setenv("PREST_DEBUG", "true")
t.Setenv("PREST_DEBUG", "true")
config.Load()
postgres.Load()
n := middlewares.GetApp()
Expand Down

0 comments on commit 4ed2a9f

Please sign in to comment.