diff --git a/modules/postgres/postgres.go b/modules/postgres/postgres.go index cb3d01694b..9feac8b585 100644 --- a/modules/postgres/postgres.go +++ b/modules/postgres/postgres.go @@ -3,6 +3,7 @@ package postgres import ( "context" "fmt" + "io" "net" "path/filepath" "strings" @@ -26,10 +27,9 @@ type PostgresContainer struct { snapshotName string } - // MustConnectionString panics if the address cannot be determined. func (c *PostgresContainer) MustConnectionString(ctx context.Context, args ...string) string { - addr, err := c.ConnectionString(ctx,args...) + addr, err := c.ConnectionString(ctx, args...) if err != nil { panic(err) } @@ -185,6 +185,10 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption snapshotName = config.snapshotName } + if c.dbName == "postgres" { + return fmt.Errorf("cannot snapshot the postgres system database as it cannot be dropped to be restored") + } + // execute the commands to create the snapshot, in order cmds := []string{ // Drop the snapshot database if it already exists @@ -196,10 +200,19 @@ func (c *PostgresContainer) Snapshot(ctx context.Context, opts ...SnapshotOption } for _, cmd := range cmds { - _, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-c", cmd}) + exitCode, reader, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", c.dbName, "-c", cmd}) if err != nil { return err } + if exitCode != 0 { + buf := new(strings.Builder) + _, err := io.Copy(buf, reader) + if err != nil { + return fmt.Errorf("non-zero exit code for snapshot command, could not read command output: %w", err) + } + + return fmt.Errorf("non-zero exit code for snapshot command: %s", buf.String()) + } } c.snapshotName = snapshotName @@ -220,6 +233,10 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption) snapshotName = config.snapshotName } + if c.dbName == "postgres" { + return fmt.Errorf("cannot restore the postgres system database as it cannot be dropped to be restored") + } + // execute the commands to restore the snapshot, in order cmds := []string{ // Drop the entire database by connecting to the postgres global database @@ -229,10 +246,19 @@ func (c *PostgresContainer) Restore(ctx context.Context, opts ...SnapshotOption) } for _, cmd := range cmds { - _, _, err := c.Exec(ctx, []string{"psql", "-U", c.user, "-d", "postgres", "-c", cmd}) + exitCode, reader, err := c.Exec(ctx, []string{"psql", "-v", "ON_ERROR_STOP=1", "-U", c.user, "-d", "postgres", "-c", cmd}) if err != nil { return err } + if exitCode != 0 { + buf := new(strings.Builder) + _, err := io.Copy(buf, reader) + if err != nil { + return fmt.Errorf("non-zero exit code for restore command, could not read command output: %w", err) + } + + return fmt.Errorf("non-zero exit code for restore command: %s", buf.String()) + } } return nil diff --git a/modules/postgres/postgres_test.go b/modules/postgres/postgres_test.go index 2a74e147f6..9dc2194685 100644 --- a/modules/postgres/postgres_test.go +++ b/modules/postgres/postgres_test.go @@ -87,12 +87,12 @@ func TestPostgres(t *testing.T) { connStr, err := container.ConnectionString(ctx, "sslmode=disable", "application_name=test") // } require.NoError(t, err) - - mustConnStr := container.MustConnectionString(ctx,"sslmode=disable", "application_name=test") - if mustConnStr!=connStr{ + + mustConnStr := container.MustConnectionString(ctx, "sslmode=disable", "application_name=test") + if mustConnStr != connStr { t.Errorf("ConnectionString was not equal to MustConnectionString") } - + // Ensure connection string is using generic format id, err := container.MappedPort(ctx, "5432/tcp") require.NoError(t, err) @@ -327,3 +327,76 @@ func TestSnapshot(t *testing.T) { }) // } } + +func TestSnapshotWithOverrides(t *testing.T) { + ctx := context.Background() + + dbname := "other-db" + user := "other-user" + password := "other-password" + + container, err := postgres.RunContainer( + ctx, + testcontainers.WithImage("docker.io/postgres:16-alpine"), + postgres.WithDatabase(dbname), + postgres.WithUsername(user), + postgres.WithPassword(password), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + _, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "CREATE TABLE users (id SERIAL, name TEXT NOT NULL, age INT NOT NULL)"}) + if err != nil { + t.Fatal(err) + } + + err = container.Snapshot(ctx, postgres.WithSnapshotName("other-snapshot")) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := container.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate container: %s", err) + } + }) + + dbURL, err := container.ConnectionString(ctx) + if err != nil { + t.Fatal(err) + } + + t.Run("Test that the restore works when not using defaults", func(t *testing.T) { + _, _, err = container.Exec(ctx, []string{"psql", "-U", user, "-d", dbname, "-c", "INSERT INTO users(name, age) VALUES ('test', 42)"}) + if err != nil { + t.Fatal(err) + } + + // Doing the restore before we connect since this resets the pgx connection + err = container.Restore(ctx) + if err != nil { + t.Fatal(err) + } + + conn, err := pgx.Connect(context.Background(), dbURL) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + var count int64 + err = conn.QueryRow(context.Background(), "SELECT COUNT(1) FROM users").Scan(&count) + if err != nil { + t.Fatal(err) + } + + if count != 0 { + t.Fatalf("Expected %d to equal `0`", count) + } + }) +}