From e45c2faeac8aa6ab59a4e91f643b2ef9cb1de03c Mon Sep 17 00:00:00 2001 From: Guillaume St-Pierre Date: Thu, 18 Apr 2024 11:59:45 -0400 Subject: [PATCH] fix(postgres): Fix the non-default dbname error (#2489) * Fix the non-default dbname error The linked issue described in great detail an issue where we assumed everyone would use the default database user, whose home DB defaults to the postgres database. When that was not the case, the snapshots would fail silently as the user would not connect to the right database to take the commands. This PR fixes the issue by adding the dbname by default in the command, and adds a test to validate this works as intended. In addition, it also adds some logic to handle any error that does not cause the exec command to fail, such as database access failures. Run the added test to test this works as intended. Closes #2474 * Document the postgres dbname issue in the docs --- docs/modules/postgres.md | 6 +++ modules/postgres/postgres.go | 31 +++++++++++- modules/postgres/postgres_test.go | 81 +++++++++++++++++++++++++++++-- 3 files changed, 112 insertions(+), 6 deletions(-) diff --git a/docs/modules/postgres.md b/docs/modules/postgres.md index 683ed29162..ebd1ed919e 100644 --- a/docs/modules/postgres.md +++ b/docs/modules/postgres.md @@ -97,6 +97,12 @@ This example shows the usage of the postgres module's Snapshot feature to give e to recreate the database container on every test or run heavy scripts to clean your database. This makes the individual tests very modular, since they always run on a brand-new database. +!!!tip + You should never pass the `"postgres"` system database as the container database name if you want to use snapshots. + The Snapshot logic requires dropping the connected database and using the system database to run commands, which will + not work if the database for the container is set to `"postgres"`. + + [Test with a reusable Postgres container](../../modules/postgres/postgres_test.go) inside_block:snapshotAndReset diff --git a/modules/postgres/postgres.go b/modules/postgres/postgres.go index cd349494c8..ccc66ba12a 100644 --- a/modules/postgres/postgres.go +++ b/modules/postgres/postgres.go @@ -3,6 +3,7 @@ package postgres import ( "context" "fmt" + "io" "net" "path/filepath" "strings" @@ -184,6 +185,10 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption snapshotName = config.snapshotName } + if c.dbName == "postgres" { + return fmt.Errorf("cannot snapshot the postgres system database as it cannot be dropped to be restored") + } + // execute the commands to create the snapshot, in order cmds := []string{ // Drop the snapshot database if it already exists @@ -195,10 +200,19 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption } for _, cmd := range cmds { - _, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-c", cmd}) + exitCode, reader, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", c.dbName, "-c", cmd}) if err != nil { return err } + if exitCode != 0 { + buf := new(strings.Builder) + _, err := io.Copy(buf, reader) + if err != nil { + return fmt.Errorf("non-zero exit code for snapshot command, could not read command output: %w", err) + } + + return fmt.Errorf("non-zero exit code for snapshot command: %s", buf.String()) + } } c.snapshotName = snapshotName @@ -219,6 +233,10 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption) snapshotName = config.snapshotName } + if c.dbName == "postgres" { + return fmt.Errorf("cannot restore the postgres system database as it cannot be dropped to be restored") + } + // execute the commands to restore the snapshot, in order cmds := []string{ // Drop the entire database by connecting to the postgres global database @@ -228,10 +246,19 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption) } for _, cmd := range cmds { - _, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", "postgres", "-c", cmd}) + exitCode, reader, err := c.Exec(ctx, []string{"psql", "-v", "ON_ERROR_STOP=1", "-U", c.user, "-d", "postgres", "-c", cmd}) if err != nil { return err } + if exitCode != 0 { + buf := new(strings.Builder) + _, err := io.Copy(buf, reader) + if err != nil { + return fmt.Errorf("non-zero exit code for restore command, could not read command output: %w", err) + } + + return fmt.Errorf("non-zero exit code for restore command: %s", buf.String()) + } } return nil diff --git a/modules/postgres/postgres_test.go b/modules/postgres/postgres_test.go index 2a74e147f6..9dc2194685 100644 --- a/modules/postgres/postgres_test.go +++ b/modules/postgres/postgres_test.go @@ -87,12 +87,12 @@ func TestPostgres(t *testing.T) { connStr, err := container.ConnectionString(ctx, "sslmode=disable", "application_name=test") // } require.NoError(t, err) - - mustConnStr := container.MustConnectionString(ctx,"sslmode=disable", "application_name=test") - if mustConnStr!=connStr{ + + mustConnStr := container.MustConnectionString(ctx, "sslmode=disable", "application_name=test") + if mustConnStr != connStr { t.Errorf("ConnectionString was not equal to MustConnectionString") } - + // Ensure connection string is using generic format id, err := container.MappedPort(ctx, "5432/tcp") require.NoError(t, err) @@ -327,3 +327,76 @@ func TestSnapshot(t *testing.T) { }) // } } + +func TestSnapshotWithOverrides(t *testing.T) { + ctx := context.Background() + + dbname := "other-db" + user := "other-user" + password := "other-password" + + container, err := postgres.RunContainer( + ctx, + testcontainers.WithImage("docker.io/postgres:16-alpine"), + postgres.WithDatabase(dbname), + postgres.WithUsername(user), + postgres.WithPassword(password), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + _, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"}) + if err != nil { + t.Fatal(err) + } + + err = container.Snapshot(ctx, postgres.WithSnapshotName("other-snapshot")) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := container.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }) + + dbURL, err := container.ConnectionString(ctx) + if err != nil { + t.Fatal(err) + } + + t.Run("Test that the restore works when not using defaults", func(t *testing.T) { + _, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "INSERT INTO users(name, age) VALUES ('test', 42)"}) + if err != nil { + t.Fatal(err) + } + + // Doing the restore before we connect since this resets the pgx connection + err = container.Restore(ctx) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(context.Background(), dbURL) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + var count int64 + err = conn.QueryRow(context.Background(), "SELECT COUNT(1) FROM users").Scan(&count) + if err != nil { + t.Fatal(err) + } + + if count != 0 { + t.Fatalf("Expected %d to equal `0`", count) + } + }) +}