From 7dd3e1b1b4c846cbb2f8b1327058d941863d8b3f Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Thu, 30 Oct 2025 11:46:37 +0000 Subject: [PATCH 1/6] Upgrading to upstream version v5.7.6 --- .github/workflows/ci.yml | 28 +- .golangci.yml | 21 + CHANGELOG.md | 118 ++ CONTRIBUTING.md | 1 + README.md | 69 +- Rakefile | 2 +- batch.go | 92 +- batch_test.go | 87 +- bench_test.go | 153 ++- ci/setup_test.bash | 4 +- conn.go | 218 ++-- conn_test.go | 228 +++- copy_from_test.go | 1 - derived_types.go | 256 ++++ derived_types_test.go | 40 + doc.go | 34 +- extended_query_builder.go | 88 +- go.mod | 19 +- go.sum | 38 +- internal/sanitize/benchmmark.sh | 60 + internal/sanitize/sanitize.go | 184 ++- internal/sanitize/sanitize_bench_test.go | 62 + internal/sanitize/sanitize_fuzz_test.go | 55 + internal/sanitize/sanitize_test.go | 82 +- internal/stmtcache/lru_cache.go | 1 - large_objects.go | 7 +- multitracer/tracer.go | 152 +++ multitracer/tracer_test.go | 115 ++ named_args.go | 58 +- named_args_test.go | 58 + pgbouncer_test.go | 1 - pgconn/auth_scram.go | 2 +- pgconn/benchmark_test.go | 12 + pgconn/config.go | 192 +-- pgconn/config_test.go | 90 +- pgconn/ctxwatch/context_watcher.go | 80 ++ pgconn/ctxwatch/context_watcher_test.go | 185 +++ pgconn/doc.go | 16 +- pgconn/errors.go | 86 +- pgconn/errors_test.go | 6 +- pgconn/export_test.go | 8 - pgconn/krb5.go | 2 +- pgconn/pgconn.go | 703 +++++++---- pgconn/pgconn_private_test.go | 2 +- pgconn/pgconn_test.go | 1048 +++++++++++++++-- pgproto3/authentication_cleartext_password.go | 7 +- pgproto3/authentication_gss.go | 7 +- pgproto3/authentication_gss_continue.go | 7 +- pgproto3/authentication_md5_password.go | 7 +- pgproto3/authentication_ok.go | 7 +- pgproto3/authentication_sasl.go | 10 +- pgproto3/authentication_sasl_continue.go | 12 +- pgproto3/authentication_sasl_final.go | 12 +- pgproto3/backend.go | 40 +- pgproto3/backend_key_data.go | 7 +- pgproto3/backend_test.go | 4 +- pgproto3/bind.go | 19 +- pgproto3/bind_complete.go | 4 +- pgproto3/bind_test.go | 20 + pgproto3/cancel_request.go | 4 +- pgproto3/close.go | 14 +- pgproto3/close_complete.go | 4 +- pgproto3/command_complete.go | 14 +- pgproto3/copy_both_response.go | 13 +- pgproto3/copy_both_response_test.go | 6 +- pgproto3/copy_data.go | 9 +- pgproto3/copy_done.go | 7 +- pgproto3/copy_fail.go | 14 +- pgproto3/copy_in_response.go | 13 +- pgproto3/copy_out_response.go | 13 +- pgproto3/data_row.go | 15 +- pgproto3/describe.go | 14 +- pgproto3/empty_query_response.go | 4 +- pgproto3/error_response.go | 135 +-- pgproto3/example/pgfortune/server.go | 21 +- pgproto3/execute.go | 13 +- pgproto3/flush.go | 4 +- pgproto3/frontend.go | 151 ++- pgproto3/frontend_test.go | 18 + pgproto3/function_call.go | 19 +- pgproto3/function_call_response.go | 10 +- pgproto3/function_call_test.go | 7 +- pgproto3/gss_enc_request.go | 7 +- pgproto3/gss_response.go | 9 +- pgproto3/json_test.go | 4 +- pgproto3/no_data.go | 4 +- pgproto3/notice_response.go | 6 +- pgproto3/notification_response.go | 12 +- pgproto3/parameter_description.go | 15 +- pgproto3/parameter_status.go | 14 +- pgproto3/parse.go | 16 +- pgproto3/parse_complete.go | 4 +- pgproto3/password_message.go | 13 +- pgproto3/pgproto3.go | 28 +- pgproto3/pgproto3_private_test.go | 3 + pgproto3/portal_suspended.go | 4 +- pgproto3/query.go | 11 +- pgproto3/query_test.go | 20 + pgproto3/ready_for_query.go | 4 +- pgproto3/row_description.go | 16 +- pgproto3/sasl_initial_response.go | 12 +- pgproto3/sasl_response.go | 11 +- pgproto3/ssl_request.go | 7 +- pgproto3/startup_message.go | 6 +- pgproto3/sync.go | 4 +- pgproto3/terminate.go | 4 +- pgtype/array.go | 8 +- pgtype/array_codec.go | 3 +- pgtype/array_codec_test.go | 3 + pgtype/bits.go | 11 +- pgtype/bool.go | 15 +- pgtype/box.go | 9 +- pgtype/builtin_wrappers.go | 1 - pgtype/bytea.go | 1 - pgtype/bytea_test.go | 6 +- pgtype/circle.go | 6 +- pgtype/composite.go | 3 + pgtype/composite_test.go | 17 +- pgtype/date.go | 11 +- pgtype/derived_types_test.go | 61 + pgtype/doc.go | 28 +- pgtype/enum_codec_test.go | 4 +- pgtype/float4.go | 16 +- pgtype/float8.go | 14 +- pgtype/hstore.go | 13 +- pgtype/hstore_test.go | 7 +- pgtype/inet.go | 5 +- pgtype/int.go | 30 +- pgtype/int.go.erb | 9 +- pgtype/int_test.go | 3 +- pgtype/integration_benchmark_test.go | 2 + pgtype/integration_benchmark_test.go.erb | 4 +- pgtype/interval.go | 16 +- pgtype/interval_test.go | 19 + pgtype/json.go | 143 ++- pgtype/json_test.go | 142 ++- pgtype/jsonb.go | 28 +- pgtype/jsonb_test.go | 33 +- pgtype/line.go | 12 +- pgtype/lseg.go | 9 +- pgtype/macaddr_test.go | 19 + pgtype/multirange.go | 5 +- pgtype/multirange_test.go | 3 + pgtype/numeric.go | 36 +- pgtype/numeric_test.go | 3 + pgtype/path.go | 9 +- pgtype/pgtype.go | 150 ++- pgtype/pgtype_default.go | 35 +- pgtype/pgtype_test.go | 67 +- pgtype/point.go | 11 +- pgtype/polygon.go | 9 +- pgtype/range.go | 15 +- pgtype/range_codec_test.go | 6 + pgtype/record_codec.go | 1 - pgtype/text.go | 9 +- pgtype/tid.go | 7 +- pgtype/time.go | 58 +- pgtype/time_test.go | 68 ++ pgtype/timestamp.go | 77 +- pgtype/timestamp_test.go | 80 +- pgtype/timestamptz.go | 58 +- pgtype/timestamptz_test.go | 34 + pgtype/uint32.go | 55 +- pgtype/uint32_test.go | 1 + pgtype/uint64.go | 323 +++++ pgtype/uint64_test.go | 30 + pgtype/uuid.go | 16 +- pgtype/uuid_test.go | 46 +- pgtype/xml.go | 198 ++++ pgtype/xml_test.go | 128 ++ pgtype/zeronull/float8.go | 8 +- pgtype/zeronull/int.go | 90 +- pgtype/zeronull/int.go.erb | 29 +- pgtype/zeronull/int_test.go | 3 +- pgtype/zeronull/text.go | 7 +- pgtype/zeronull/timestamp.go | 7 +- pgtype/zeronull/timestamptz.go | 8 +- pgtype/zeronull/uuid.go | 8 +- pgx_test.go | 22 + pgxpool/common_test.go | 5 +- pgxpool/conn.go | 4 + pgxpool/doc.go | 2 +- pgxpool/pool.go | 183 ++- pgxpool/pool_test.go | 146 ++- pgxpool/stat.go | 7 + pgxpool/tracer.go | 33 + pgxpool/tracer_test.go | 130 ++ pgxpool/tx.go | 13 +- query_test.go | 165 ++- rows.go | 321 +++-- rows_test.go | 50 + stdlib/sql.go | 59 +- testsetup/generate_certs.go | 7 +- tracelog/tracelog.go | 113 +- tracelog/tracelog_test.go | 140 ++- tracer_test.go | 40 + tx.go | 20 +- values.go | 8 - values_test.go | 51 +- 199 files changed, 7831 insertions(+), 1780 deletions(-) create mode 100644 .golangci.yml create mode 100644 derived_types.go create mode 100644 derived_types_test.go create mode 100644 internal/sanitize/benchmmark.sh create mode 100644 internal/sanitize/sanitize_bench_test.go create mode 100644 internal/sanitize/sanitize_fuzz_test.go create mode 100644 multitracer/tracer.go create mode 100644 multitracer/tracer_test.go create mode 100644 pgconn/ctxwatch/context_watcher.go create mode 100644 pgconn/ctxwatch/context_watcher_test.go create mode 100644 pgproto3/bind_test.go create mode 100644 pgproto3/pgproto3_private_test.go create mode 100644 pgproto3/query_test.go create mode 100644 pgtype/derived_types_test.go create mode 100644 pgtype/uint64.go create mode 100644 pgtype/uint64_test.go create mode 100644 pgtype/xml.go create mode 100644 pgtype/xml_test.go create mode 100644 pgx_test.go create mode 100644 pgxpool/tracer.go create mode 100644 pgxpool/tracer_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f15a52fb9..ec221f59d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,10 @@ name: CI -on: workflow_dispatch +on: + push: + branches: [master] + pull_request: + branches: [master] jobs: test: @@ -9,10 +13,10 @@ jobs: strategy: matrix: - go-version: ["1.20", "1.21"] - pg-version: [12, 13, 14, 15, 16, cockroachdb] + go-version: ["1.23", "1.24"] + pg-version: [13, 14, 15, 16, 17, cockroachdb] include: - - pg-version: 12 + - pg-version: 13 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" @@ -22,7 +26,7 @@ jobs: pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - - pg-version: 13 + - pg-version: 14 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" @@ -32,7 +36,7 @@ jobs: pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - - pg-version: 14 + - pg-version: 15 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" @@ -42,7 +46,7 @@ jobs: pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - - pg-version: 15 + - pg-version: 16 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" @@ -52,7 +56,7 @@ jobs: pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test" pgx-ssl-password: certpw pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test" - - pg-version: 16 + - pg-version: 17 pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test" @@ -102,7 +106,8 @@ jobs: git diff --exit-code - name: Test - run: go test -v -race ./... + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -race ./... env: PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} @@ -120,7 +125,7 @@ jobs: runs-on: windows-latest strategy: matrix: - go-version: ["1.20", "1.21"] + go-version: ["1.23", "1.24"] steps: - name: Setup PostgreSQL @@ -145,6 +150,7 @@ jobs: shell: bash - name: Test - run: go test -v -race ./... + # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner. + run: go test -parallel=1 -race ./... env: PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }} diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 000000000..ca74c703a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,21 @@ +# See for configurations: https://golangci-lint.run/usage/configuration/ +version: 2 + +# See: https://golangci-lint.run/usage/formatters/ +formatters: + default: none + enable: + - gofmt # https://pkg.go.dev/cmd/gofmt + - gofumpt # https://github.com/mvdan/gofumpt + + settings: + gofmt: + simplify: true # Simplify code: gofmt with `-s` option. + + gofumpt: + # Module path which contains the source code being formatted. + # Default: "" + module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod + # Choose whether to use the extra rules. + # Default: false + extra-rules: true diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fcbc2473..6c9c99b5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,121 @@ +# 5.7.6 (September 8, 2025) + +* Use ParseConfigError in pgx.ParseConfig and pgxpool.ParseConfig (Yurasov Ilia) +* Add PrepareConn hook to pgxpool (Jonathan Hall) +* Reduce allocations in QueryContext (Dominique Lefevre) +* Add MarshalJSON and UnmarshalJSON for pgtype.Uint32 (Panos Koutsovasilis) +* Configure ping behavior on pgxpool with ShouldPing (Christian Kiely) +* zeronull int types implement Int64Valuer and Int64Scanner (Li Zeghong) +* Fix panic when receiving terminate connection message during CopyFrom (Michal Drausowski) +* Fix statement cache not being invalidated on error during batch (Muhammadali Nazarov) + +# 5.7.5 (May 17, 2025) + +* Support sslnegotiation connection option (divyam234) +* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487. +* TraceLog now logs Acquire and Release at the debug level (dave sinclair) +* Add support for PGTZ environment variable +* Add support for PGOPTIONS environment variable +* Unpin memory used by Rows quicker +* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit. + +# 5.7.4 (March 24, 2025) + +* Fix / revert change to scanning JSON `null` (Felix Röhrich) + +# 5.7.3 (March 21, 2025) + +* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32) +* Improve SQL sanitizer performance (ninedraft) +* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich) +* Fix Values() for xml type always returning nil instead of []byte +* Add ability to send Flush message in pipeline mode (zenkovev) +* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou) +* Better error messages when scanning structs (logicbomb) +* Fix handling of error on batch write (bonnefoa) +* Match libpq's connection fallback behavior more closely (felix-roehrich) +* Add MinIdleConns to pgxpool (djahandarie) + +# 5.7.2 (December 21, 2024) + +* Fix prepared statement already exists on batch prepare failure +* Add commit query to tx options (Lucas Hild) +* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels) +* Add message body size limits in frontend and backend (zene) +* Add xid8 type +* Ensure planning encodes and scans cannot infinitely recurse +* Implement pgtype.UUID.String() (Konstantin Grachev) +* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev) +* Update golang.org/x/crypto +* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo) + +# 5.7.1 (September 10, 2024) + +* Fix data race in tracelog.TraceLog +* Update puddle to v2.2.2. This removes the import of nanotime via linkname. +* Update golang.org/x/crypto and golang.org/x/text + +# 5.7.0 (September 7, 2024) + +* Add support for sslrootcert=system (Yann Soubeyrand) +* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell) +* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda) +* Add MultiTrace (Stepan Rabotkin) +* Add TraceLogConfig with customizable TimeKey (stringintech) +* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin) +* Support scanning binary formatted uint32 into string / TextScanner (jennifersp) +* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce) +* Update pgservicefile - fixes panic when parsing invalid file +* Better error message when reading past end of batch +* Don't print url when url.Parse returns an error (Kevin Biju) +* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler) +* Fix: Scan and encode types with underlying types of arrays + +# 5.6.0 (May 25, 2024) + +* Add StrictNamedArgs (Tomas Zahradnicek) +* Add support for macaddr8 type (Carlos Pérez-Aradros Herce) +* Add SeverityUnlocalized field to PgError / Notice +* Performance optimization of RowToStructByPos/Name (Zach Olstein) +* Allow customizing context canceled behavior for pgconn +* Add ScanLocation to pgtype.Timestamp[tz]Codec +* Add custom data to pgconn.PgConn +* Fix ResultReader.Read() to handle nil values +* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce) +* pgconn.SafeToRetry checks for wrapped errors (tjasko) +* Failed connection attempts include all errors +* Optimize LargeObject.Read (Mitar) +* Add tracing for connection acquire and release from pool (ngavinsir) +* Fix encode driver.Valuer not called when nil +* Add support for custom JSON marshal and unmarshal (Mitar) +* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck) + +# 5.5.5 (March 9, 2024) + +Use spaces instead of parentheses for SQL sanitization. + +This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as +`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed. + +# 5.5.4 (March 4, 2024) + +Fix CVE-2024-27304 + +SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer +overflow in the calculated message size can cause the one large message to be sent as multiple messages under the +attacker's control. + +Thanks to Paul Gerste for reporting this issue. + +* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix) +* Fix simple protocol encoding of json.RawMessage +* Fix *Pipeline.getResults should close pipeline on error +* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman) +* Fix deallocation of invalidated cached statements in a transaction +* Handle invalid sslkey file +* Fix scan float4 into sql.Scanner +* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads. + # 5.5.3 (February 3, 2024) * Fix: prepared statement already exists diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6ed3205ce..c975a9372 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,6 +29,7 @@ Create and setup a test database: export PGDATABASE=pgx_test createdb psql -c 'create extension hstore;' +psql -c 'create extension ltree;' psql -c 'create domain uint64 as numeric(20,0);' ``` diff --git a/README.md b/README.md index cb72fc08c..a57ff84ab 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,8 @@ Details about the upstream pgx driver - which hold true for this driver as well pgx is a pure Go driver and toolkit for PostgreSQL. -The pgx driver is a low-level, high performance interface. It also includes an adapter for the standard `database/sql` interface. +The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` / +`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface. The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, @@ -109,11 +110,11 @@ import ( "fmt" "os" +<<<<<<< HEAD "github.com/yugabyte/pgx/v5" +======= + "github.com/jackc/pgx/v5" ) - -func main() { - // urlExample := "postgres://username:password@localhost:5433/database_name" conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err) @@ -146,6 +147,10 @@ See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-starte * `COPY` protocol support for faster bulk data loads * Tracing and logging support * Connection pool with after-connect hook for arbitrary connection setup +<<<<<<< HEAD +======= +* `LISTEN` / `NOTIFY` +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings * `hstore` support * `json` and `jsonb` support @@ -158,7 +163,12 @@ See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-starte ## Choosing Between the pgx and database/sql Interfaces +<<<<<<< HEAD The pgx interface is faster. +======= +The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available +through the `database/sql` interface. +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e The pgx interface is recommended when: @@ -169,7 +179,11 @@ It is also possible to use the `database/sql` interface and convert a connection ## Testing +<<<<<<< HEAD See CONTRIBUTING.md for setup instructions. +======= +See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions. +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e ## Architecture @@ -177,7 +191,11 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions +<<<<<<< HEAD pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. +======= +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e ## Version Policy @@ -189,7 +207,11 @@ pgx follows semantic versioning for the documented public API on stable releases pglogrepl provides functionality to act as a client for PostgreSQL logical replication. +<<<<<<< HEAD pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. +======= +### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). @@ -203,12 +225,25 @@ pgerrcode contains constants for the PostgreSQL error codes. ## Adapters for 3rd Party Types +<<<<<<< HEAD ### [github.com/yugabyte/pgx/v4/pgxpool](https://github.com/yugabyte/pgx/tree/master/pgxpool) ### [github.com/yugabyte/pgx/v4/stdlib](https://github.com/yugabyte/pgx/tree/master/stdlib) * [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +======= +* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) +* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) +* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) +* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) + + +## Adapters for 3rd Party Tracers + +* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) +* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx) +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e ## Adapters for 3rd Party Loggers @@ -238,7 +273,11 @@ Library for scanning data from a database into Go structs and more. A carefully designed SQL client for making using SQL easier, more productive, and less error-prone on Golang. +<<<<<<< HEAD ### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) +======= +### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e Adds GSSAPI / Kerberos authentication support. @@ -251,6 +290,28 @@ Explicit data mapping and scanning library for Go structs and slices. Type safe and flexible package for scanning database data into Go types. Supports, structs, maps, slices and custom mapping functions. +<<<<<<< HEAD ### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) Code first migration library for native pgx (no database/sql abstraction). +======= +### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) + +Code first migration library for native pgx (no database/sql abstraction). + +### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring) + +A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry. + +### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox) + +Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver. + +### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy) + +Simplifies working with the pgx library, providing convenient scanning of nested structures. + +## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) + +Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e diff --git a/Rakefile b/Rakefile index d957573e9..3e3aa5030 100644 --- a/Rakefile +++ b/Rakefile @@ -2,7 +2,7 @@ require "erb" rule '.go' => '.go.erb' do |task| erb = ERB.new(File.read(task.source)) - File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) + File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding)) sh "goimports", "-w", task.name end diff --git a/batch.go b/batch.go index 5ae987101..1401e1388 100644 --- a/batch.go +++ b/batch.go @@ -12,7 +12,7 @@ import ( type QueuedQuery struct { SQL string Arguments []any - fn batchItemFunc + Fn batchItemFunc sd *pgconn.StatementDescription } @@ -20,7 +20,7 @@ type batchItemFunc func(br BatchResults) error // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) Query(fn func(rows Rows) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { rows, _ := br.Query() defer rows.Close() @@ -36,15 +36,19 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) { // Query sets fn to be called when the response to qq is received. func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { row := br.QueryRow() return fn(row) } } // Exec sets fn to be called when the response to qq is received. +// +// Note: for simple batch insert uses where it is not required to handle +// each potential error individually, it's sufficient to not set any callbacks, +// and just handle the return value of BatchResults.Close. func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { - qq.fn = func(br BatchResults) error { + qq.Fn = func(br BatchResults) error { ct, err := br.Exec() if err != nil { return err @@ -60,9 +64,13 @@ type Batch struct { QueuedQueries []*QueuedQuery } -// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. -// The only pgx option argument that is supported is QueryRewriter. Queries are executed using the -// connection's DefaultQueryExecMode. +// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option +// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode. +// +// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should +// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query, +// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that +// include the current query may reference the wrong query. func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { qq := &QueuedQuery{ SQL: query, @@ -79,7 +87,7 @@ func (b *Batch) Len() int { type BatchResults interface { // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer - // calling Exec on the QueuedQuery. + // calling Exec on the QueuedQuery, or just calling Close. Exec() (pgconn.CommandTag, error) // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer @@ -94,6 +102,9 @@ type BatchResults interface { // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an // error or the batch encounters an error subsequent callback functions will not be called. // + // For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks, + // and just handle the return value of Close. + // // Close must be called before the underlying connection can be used again. Any error that occurred during a batch // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying // connection will have been closed. @@ -128,7 +139,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) { if !br.mrr.NextResult() { err := br.mrr.Close() if err == nil { - err = errors.New("no result") + err = errors.New("no more results in batch") } if br.conn.batchTracer != nil { br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ @@ -180,7 +191,7 @@ func (br *batchResults) Query() (Rows, error) { if !br.mrr.NextResult() { rows.err = br.mrr.Close() if rows.err == nil { - rows.err = errors.New("no result") + rows.err = errors.New("no more results in batch") } rows.closed = true @@ -203,7 +214,6 @@ func (br *batchResults) Query() (Rows, error) { func (br *batchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -216,6 +226,8 @@ func (br *batchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err != nil { @@ -228,8 +240,8 @@ func (br *batchResults) Close() error { // Read and run fn for all remaining items for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { - if br.b.QueuedQueries[br.qqIdx].fn != nil { - err := br.b.QueuedQueries[br.qqIdx].fn(br) + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) if err != nil { br.err = err } @@ -287,7 +299,10 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) { return pgconn.CommandTag{}, br.err } - query, arguments, _ := br.nextQueryAndArgs() + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return pgconn.CommandTag{}, err + } results, err := br.pipeline.GetResults() if err != nil { @@ -330,9 +345,9 @@ func (br *pipelineBatchResults) Query() (Rows, error) { return &baseRows{err: br.err, closed: true}, br.err } - query, arguments, ok := br.nextQueryAndArgs() - if !ok { - query = "batch query" + query, arguments, err := br.nextQueryAndArgs() + if err != nil { + return &baseRows{err: err, closed: true}, err } rows := br.conn.getRows(br.ctx, query, arguments) @@ -371,7 +386,6 @@ func (br *pipelineBatchResults) Query() (Rows, error) { func (br *pipelineBatchResults) QueryRow() Row { rows, _ := br.Query() return (*connRow)(rows.(*baseRows)) - } // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to @@ -384,6 +398,8 @@ func (br *pipelineBatchResults) Close() error { } br.endTraced = true } + + invalidateCachesOnBatchResultsError(br.conn, br.b, br.err) }() if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { @@ -397,8 +413,8 @@ func (br *pipelineBatchResults) Close() error { // Read and run fn for all remaining items for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { - if br.b.QueuedQueries[br.qqIdx].fn != nil { - err := br.b.QueuedQueries[br.qqIdx].fn(br) + if br.b.QueuedQueries[br.qqIdx].Fn != nil { + err := br.b.QueuedQueries[br.qqIdx].Fn(br) if err != nil { br.err = err } @@ -421,13 +437,33 @@ func (br *pipelineBatchResults) earlyError() error { return br.err } -func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { - if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) { - bi := br.b.QueuedQueries[br.qqIdx] - query = bi.SQL - args = bi.Arguments - ok = true - br.qqIdx++ +func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) { + if br.b == nil { + return "", nil, errors.New("no reference to batch") + } + + if br.qqIdx >= len(br.b.QueuedQueries) { + return "", nil, errors.New("no more results in batch") + } + + bi := br.b.QueuedQueries[br.qqIdx] + br.qqIdx++ + return bi.SQL, bi.Arguments, nil +} + +// invalidates statement and description caches on batch results error +func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { + if err != nil && conn != nil && b != nil { + if sc := conn.statementCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } + + if sc := conn.descriptionCache; sc != nil { + for _, bi := range b.QueuedQueries { + sc.Invalidate(bi.SQL) + } + } } - return } diff --git a/batch_test.go b/batch_test.go index 716b5eb04..06f62f672 100644 --- a/batch_test.go +++ b/batch_test.go @@ -290,6 +290,45 @@ func TestConnSendBatchMany(t *testing.T) { }) } +// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178 +func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + br := conn.SendBatch(ctx, batch) + commandTag, err := br.Exec() + require.Equal(t, "", commandTag.String()) + require.EqualError(t, err, "no more results in batch") + err = br.Close() + require.NoError(t, err) + }) +} + +func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + batch := &pgx.Batch{} + batch.Queue("select 1") + br := conn.SendBatch(ctx, batch) + commandTag, err := br.Exec() + require.Equal(t, "SELECT 1", commandTag.String()) + require.NoError(t, err) + commandTag, err = br.Exec() + require.Equal(t, "", commandTag.String()) + require.EqualError(t, err, "no more results in batch") + err = br.Close() + require.NoError(t, err) + }) +} + func TestConnSendBatchWithPreparedStatement(t *testing.T) { t.Parallel() @@ -449,7 +488,10 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n") @@ -500,7 +542,6 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { if err != nil { t.Fatal(err) } - }) } @@ -511,7 +552,6 @@ func TestConnSendBatchQueryError(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") batch.Queue("select n from generate_series(0,5) n") @@ -541,7 +581,6 @@ func TestConnSendBatchQueryError(t *testing.T) { if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { t.Errorf("br.Close() => %v, want error code %v", err, 22012) } - }) } @@ -552,7 +591,6 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - batch := &pgx.Batch{} batch.Queue("select 1 1") @@ -568,7 +606,6 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { if err == nil { t.Error("Expected error") } - }) } @@ -579,7 +616,6 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger( id serial primary key, description varchar not null, @@ -608,7 +644,6 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) { } br.Close() - }) } @@ -619,7 +654,6 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger( id serial primary key, description varchar not null, @@ -648,7 +682,6 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { } br.Close() - }) } @@ -659,7 +692,6 @@ func TestTxSendBatch(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger1( id serial primary key, description varchar not null @@ -718,7 +750,6 @@ func TestTxSendBatch(t *testing.T) { if err != nil { t.Fatal(err) } - }) } @@ -729,7 +760,6 @@ func TestTxSendBatchRollback(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - sql := `create temporary table ledger1( id serial primary key, description varchar not null @@ -756,7 +786,6 @@ func TestTxSendBatchRollback(t *testing.T) { if count != 0 { t.Errorf("count => %v, want %v", count, 0) } - }) } @@ -816,7 +845,6 @@ func TestConnBeginBatchDeferredError(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") mustExec(t, conn, `create temporary table t ( @@ -855,7 +883,6 @@ func TestConnBeginBatchDeferredError(t *testing.T) { if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { t.Fatalf("expected error 23505, got %v", err) } - }) } @@ -969,6 +996,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) { assert.False(t, rows.Next()) } +// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887 +func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") + + mustExec(t, conn, `create temporary table foo(col1 text primary key);`) + + batch := &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err := conn.SendBatch(ctx, batch).Close() + require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`) + + mustExec(t, conn, `create temporary table baz(col1 text primary key);`) + + // Since table baz now exists, the batch should succeed. + + batch = &pgx.Batch{} + batch.Queue("select col1 from foo") + batch.Queue("select col1 from baz") + err = conn.SendBatch(ctx, batch).Close() + require.NoError(t, err) + }) +} + func ExampleConn_SendBatch() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() diff --git a/bench_test.go b/bench_test.go index 38283e63c..28edcd02a 100644 --- a/bench_test.go +++ b/bench_test.go @@ -516,7 +516,6 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc } return rowCount, nil - } func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { @@ -535,7 +534,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { src := newBenchmarkWriteTableCopyFromSrc(n) _, err := multiInsert(conn, "t", - []string{"varchar_1", + []string{ + "varchar_1", "varchar_2", "varchar_null_1", "date_1", @@ -547,7 +547,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { "tstz_2", "bool_1", "bool_2", - "bool_3"}, + "bool_3", + }, src) if err != nil { b.Fatal(err) @@ -568,7 +569,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { _, err := conn.CopyFrom(context.Background(), pgx.Identifier{"t"}, - []string{"varchar_1", + []string{ + "varchar_1", "varchar_2", "varchar_null_1", "date_1", @@ -580,7 +582,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) { "tstz_2", "bool_1", "bool_2", - "bool_3"}, + "bool_3", + }, src) if err != nil { b.Fatal(err) @@ -615,6 +618,10 @@ func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 5) } +func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 5) +} + func BenchmarkWrite5RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 5) } @@ -630,6 +637,10 @@ func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 10) } +func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 10) +} + func BenchmarkWrite10RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10) } @@ -645,6 +656,10 @@ func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 100) } +func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 100) +} + func BenchmarkWrite100RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 100) } @@ -676,6 +691,10 @@ func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 10000) } +func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { + benchmarkWriteNRowsViaBatchInsert(b, 10000) +} + func BenchmarkWrite10000RowsViaCopy(b *testing.B) { benchmarkWriteNRowsViaCopy(b, 10000) } @@ -944,6 +963,7 @@ type BenchRowSimple struct { BirthDate time.Time Weight int32 Height int32 + Tags []string UpdateTime time.Time } @@ -957,13 +977,13 @@ func BenchmarkSelectRowsScanSimple(b *testing.B) { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowSimple{} for i := 0; i < b.N; i++ { - rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { - rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) } if rows.Err() != nil { @@ -982,6 +1002,7 @@ type BenchRowStringBytes struct { BirthDate time.Time Weight int32 Height int32 + Tags []string UpdateTime time.Time } @@ -995,13 +1016,13 @@ func BenchmarkSelectRowsScanStringBytes(b *testing.B) { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { br := &BenchRowStringBytes{} for i := 0; i < b.N; i++ { - rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) if err != nil { b.Fatal(err) } for rows.Next() { - rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) } if rows.Err() != nil { @@ -1020,6 +1041,7 @@ type BenchRowDecoder struct { BirthDate pgtype.Date Weight pgtype.Int4 Height pgtype.Int4 + Tags pgtype.FlatArray[string] UpdateTime pgtype.Timestamptz } @@ -1040,12 +1062,11 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) { } for _, format := range formats { b.Run(format.name, func(b *testing.B) { - br := &BenchRowDecoder{} for i := 0; i < b.N; i++ { rows, err := conn.Query( context.Background(), - "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", pgx.QueryResultFormats{format.code}, rowCount, ) @@ -1054,7 +1075,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) { } for rows.Next() { - rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) + rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime) } if rows.Err() != nil { @@ -1076,7 +1097,7 @@ func BenchmarkSelectRowsPgConnExecText(b *testing.B) { for _, rowCount := range rowCounts { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { for i := 0; i < b.N; i++ { - mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount)) + mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount)) for mrr.NextResult() { rr := mrr.ResultReader() for rr.NextRow() { @@ -1113,11 +1134,11 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) { for i := 0; i < b.N; i++ { rr := conn.PgConn().ExecParams( context.Background(), - "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", + "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, nil, - []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, ) for rr.NextRow() { rr.Values() @@ -1134,13 +1155,107 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) { } } +func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + benchRows := make([]BenchRowSimple, 0, rowCount) + for i := 0; i < b.N; i++ { + benchRows = benchRows[:0] + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + var err error + benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount) + benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + +func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) { + conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(b, conn) + + rowCounts := getSelectRowsCounts(b) + + for _, rowCount := range rowCounts { + b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { + benchRows := make([]BenchRowSimple, 0, rowCount) + for i := 0; i < b.N; i++ { + benchRows = benchRows[:0] + rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) + var err error + benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple]) + if err != nil { + b.Fatal(err) + } + if len(benchRows) != int(rowCount) { + b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows)) + } + } + }) + } +} + func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(b, conn) rowCounts := getSelectRowsCounts(b) - _, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + _, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) if err != nil { b.Fatal(err) } @@ -1162,7 +1277,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { "ps1", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, - []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, ) for rr.NextRow() { rr.Values() @@ -1241,7 +1356,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) { conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn() defer conn.Close(context.Background()) - _, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) + _, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) if err != nil { b.Fatal(err) } @@ -1264,7 +1379,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) { "ps1", [][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, nil, - []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, + []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code}, ) _, err := rr.Close() require.NoError(b, err) diff --git a/ci/setup_test.bash b/ci/setup_test.bash index 66ba07d4d..d591c512c 100755 --- a/ci/setup_test.bash +++ b/ci/setup_test.bash @@ -42,8 +42,8 @@ fi if [[ "${PGVERSION-}" =~ ^cockroach ]] then - wget -qO- https://binaries.cockroachdb.com/cockroach-v23.1.3.linux-amd64.tgz | tar xvz - sudo mv cockroach-v23.1.3.linux-amd64/cockroach /usr/local/bin/ + wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz + sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/ cockroach start-single-node --insecure --background --listen-addr=localhost cockroach sql --insecure -e 'create database pgx_test' fi diff --git a/conn.go b/conn.go index 9593a6049..07223db22 100644 --- a/conn.go +++ b/conn.go @@ -3,6 +3,7 @@ package pgx import ( "context" "crypto/sha256" + "database/sql" "encoding/hex" "errors" "fmt" @@ -10,7 +11,6 @@ import ( "strings" "time" - "github.com/yugabyte/pgx/v5/internal/anynil" "github.com/yugabyte/pgx/v5/internal/sanitize" "github.com/yugabyte/pgx/v5/internal/stmtcache" "github.com/yugabyte/pgx/v5/pgconn" @@ -32,6 +32,10 @@ type ConnConfig struct { // query exec mode. StatementCacheCapacity int + // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" + // query exec mode. + StatementCacheCapacity int + // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with // "cache_describe" query exec mode. DescriptionCacheCapacity int @@ -56,6 +60,11 @@ type ParseConfigOptions struct { pgconn.ParseConfigOptions } +// ParseConfigOptions contains options that control how a config is built such as getsslpassword. +type ParseConfigOptions struct { + pgconn.ParseConfigOptions +} + // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the tls.Config: // according to the tls.Config docs it must not be modified after creation. @@ -92,8 +101,6 @@ type Conn struct { wbuf []byte eqb ExtendedQueryBuilder - - closeCntUpdated bool } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -112,13 +119,31 @@ func (ident Identifier) Sanitize() string { var ( // ErrNoRows occurs when rows are expected but none are returned. - ErrNoRows = errors.New("no rows in result set") + ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set") // ErrTooManyRows occurs when more rows than expected are returned. ErrTooManyRows = errors.New("too many rows in result set") ) -var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") -var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +func newProxyErr(background error, msg string) error { + return &proxyError{ + msg: msg, + background: background, + } +} + +type proxyError struct { + msg string + background error +} + +func (err *proxyError) Error() string { return err.msg } + +func (err *proxyError) Unwrap() error { return err.background } + +var ( + errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") + errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") +) // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. @@ -148,6 +173,16 @@ func ConnectWithOptions(ctx context.Context, connString string, options ParseCon } } +// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to +// provide a GetSSLPassword function. +func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) { + connConfig, err := ParseConfigWithOptions(connString, options) + if err != nil { + return nil, err + } + return connect(ctx, connConfig) +} + // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // connConfig must have been created by ParseConfig. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { @@ -175,7 +210,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "statement_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err) } statementCacheCapacity = int(n) } @@ -185,30 +220,11 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con delete(config.RuntimeParams, "description_cache_capacity") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err) } descriptionCacheCapacity = int(n) } - defaultQueryExecMode := QueryExecModeCacheStatement - if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { - delete(config.RuntimeParams, "default_query_exec_mode") - switch s { - case "cache_statement": - defaultQueryExecMode = QueryExecModeCacheStatement - case "cache_describe": - defaultQueryExecMode = QueryExecModeCacheDescribe - case "describe_exec": - defaultQueryExecMode = QueryExecModeDescribeExec - case "exec": - defaultQueryExecMode = QueryExecModeExec - case "simple_protocol": - defaultQueryExecMode = QueryExecModeSimpleProtocol - default: - return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s) - } - } - var loadBalance string = "false" if s, ok := config.RuntimeParams["load_balance"]; ok { delete(config.RuntimeParams, "load_balance") @@ -240,6 +256,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con } else { return nil, err } + descriptionCacheCapacity = int(n) } refreshInterval := int64(REFRESH_INTERVAL_SECONDS) @@ -250,7 +267,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con refreshInterval = int64(refresh) } } else { - return nil, fmt.Errorf("invalid refresh_interval: %v", err) + return nil, pgconn.NewParseConfigError(connString, "invalid refresh_interval", err) } } @@ -260,7 +277,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con if b, err := strconv.ParseBool(s); err == nil { fallbackToTopologyKeysOnly = b } else { - return nil, fmt.Errorf("invalid fallback_to_topology_keys_only: %v", err) + return nil, pgconn.NewParseConfigError(connString, "invalid fallback_to_topology_keys_only", err) } } @@ -272,7 +289,24 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con failedHostReconnectDelaySecs = int64(reconnect) } } else { - return nil, fmt.Errorf("invalid failed_host_reconnect_delay_secs: %v", err) + return nil, pgconn.NewParseConfigError(connString, "invalid failed_host_reconnect_delay_secs", err) + + defaultQueryExecMode := QueryExecModeCacheStatement + if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { + delete(config.RuntimeParams, "default_query_exec_mode") + switch s { + case "cache_statement": + defaultQueryExecMode = QueryExecModeCacheStatement + case "cache_describe": + defaultQueryExecMode = QueryExecModeCacheDescribe + case "describe_exec": + defaultQueryExecMode = QueryExecModeDescribeExec + case "exec": + defaultQueryExecMode = QueryExecModeExec + case "simple_protocol": + defaultQueryExecMode = QueryExecModeSimpleProtocol + default: + return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err) } } @@ -511,7 +545,7 @@ func (c *Conn) IsClosed() bool { return c.pgConn.IsClosed() } -func (c *Conn) die(err error) { +func (c *Conn) die() { if c.IsClosed() { return } @@ -679,14 +713,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription return result.CommandTag, result.Err } -type unknownArgumentTypeQueryExecModeExecError struct { - arg any -} - -func (e *unknownArgumentTypeQueryExecModeExecError) Error() string { - return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg) -} - func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { err := c.eqb.Build(c.typeMap, nil, args) if err != nil { @@ -733,7 +759,7 @@ const ( // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the // statement description on the first round trip and then uses it to execute the query on the second round trip. This // may cause problems with connection poolers that switch the underlying connection between round trips. It is safe - // even when the the database schema is modified concurrently. + // even when the database schema is modified concurrently. QueryExecModeDescribeExec // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol @@ -741,21 +767,33 @@ const ( // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. + // + // On rare occasions user defined types may behave differently when encoded in the text format instead of the binary + // format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as + // Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7. + // But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it + // is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query. + // This should not occur with types pgx supports directly and can be avoided by registering the types with + // pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it + // should implement pgtype.Int64Valuer. QueryExecModeExec - // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. - // Queries are executed in a single round trip. Type mappings can be registered with + // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is + // especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used + // instead for text type values including json and jsonb. Type mappings can be registered with // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous. - // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use - // a map[string]string directly as an argument. This mode cannot. + // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a + // map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip. // - // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor - // exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. + // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes + // the warning regarding differences in text format and binary format encoding with user defined types. There may be + // other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single + // string. // // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer - // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol - // should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does - // not support the extended protocol. + // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should + // only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not + // support the extended protocol. QueryExecModeSimpleProtocol ) @@ -864,7 +902,6 @@ optionLoop: } c.eqb.reset() - anynil.NormalizeSlice(args) rows := c.getRows(ctx, sql, args) var err error @@ -954,7 +991,6 @@ func (c *Conn) getStatementDescription( mode QueryExecMode, sql string, ) (sd *pgconn.StatementDescription, err error) { - switch mode { case QueryExecModeCacheStatement: if c.statementCache == nil { @@ -997,6 +1033,9 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row { // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // is used again. +// +// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table +// and using it in a subsequent query in the same batch can fail. func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { if c.batchTracer != nil { ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) @@ -1219,47 +1258,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d // Prepare any needed queries if len(distinctNewQueries) > 0 { - for _, sd := range distinctNewQueries { - pipeline.SendPrepare(sd.Name, sd.SQL, nil) - } + err := func() (err error) { + for _, sd := range distinctNewQueries { + pipeline.SendPrepare(sd.Name, sd.SQL, nil) + } - err := pipeline.Sync() - if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} - } + // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will + // clean them up later. + if sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Put(sd) + } + } + + // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added. + defer func() { + if err != nil && sdCache != nil { + for _, sd := range distinctNewQueries { + sdCache.Invalidate(sd.SQL) + } + } + }() + + err = pipeline.Sync() + if err != nil { + return err + } + + for _, sd := range distinctNewQueries { + results, err := pipeline.GetResults() + if err != nil { + return err + } + + resultSD, ok := results.(*pgconn.StatementDescription) + if !ok { + return fmt.Errorf("expected statement description, got %T", results) + } + + // Fill in the previously empty / pending statement descriptions. + sd.ParamOIDs = resultSD.ParamOIDs + sd.Fields = resultSD.Fields + } - for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} + return err } - resultSD, ok := results.(*pgconn.StatementDescription) + _, ok := results.(*pgconn.PipelineSync) if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true} + return fmt.Errorf("expected sync, got %T", results) } - // Fill in the previously empty / pending statement descriptions. - sd.ParamOIDs = resultSD.ParamOIDs - sd.Fields = resultSD.Fields - } - - results, err := pipeline.GetResults() + return nil + }() if err != nil { return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} } - - _, ok := results.(*pgconn.PipelineSync) - if !ok { - return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true} - } - } - - // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. - if sdCache != nil { - for _, sd := range distinctNewQueries { - sdCache.Put(sd) - } } // Queue the queries. @@ -1463,7 +1519,7 @@ order by attnum`, } func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { - if c.pgConn.TxStatus() != 'I' { + if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' { return nil } diff --git a/conn_test.go b/conn_test.go index d81b1614c..311b8c536 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "bytes" "context" + "database/sql" "os" "strings" "sync" @@ -411,7 +412,6 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) { if commandTag.String() != "INSERT 0 1" { t.Fatalf("Unexpected results from Exec: %v", commandTag) } - } func TestPrepare(t *testing.T) { @@ -1088,7 +1088,7 @@ func TestLoadRangeType(t *testing.T) { conn.TypeMap().RegisterType(newRangeType) conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange") - var inputRangeType = pgtype.Range[float64]{ + inputRangeType := pgtype.Range[float64]{ Lower: 1.0, Upper: 2.0, LowerType: pgtype.Inclusive, @@ -1128,7 +1128,7 @@ func TestLoadMultiRangeType(t *testing.T) { conn.TypeMap().RegisterType(newMultiRangeType) conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange") - var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{ + inputMultiRangeType := pgtype.Multirange[pgtype.Range[float64]]{ { Lower: 1.0, Upper: 2.0, @@ -1292,6 +1292,177 @@ func TestStmtCacheInvalidationTx(t *testing.T) { ensureConnValid(t, conn) } +func TestStmtCacheInvalidationConnWithBatch(t *testing.T) { + ctx := context.Background() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Test fails due to different CRDB behavior") + } + + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) + + getSQL := "SELECT * FROM drop_cols WHERE id = $1" + + // This query will populate the statement cache. We don't care about the result. + rows, err := conn.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + batch := &pgx.Batch{} + batch.Queue(getSQL, 1) + br := conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + err = br.Close() + require.Error(t, err) + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } + } + + // On retry, the statement should have been flushed from the cache. + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + err = br.Close() + require.NoError(t, err) + + ensureConnValid(t, conn) +} + +func TestStmtCacheInvalidationTxWithBatch(t *testing.T) { + ctx := context.Background() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)") + } + + // create a table and fill it with some data + _, err := conn.Exec(ctx, ` + DROP TABLE IF EXISTS drop_cols; + CREATE TABLE drop_cols ( + id SERIAL PRIMARY KEY NOT NULL, + f1 int NOT NULL, + f2 int NOT NULL + ); + `) + require.NoError(t, err) + _, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)") + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + + getSQL := "SELECT * FROM drop_cols WHERE id = $1" + + // This query will populate the statement cache. We don't care about the result. + rows, err := tx.Query(ctx, getSQL, 1) + require.NoError(t, err) + rows.Close() + require.NoError(t, rows.Err()) + + // Now, change the schema of the table out from under the statement, making it invalid. + _, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1") + require.NoError(t, err) + + // We must get an error the first time we try to re-execute a bad statement. + // It is up to the application to determine if it wants to try again. We punt to + // the application because there is no clear recovery path in the case of failed transactions + // or batch operations and because automatic retry is tricky and we don't want to get + // it wrong at such an importaint layer of the stack. + batch := &pgx.Batch{} + batch.Queue(getSQL, 1) + br := tx.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Next() + nextErr := rows.Err() + rows.Close() + err = br.Close() + require.Error(t, err) + for _, err := range []error{nextErr, rows.Err()} { + if err == nil { + t.Fatal(`expected "cached plan must not change result type": no error`) + } + if !strings.Contains(err.Error(), "cached plan must not change result type") { + t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error()) + } + } + + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = tx.SendBatch(ctx, batch) + rows, err = br.Query() + require.Error(t, err) + rows.Close() + err = rows.Err() + // Retries within the same transaction are errors (really anything except a rollback + // will be an error in this transaction). + require.Error(t, err) + rows.Close() + err = br.Close() + require.Error(t, err) + + err = tx.Rollback(ctx) + require.NoError(t, err) + + // once we've rolled back, retries will work + batch = &pgx.Batch{} + batch.Queue(getSQL, 1) + br = conn.SendBatch(ctx, batch) + rows, err = br.Query() + require.NoError(t, err) + rows.Next() + err = rows.Err() + require.NoError(t, err) + rows.Close() + err = br.Close() + require.NoError(t, err) + + ensureConnValid(t, conn) +} + func TestInsertDurationInterval(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -1369,3 +1540,54 @@ func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { require.EqualValues(t, 1, n) }) } +<<<<<<< HEAD +======= + +// https://github.com/jackc/pgx/issues/1847 +func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + connString := os.Getenv("PGX_TEST_DATABASE") + config := mustParseConfig(t, connString) + config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement + config.StatementCacheCapacity = 2 + + conn, err := pgx.ConnectConfig(ctx, config) + require.NoError(t, err) + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "select $1::int + 1", 1) + require.NoError(t, err) + + _, err = tx.Exec(ctx, "select $1::int + 2", 1) + require.NoError(t, err) + + // This should invalidate the first cached statement. + _, err = tx.Exec(ctx, "select $1::int + 3", 1) + require.NoError(t, err) + + batch := &pgx.Batch{} + batch.Queue("select $1::int + 1", 1) + err = tx.SendBatch(ctx, batch).Close() + require.NoError(t, err) + + err = tx.Rollback(ctx) + require.NoError(t, err) + + ensureConnValid(t, conn) +} + +func TestErrNoRows(t *testing.T) { + t.Parallel() + + // ensure we preserve old error message + require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error()) + + require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows") +} +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e diff --git a/copy_from_test.go b/copy_from_test.go index 87ba169c1..e123cf5db 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -76,7 +76,6 @@ func TestConnCopyWithAllQueryExecModes(t *testing.T) { } func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) { - for _, mode := range pgxtest.KnownOIDQueryExecModes { t.Run(mode.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) diff --git a/derived_types.go b/derived_types.go new file mode 100644 index 000000000..72c0a2423 --- /dev/null +++ b/derived_types.go @@ -0,0 +1,256 @@ +package pgx + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/pgtype" +) + +/* +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. + + pgVersion: the major version of the PostgreSQL server + typeNames: the names of the types to load. If nil, load all types. +*/ +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { + supportsMultirange := (pgVersion >= 14) + var typeNamesClause string + + if typeNames == nil { + // This should not occur; this will not return any types + typeNamesClause = "= ''" + } else { + typeNamesClause = "= ANY($1)" + } + parts := make([]string, 0, 10) + + // Each of the type names provided might be found in pg_class or pg_type. + // Additionally, it may or may not include a schema portion. + parts = append(parts, ` +WITH RECURSIVE +-- find the OIDs in pg_class which match one of the provided type names +selected_classes(oid,reltype) AS ( + -- this query uses the namespace search path, so will match type names without a schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_catalog.pg_class + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) + AND relname `, typeNamesClause, ` +UNION ALL + -- this query will only match type names which include the schema prefix + SELECT pg_class.oid, pg_class.reltype + FROM pg_class + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) + WHERE nspname || '.' || relname `, typeNamesClause, ` +), +selected_types(oid) AS ( + -- collect the OIDs from pg_types which correspond to the selected classes + SELECT reltype AS oid + FROM selected_classes +UNION ALL + -- as well as any other type names which match our criteria + SELECT pg_type.oid + FROM pg_type + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE typname `, typeNamesClause, ` + OR nspname || '.' || typname `, typeNamesClause, ` +), +-- this builds a parent/child mapping of objects, allowing us to know +-- all the child (ie: dependent) types that a parent (type) requires +-- As can be seen, there are 3 ways this can occur (the last of which +-- is due to being a composite class, where the composite fields are children) +pc(parent, child) AS ( + SELECT parent.oid, parent.typelem + FROM pg_type parent + WHERE parent.typtype = 'b' AND parent.typelem != 0 +UNION ALL + SELECT parent.oid, parent.typbasetype + FROM pg_type parent + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 +UNION ALL + SELECT pg_type.oid, atttypid + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 +), +-- Now construct a recursive query which includes a 'depth' element. +-- This is used to ensure that the "youngest" children are registered before +-- their parents. +relationships(parent, child, depth) AS ( + SELECT DISTINCT 0::OID, selected_types.oid, 0 + FROM selected_types +UNION ALL + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 + FROM selected_classes c + inner join pg_type ON (c.reltype = pg_type.oid) + inner join pg_attribute on (c.oid = pg_attribute.attrelid) +UNION ALL + SELECT pc.parent, pc.child, relationships.depth + 1 + FROM pc + INNER JOIN relationships ON (pc.parent = relationships.child) +), +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration +composite AS ( + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids + FROM pg_attribute + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) + WHERE NOT attisdropped + AND attnum > 0 + GROUP BY pg_type.oid +) +-- Bring together this information, showing all the information which might possibly be required +-- to complete the registration, applying filters to only show the items which relate to the selected +-- types/classes. +SELECT typname, + pg_namespace.nspname, + typtype, + typbasetype, + typelem, + pg_type.oid,`) + if supportsMultirange { + parts = append(parts, ` + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) + } else { + parts = append(parts, ` + 0 AS rngtypid,`) + } + parts = append(parts, ` + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, + attnames, atttypids + FROM relationships + INNER JOIN pg_type ON (pg_type.oid = relationships.child) + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) + if supportsMultirange { + parts = append(parts, ` + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) + } + + parts = append(parts, ` + LEFT OUTER JOIN composite USING (oid) + LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid) + WHERE NOT (typtype = 'b' AND typelem = 0)`) + parts = append(parts, ` + GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) + if supportsMultirange { + parts = append(parts, ` + multirange.rngtypid,`) + } + parts = append(parts, ` + attnames, atttypids + ORDER BY MAX(depth) desc, typname;`) + return strings.Join(parts, "") +} + +type derivedTypeInfo struct { + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 + TypeName, Typtype, NspName string + Attnames []string + Atttypids []uint32 +} + +// LoadTypes performs a single (complex) query, returning all the required +// information to register the named types, as well as any other types directly +// or indirectly required to complete the registration. +// The result of this call can be passed into RegisterTypes to complete the process. +func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) { + m := c.TypeMap() + if len(typeNames) == 0 { + return nil, fmt.Errorf("No type names were supplied.") + } + + // Disregard server version errors. This will result in + // the SQL not support recent structures such as multirange + serverVersion, _ := serverVersion(c) + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) + rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) + if err != nil { + return nil, fmt.Errorf("While generating load types query: %w", err) + } + defer rows.Close() + result := make([]*pgtype.Type, 0, 100) + for rows.Next() { + ti := derivedTypeInfo{} + err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) + if err != nil { + return nil, fmt.Errorf("While scanning type information: %w", err) + } + var type_ *pgtype.Type + switch ti.Typtype { + case "b": // array + dt, ok := m.TypeForOID(ti.Typelem) + if !ok { + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) + } + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} + case "c": // composite + var fields []pgtype.CompositeCodecField + for i, fieldName := range ti.Attnames { + dt, ok := m.TypeForOID(ti.Atttypids[i]) + if !ok { + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) + } + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} + case "d": // domain + dt, ok := m.TypeForOID(ti.Typbasetype) + if !ok { + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} + case "e": // enum + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} + case "r": // range + dt, ok := m.TypeForOID(ti.Rngsubtype) + if !ok { + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} + case "m": // multirange + dt, ok := m.TypeForOID(ti.Rngtypid) + if !ok { + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) + } + + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} + default: + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) + } + + // the type_ is imposible to be null + m.RegisterType(type_) + if ti.NspName != "" { + nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec} + m.RegisterType(nspType) + result = append(result, nspType) + } + result = append(result, type_) + } + return result, nil +} + +// serverVersion returns the postgresql server version. +func serverVersion(c *Conn) (int64, error) { + serverVersionStr := c.PgConn().ParameterStatus("server_version") + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) + // if not PostgreSQL do nothing + if serverVersionStr == "" { + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) + } + + version, err := strconv.ParseInt(serverVersionStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("postgres version parsing failed: %w", err) + } + return version, nil +} diff --git a/derived_types_test.go b/derived_types_test.go new file mode 100644 index 000000000..6fb6e1d36 --- /dev/null +++ b/derived_types_test.go @@ -0,0 +1,40 @@ +package pgx_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" +) + +func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dtype_test; +drop domain if exists anotheruint64; + +create domain anotheruint64 as numeric(20,0); +create type dtype_test as ( + a text, + b int4, + c anotheruint64, + d anotheruint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop type dtype_test") + defer conn.Exec(ctx, "drop domain anotheruint64") + + types, err := conn.LoadTypes(ctx, []string{"dtype_test"}) + require.NoError(t, err) + require.Len(t, types, 6) + require.Equal(t, types[0].Name, "public.anotheruint64") + require.Equal(t, types[1].Name, "anotheruint64") + require.Equal(t, types[2].Name, "public._anotheruint64") + require.Equal(t, types[3].Name, "_anotheruint64") + require.Equal(t, types[4].Name, "public.dtype_test") + require.Equal(t, types[5].Name, "dtype_test") + }) +} diff --git a/doc.go b/doc.go index 0590d2d43..6f5c917b9 100644 --- a/doc.go +++ b/doc.go @@ -11,9 +11,10 @@ The primary way of establishing a connection is with [pgx.Connect]: conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) -The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified -here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the connection with -[ConnectConfig] to configure settings such as tracing that cannot be configured with a connection string. +The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be +specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the +connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection +string. Connection Pool @@ -23,8 +24,8 @@ github.com/yugabyte/pgx/v5/pgxpool for a concurrency safe connection pool. Query Interface pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and -ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and -rows.Err(). +ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(), +rows.Scan, and rows.Err(). CollectRows can be used collect all returned rows into a slice. @@ -155,17 +156,34 @@ When you already have a typed array using CopyFromSlice can be more convenient. CopyFrom can be faster than an insert with as few as 5 rows. +Listen and Notify + +pgx can listen to the PostgreSQL notification system with the `Conn.WaitForNotification` method. It blocks until a +notification is received or the context is canceled. + + _, err := conn.Exec(context.Background(), "listen channelname") + if err != nil { + return err + } + + notification, err := conn.WaitForNotification(context.Background()) + if err != nil { + return err + } + // do something with notification + + Tracing and Logging -pgx supports tracing by setting ConnConfig.Tracer. +pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer. In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. -For debug tracing of the actual PostgreSQL wire protocol messages see github.com/yugabyte/pgx/v5/pgproto3. +For debug tracing of the actual PostgreSQL wire protocol messages see github.com/jackc/pgx/v5/pgproto3. Lower Level PostgreSQL Functionality -github.com/yugabyte/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in +github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. PgBouncer diff --git a/extended_query_builder.go b/extended_query_builder.go index dc89a4973..5b5531cc7 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -1,10 +1,8 @@ package pgx import ( - "database/sql/driver" "fmt" - "github.com/yugabyte/pgx/v5/internal/anynil" "github.com/yugabyte/pgx/v5/pgconn" "github.com/yugabyte/pgx/v5/pgtype" ) @@ -23,10 +21,15 @@ type ExtendedQueryBuilder struct { func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { eqb.reset() - anynil.NormalizeSlice(args) - if sd == nil { - return eqb.appendParamsForQueryExecModeExec(m, args) + for i := range args { + err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i]) + if err != nil { + err = fmt.Errorf("failed to encode args[%d]: %w", i, err) + return err + } + } + return nil } if len(sd.ParamOIDs) != len(args) { @@ -113,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() { } func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { - if anynil.Is(arg) { - return nil, nil - } - if eqb.paramValueBytes == nil { eqb.paramValueBytes = make([]byte, 0, 128) } @@ -145,74 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui return m.FormatCodeForOID(oid) } - -// appendParamsForQueryExecModeExec appends the args to eqb. -// -// Parameters must be encoded in the text format because of differences in type conversion between timestamps and -// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the -// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both -// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL -// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date. -// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion -// before converting it to date. This means that dates can be shifted by one day. In text format without that double -// type conversion it takes the date directly and ignores time zone (i.e. it works). -// -// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is -// no way to safely use binary or to specify the parameter OIDs. -func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error { - for _, arg := range args { - if arg == nil { - err := eqb.appendParam(m, 0, TextFormatCode, arg) - if err != nil { - return err - } - } else { - dt, ok := m.TypeForValue(arg) - if !ok { - var tv pgtype.TextValuer - if tv, ok = arg.(pgtype.TextValuer); ok { - t, err := tv.TextValue() - if err != nil { - return err - } - - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = t - } - } - } - if !ok { - var dv driver.Valuer - if dv, ok = arg.(driver.Valuer); ok { - v, err := dv.Value() - if err != nil { - return err - } - dt, ok = m.TypeForValue(v) - if ok { - arg = v - } - } - } - if !ok { - var str fmt.Stringer - if str, ok = arg.(fmt.Stringer); ok { - dt, ok = m.TypeForOID(pgtype.TextOID) - if ok { - arg = str.String() - } - } - } - if !ok { - return &unknownArgumentTypeQueryExecModeExecError{arg: arg} - } - err := eqb.appendParam(m, dt.OID, TextFormatCode, arg) - if err != nil { - return err - } - } - } - - return nil -} diff --git a/go.mod b/go.mod index 61ef3416c..2a48d7eac 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,21 @@ module github.com/yugabyte/pgx/v5 -go 1.19 +go 1.23.0 require ( github.com/jackc/pgpassfile v1.0.0 - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 - github.com/jackc/puddle/v2 v2.2.1 - github.com/stretchr/testify v1.8.4 - golang.org/x/crypto v0.20.0 - golang.org/x/text v0.14.0 + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 + github.com/jackc/puddle/v2 v2.2.2 + github.com/stretchr/testify v1.8.1 + golang.org/x/crypto v0.37.0 + golang.org/x/sync v0.13.0 + golang.org/x/text v0.24.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/pretty v0.1.0 // indirect - github.com/kr/text v0.2.0 // indirect + github.com/kr/pretty v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - golang.org/x/sync v0.6.0 // indirect - gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index fee89d19a..e112808c0 100644 --- a/go.sum +++ b/go.sum @@ -4,32 +4,42 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -golang.org/x/crypto v0.20.0 h1:jmAMJJZXr5KiCw05dfYK9QnqaqKLYXijU23lsEdcQqg= -golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/sanitize/benchmmark.sh b/internal/sanitize/benchmmark.sh new file mode 100644 index 000000000..ec0f7b03a --- /dev/null +++ b/internal/sanitize/benchmmark.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +current_branch=$(git rev-parse --abbrev-ref HEAD) +if [ "$current_branch" == "HEAD" ]; then + current_branch=$(git rev-parse HEAD) +fi + +restore_branch() { + echo "Restoring original branch/commit: $current_branch" + git checkout "$current_branch" +} +trap restore_branch EXIT + +# Check if there are uncommitted changes +if ! git diff --quiet || ! git diff --cached --quiet; then + echo "There are uncommitted changes. Please commit or stash them before running this script." + exit 1 +fi + +# Ensure that at least one commit argument is passed +if [ "$#" -lt 1 ]; then + echo "Usage: $0 ... " + exit 1 +fi + +commits=("$@") +benchmarks_dir=benchmarks + +if ! mkdir -p "${benchmarks_dir}"; then + echo "Unable to create dir for benchmarks data" + exit 1 +fi + +# Benchmark results +bench_files=() + +# Run benchmark for each listed commit +for i in "${!commits[@]}"; do + commit="${commits[i]}" + git checkout "$commit" || { + echo "Failed to checkout $commit" + exit 1 + } + + # Sanitized commmit message + commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_') + + # Benchmark data will go there + bench_file="${benchmarks_dir}/${i}_${commit_message}.bench" + + if ! go test -bench=. -count=10 >"$bench_file"; then + echo "Benchmarking failed for commit $commit" + exit 1 + fi + + bench_files+=("$bench_file") +done + +# go install golang.org/x/perf/cmd/benchstat[@latest] +benchstat "${bench_files[@]}" diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index f9091cd48..033dde2b8 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -4,8 +4,10 @@ import ( "bytes" "encoding/hex" "fmt" + "slices" "strconv" "strings" + "sync" "time" "unicode/utf8" ) @@ -24,18 +26,33 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +const maxBufSize = 16384 // 16 Ki + +var bufPool = &pool[*bytes.Buffer]{ + new: func() *bytes.Buffer { + return &bytes.Buffer{} + }, + reset: func(b *bytes.Buffer) bool { + n := b.Len() + b.Reset() + return n < maxBufSize + }, +} + +var null = []byte("null") + func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} + buf := bufPool.get() + defer bufPool.put(buf) for _, part := range q.Parts { - var str string switch part := part.(type) { case string: - str = part + buf.WriteString(part) case int: argIdx := part - 1 - + var p []byte if argIdx < 0 { return "", fmt.Errorf("first sql argument must be > 0") } @@ -43,30 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) { if argIdx >= len(args) { return "", fmt.Errorf("insufficient arguments") } + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') + arg := args[argIdx] switch arg := arg.(type) { case nil: - str = "null" + p = null case int64: - str = strconv.FormatInt(arg, 10) + p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10) case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) + p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64) case bool: - str = strconv.FormatBool(arg) + p = strconv.AppendBool(buf.AvailableBuffer(), arg) case []byte: - str = QuoteBytes(arg) + p = QuoteBytes(buf.AvailableBuffer(), arg) case string: - str = QuoteString(arg) + p = QuoteString(buf.AvailableBuffer(), arg) case time.Time: - str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + p = arg.Truncate(time.Microsecond). + AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'") default: return "", fmt.Errorf("invalid arg type: %T", arg) } argUse[argIdx] = true + + buf.Write(p) + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + buf.WriteByte(' ') default: return "", fmt.Errorf("invalid Part type: %T", part) } - buf.WriteString(str) } for i, used := range argUse { @@ -78,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) { } func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, + query := &Query{} + query.init(sql) + + return query, nil +} + +var sqlLexerPool = &pool[*sqlLexer]{ + new: func() *sqlLexer { + return &sqlLexer{} + }, + reset: func(sl *sqlLexer) bool { + *sl = sqlLexer{} + return true + }, +} + +func (q *Query) init(sql string) { + parts := q.Parts[:0] + if parts == nil { + // dirty, but fast heuristic to preallocate for ~90% usecases + n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1 + parts = make([]Part, 0, n) } + l := sqlLexerPool.get() + defer sqlLexerPool.put(l) + + l.src = sql + l.stateFn = rawState + l.parts = parts + for l.stateFn != nil { l.stateFn = l.stateFn(l) } - query := &Query{Parts: l.parts} - - return query, nil + q.Parts = l.parts } -func QuoteString(str string) string { - return "'" + strings.ReplaceAll(str, "'", "''") + "'" +func QuoteString(dst []byte, str string) []byte { + const quote = '\'' + + // Preallocate space for the worst case scenario + dst = slices.Grow(dst, len(str)*2+2) + + // Add opening quote + dst = append(dst, quote) + + // Iterate through the string without allocating + for i := 0; i < len(str); i++ { + if str[i] == quote { + dst = append(dst, quote, quote) + } else { + dst = append(dst, str[i]) + } + } + + // Add closing quote + dst = append(dst, quote) + + return dst } -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" +func QuoteBytes(dst, buf []byte) []byte { + if len(buf) == 0 { + return append(dst, `'\x'`...) + } + + // Calculate required length + requiredLen := 3 + hex.EncodedLen(len(buf)) + 1 + + // Ensure dst has enough capacity + if cap(dst)-len(dst) < requiredLen { + newDst := make([]byte, len(dst), len(dst)+requiredLen) + copy(newDst, dst) + dst = newDst + } + + // Record original length and extend slice + origLen := len(dst) + dst = dst[:origLen+requiredLen] + + // Add prefix + dst[origLen] = '\'' + dst[origLen+1] = '\\' + dst[origLen+2] = 'x' + + // Encode bytes directly into dst + hex.Encode(dst[origLen+3:len(dst)-1], buf) + + // Add suffix + dst[len(dst)-1] = '\'' + + return dst } type sqlLexer struct { @@ -315,13 +416,52 @@ func multilineCommentState(l *sqlLexer) stateFn { } } +var queryPool = &pool[*Query]{ + new: func() *Query { + return &Query{} + }, + reset: func(q *Query) bool { + n := len(q.Parts) + q.Parts = q.Parts[:0] + return n < 64 // drop too large queries + }, +} + // SanitizeSQL replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { +<<<<<<< HEAD query, err := NewQuery(sql) if err != nil { return "", err } +======= + query := queryPool.get() + query.init(sql) + defer queryPool.put(query) + +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e return query.Sanitize(args...) } + +type pool[E any] struct { + p sync.Pool + new func() E + reset func(E) bool +} + +func (pool *pool[E]) get() E { + v, ok := pool.p.Get().(E) + if !ok { + v = pool.new() + } + + return v +} + +func (p *pool[E]) put(v E) { + if p.reset(v) { + p.p.Put(v) + } +} diff --git a/internal/sanitize/sanitize_bench_test.go b/internal/sanitize/sanitize_bench_test.go new file mode 100644 index 000000000..baa742b11 --- /dev/null +++ b/internal/sanitize/sanitize_bench_test.go @@ -0,0 +1,62 @@ +// sanitize_benchmark_test.go +package sanitize_test + +import ( + "testing" + "time" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +var benchmarkSanitizeResult string + +const benchmarkQuery = "" + + `SELECT * + FROM "water_containers" + WHERE NOT "id" = $1 -- int64 + AND "tags" NOT IN $2 -- nil + AND "volume" > $3 -- float64 + AND "transportable" = $4 -- bool + AND position($5 IN "sign") -- bytes + AND "label" LIKE $6 -- string + AND "created_at" > $7; -- time.Time` + +var benchmarkArgs = []any{ + int64(12345), + nil, + float64(500), + true, + []byte("8BADF00D"), + "kombucha's han'dy awokowa", + time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC), +} + +func BenchmarkSanitize(b *testing.B) { + query, err := sanitize.NewQuery(benchmarkQuery) + if err != nil { + b.Fatalf("failed to create query: %v", err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize query: %v", err) + } + } +} + +var benchmarkNewSQLResult string + +func BenchmarkSanitizeSQL(b *testing.B) { + b.ReportAllocs() + var err error + for i := 0; i < b.N; i++ { + benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...) + if err != nil { + b.Fatalf("failed to sanitize SQL: %v", err) + } + } +} diff --git a/internal/sanitize/sanitize_fuzz_test.go b/internal/sanitize/sanitize_fuzz_test.go new file mode 100644 index 000000000..2f0c41223 --- /dev/null +++ b/internal/sanitize/sanitize_fuzz_test.go @@ -0,0 +1,55 @@ +package sanitize_test + +import ( + "strings" + "testing" + + "github.com/jackc/pgx/v5/internal/sanitize" +) + +func FuzzQuoteString(f *testing.F) { + const prefix = "prefix" + f.Add("new\nline") + f.Add("sample text") + f.Add("sample q'u'o't'e's") + f.Add("select 'quoted $42', $1") + + f.Fuzz(func(t *testing.T, input string) { + got := string(sanitize.QuoteString([]byte(prefix), input)) + want := oldQuoteString(input) + + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} + +func FuzzQuoteBytes(f *testing.F) { + const prefix = "prefix" + f.Add([]byte(nil)) + f.Add([]byte("\n")) + f.Add([]byte("sample text")) + f.Add([]byte("sample q'u'o't'e's")) + f.Add([]byte("select 'quoted $42', $1")) + + f.Fuzz(func(t *testing.T, input []byte) { + got := string(sanitize.QuoteBytes([]byte(prefix), input)) + want := oldQuoteBytes(input) + + quoted, ok := strings.CutPrefix(got, prefix) + if !ok { + t.Fatalf("result has no prefix") + } + + if want != quoted { + t.Errorf("got %q", got) + t.Fatalf("want %q", want) + } + }) +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go index a8a50b3f4..a7eb462c2 100644 --- a/internal/sanitize/sanitize_test.go +++ b/internal/sanitize/sanitize_test.go @@ -1,6 +1,8 @@ package sanitize_test import ( + "encoding/hex" + "strings" "testing" "time" @@ -132,47 +134,57 @@ func TestQuerySanitize(t *testing.T) { { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{int64(42)}, - expected: `select 42`, + expected: `select 42 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{float64(1.23)}, - expected: `select 1.23`, + expected: `select 1.23 `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{true}, - expected: `select true`, + expected: `select true `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{[]byte{0, 1, 2, 3, 255}}, - expected: `select '\x00010203ff'`, + expected: `select '\x00010203ff' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{nil}, - expected: `select null`, + expected: `select null `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foobar"}, - expected: `select 'foobar'`, + expected: `select 'foobar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{"foo'bar"}, - expected: `select 'foo''bar'`, + expected: `select 'foo''bar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, args: []any{`foo\'bar`}, - expected: `select 'foo\''bar'`, + expected: `select 'foo\''bar' `, }, { query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, - expected: `insert '2020-03-01 23:59:59.999999Z'`, + expected: `insert '2020-03-01 23:59:59.999999Z' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{int64(-1)}, + expected: `select 1- -1 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{float64(-1)}, + expected: `select 1- -1 `, }, } @@ -217,3 +229,55 @@ func TestQuerySanitize(t *testing.T) { } } } + +func TestQuoteString(t *testing.T) { + tc := func(name, input string) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := string(sanitize.QuoteString(nil, input)) + want := oldQuoteString(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("empty", "") + tc("text", "abcd") + tc("with quotes", `one's hat is always a cat`) +} + +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. +func oldQuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func TestQuoteBytes(t *testing.T) { + tc := func(name string, input []byte) { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := string(sanitize.QuoteBytes(nil, input)) + want := oldQuoteBytes(input) + + if got != want { + t.Errorf("got: %s", got) + t.Fatalf("want: %s", want) + } + }) + } + + tc("nil", nil) + tc("empty", []byte{}) + tc("text", []byte("abcd")) +} + +// This function was used before optimizations. +// You should keep for testing purposes - we want to ensure there are no breaking changes. +func oldQuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} diff --git a/internal/stmtcache/lru_cache.go b/internal/stmtcache/lru_cache.go index 7d961e41e..3f33d9dce 100644 --- a/internal/stmtcache/lru_cache.go +++ b/internal/stmtcache/lru_cache.go @@ -31,7 +31,6 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription { } return nil - } // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or diff --git a/large_objects.go b/large_objects.go index a3028b638..9d21afdce 100644 --- a/large_objects.go +++ b/large_objects.go @@ -4,6 +4,8 @@ import ( "context" "errors" "io" + + "github.com/jackc/pgx/v5/pgtype" ) // The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of @@ -115,9 +117,10 @@ func (o *LargeObject) Read(p []byte) (int, error) { expected = maxLargeObjectMessageLength } - var res []byte + res := pgtype.PreallocBytes(p[nTotal:]) err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res) - copy(p[nTotal:], res) + // We compute expected so that it always fits into p, so it should never happen + // that PreallocBytes's ScanBytes had to allocate a new slice. nTotal += len(res) if err != nil { return nTotal, err diff --git a/multitracer/tracer.go b/multitracer/tracer.go new file mode 100644 index 000000000..acff17398 --- /dev/null +++ b/multitracer/tracer.go @@ -0,0 +1,152 @@ +// Package multitracer provides a Tracer that can combine several tracers into one. +package multitracer + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Tracer can combine several tracers into one. +// You can use New to automatically split tracers by interface. +type Tracer struct { + QueryTracers []pgx.QueryTracer + BatchTracers []pgx.BatchTracer + CopyFromTracers []pgx.CopyFromTracer + PrepareTracers []pgx.PrepareTracer + ConnectTracers []pgx.ConnectTracer + PoolAcquireTracers []pgxpool.AcquireTracer + PoolReleaseTracers []pgxpool.ReleaseTracer +} + +// New returns new Tracer from tracers with automatically split tracers by interface. +func New(tracers ...pgx.QueryTracer) *Tracer { + var t Tracer + + for _, tracer := range tracers { + t.QueryTracers = append(t.QueryTracers, tracer) + + if batchTracer, ok := tracer.(pgx.BatchTracer); ok { + t.BatchTracers = append(t.BatchTracers, batchTracer) + } + + if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok { + t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer) + } + + if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok { + t.PrepareTracers = append(t.PrepareTracers, prepareTracer) + } + + if connectTracer, ok := tracer.(pgx.ConnectTracer); ok { + t.ConnectTracers = append(t.ConnectTracers, connectTracer) + } + + if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok { + t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer) + } + + if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok { + t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer) + } + } + + return &t +} + +func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + for _, tracer := range t.QueryTracers { + ctx = tracer.TraceQueryStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + for _, tracer := range t.QueryTracers { + tracer.TraceQueryEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + for _, tracer := range t.BatchTracers { + ctx = tracer.TraceBatchStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { + for _, tracer := range t.BatchTracers { + tracer.TraceBatchQuery(ctx, conn, data) + } +} + +func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + for _, tracer := range t.BatchTracers { + tracer.TraceBatchEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + for _, tracer := range t.CopyFromTracers { + ctx = tracer.TraceCopyFromStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + for _, tracer := range t.CopyFromTracers { + tracer.TraceCopyFromEnd(ctx, conn, data) + } +} + +func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + for _, tracer := range t.PrepareTracers { + ctx = tracer.TracePrepareStart(ctx, conn, data) + } + + return ctx +} + +func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + for _, tracer := range t.PrepareTracers { + tracer.TracePrepareEnd(ctx, conn, data) + } +} + +func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + for _, tracer := range t.ConnectTracers { + ctx = tracer.TraceConnectStart(ctx, data) + } + + return ctx +} + +func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + for _, tracer := range t.ConnectTracers { + tracer.TraceConnectEnd(ctx, data) + } +} + +func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + for _, tracer := range t.PoolAcquireTracers { + ctx = tracer.TraceAcquireStart(ctx, pool, data) + } + + return ctx +} + +func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + for _, tracer := range t.PoolAcquireTracers { + tracer.TraceAcquireEnd(ctx, pool, data) + } +} + +func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + for _, tracer := range t.PoolReleaseTracers { + tracer.TraceRelease(pool, data) + } +} diff --git a/multitracer/tracer_test.go b/multitracer/tracer_test.go new file mode 100644 index 000000000..aa5ccd080 --- /dev/null +++ b/multitracer/tracer_test.go @@ -0,0 +1,115 @@ +package multitracer_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/multitracer" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +type testFullTracer struct{} + +func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) { +} + +func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { +} + +func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { +} + +func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { +} + +func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { +} + +func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + return ctx +} + +func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { +} + +func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { +} + +type testCopyTracer struct{} + +func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context { + return ctx +} + +func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { +} + +func TestNew(t *testing.T) { + t.Parallel() + + fullTracer := &testFullTracer{} + copyTracer := &testCopyTracer{} + + mt := multitracer.New(fullTracer, copyTracer) + require.Equal( + t, + &multitracer.Tracer{ + QueryTracers: []pgx.QueryTracer{ + fullTracer, + copyTracer, + }, + BatchTracers: []pgx.BatchTracer{ + fullTracer, + }, + CopyFromTracers: []pgx.CopyFromTracer{ + fullTracer, + copyTracer, + }, + PrepareTracers: []pgx.PrepareTracer{ + fullTracer, + }, + ConnectTracers: []pgx.ConnectTracer{ + fullTracer, + }, + PoolAcquireTracers: []pgxpool.AcquireTracer{ + fullTracer, + }, + PoolReleaseTracers: []pgxpool.ReleaseTracer{ + fullTracer, + }, + }, + mt, + ) +} diff --git a/named_args.go b/named_args.go index 8367fc63a..c88991ee4 100644 --- a/named_args.go +++ b/named_args.go @@ -2,6 +2,7 @@ package pgx import ( "context" + "fmt" "strconv" "strings" "unicode/utf8" @@ -21,6 +22,34 @@ type NamedArgs map[string]any // RewriteQuery implements the QueryRewriter interface. func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(na, sql, false) +} + +// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all +// named arguments that the sql query uses, and no extra arguments. +type StrictNamedArgs map[string]any + +// RewriteQuery implements the QueryRewriter interface. +func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) { + return rewriteQuery(sna, sql, true) +} + +type namedArg string + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []any + + nameToOrdinal map[namedArg]int +} + +type stateFn func(*sqlLexer) stateFn + +func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) { l := &sqlLexer{ src: sql, stateFn: rawState, @@ -44,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar newArgs = make([]any, len(l.nameToOrdinal)) for name, ordinal := range l.nameToOrdinal { - newArgs[ordinal-1] = na[string(name)] + var found bool + newArgs[ordinal-1], found = na[string(name)] + if isStrict && !found { + return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name) + } } - return sb.String(), newArgs, nil -} - -type namedArg string - -type sqlLexer struct { - src string - start int - pos int - nested int // multiline comment nesting level. - stateFn stateFn - parts []any + if isStrict { + for name := range na { + if _, found := l.nameToOrdinal[namedArg(name)]; !found { + return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name) + } + } + } - nameToOrdinal map[namedArg]int + return sb.String(), newArgs, nil } -type stateFn func(*sqlLexer) stateFn - func rawState(l *sqlLexer) stateFn { for { r, width := utf8.DecodeRuneInString(l.src[l.pos:]) diff --git a/named_args_test.go b/named_args_test.go index 94d692f80..c4d193c4c 100644 --- a/named_args_test.go +++ b/named_args_test.go @@ -93,6 +93,18 @@ func TestNamedArgsRewriteQuery(t *testing.T) { where id = $1;`, expectedArgs: []any{int32(42)}, }, + { + sql: "extra provided argument", + namedArgs: pgx.NamedArgs{"extra": int32(1)}, + expectedSQL: "extra provided argument", + expectedArgs: []any{}, + }, + { + sql: "@missing argument", + namedArgs: pgx.NamedArgs{}, + expectedSQL: "$1 argument", + expectedArgs: []any{nil}, + }, // test comments and quotes } { @@ -102,3 +114,49 @@ func TestNamedArgsRewriteQuery(t *testing.T) { assert.Equalf(t, tt.expectedArgs, args, "%d", i) } } + +func TestStrictNamedArgsRewriteQuery(t *testing.T) { + t.Parallel() + + for i, tt := range []struct { + sql string + namedArgs pgx.StrictNamedArgs + expectedSQL string + expectedArgs []any + isExpectedError bool + }{ + { + sql: "no arguments", + namedArgs: pgx.StrictNamedArgs{}, + expectedSQL: "no arguments", + expectedArgs: []any{}, + isExpectedError: false, + }, + { + sql: "@all @matches", + namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)}, + expectedSQL: "$1 $2", + expectedArgs: []any{int32(1), int32(2)}, + isExpectedError: false, + }, + { + sql: "extra provided argument", + namedArgs: pgx.StrictNamedArgs{"extra": int32(1)}, + isExpectedError: true, + }, + { + sql: "@missing argument", + namedArgs: pgx.StrictNamedArgs{}, + isExpectedError: true, + }, + } { + sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil) + if tt.isExpectedError { + assert.Errorf(t, err, "%d", i) + } else { + require.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.expectedSQL, sql, "%d", i) + assert.Equalf(t, tt.expectedArgs, args, "%d", i) + } + } +} diff --git a/pgbouncer_test.go b/pgbouncer_test.go index 8141d9673..3b87f3ccd 100644 --- a/pgbouncer_test.go +++ b/pgbouncer_test.go @@ -72,5 +72,4 @@ func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int for i := 0; i < workers; i++ { <-doneChan } - } diff --git a/pgconn/auth_scram.go b/pgconn/auth_scram.go index 7fda00a3f..45fca1a2a 100644 --- a/pgconn/auth_scram.go +++ b/pgconn/auth_scram.go @@ -263,7 +263,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte { return buf } -func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte { +func computeServerSignature(saltedPassword, authMessage []byte) []byte { serverKey := computeHMAC(saltedPassword, []byte("Server Key")) serverSignature := computeHMAC(serverKey, authMessage) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index 15db1c34f..81893a047 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -78,7 +78,10 @@ func BenchmarkExec(b *testing.B) { } } _, err = rr.Close() +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -127,7 +130,10 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } } _, err = rr.Close() +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -184,7 +190,10 @@ func BenchmarkExecPrepared(b *testing.B) { } } _, err = rr.Close() +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -227,7 +236,10 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } } _, err = rr.Close() +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } diff --git a/pgconn/config.go b/pgconn/config.go index f897f3f30..43fa17818 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -19,12 +19,15 @@ import ( "github.com/jackc/pgpassfile" "github.com/jackc/pgservicefile" + "github.com/yugabyte/pgx/v5/pgconn/ctxwatch" "github.com/yugabyte/pgx/v5/pgproto3" ) -type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error -type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error -type GetSSLPasswordFunc func(ctx context.Context) string +type ( + AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error + ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error + GetSSLPasswordFunc func(ctx context.Context) string +) // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // manually initialized Config will cause ConnectConfig to panic. @@ -39,12 +42,19 @@ type Config struct { DialFunc DialFunc // e.g. net.Dialer.DialContext LookupFunc LookupFunc // e.g. net.Resolver.LookupHost BuildFrontend BuildFrontendFunc - RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) + + // BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called + // when a context passed to a PgConn method is canceled. + BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler + + RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) KerberosSrvName string KerberosSpn string Fallbacks []*FallbackConfig + SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct + // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. @@ -70,7 +80,7 @@ type Config struct { // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword. type ParseConfigOptions struct { - // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function + // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function // PQsetSSLKeyPassHook_OpenSSL. GetSSLPassword GetSSLPasswordFunc } @@ -112,6 +122,14 @@ type FallbackConfig struct { TLSConfig *tls.Config // nil disables TLS } +// connectOneConfig is the configuration for a single attempt to connect to a single host. +type connectOneConfig struct { + network string + address string + originalHostname string // original hostname before resolving + tlsConfig *tls.Config // nil disables TLS +} + // isAbsolutePath checks if the provided value is an absolute path either // beginning with a forward slash (as on Linux-based systems) or with a capital // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). @@ -146,11 +164,11 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely -// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). -// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be -// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. +// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty +// to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // -// # Example DSN +// # Example Keyword/Value // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // // # Example URL @@ -163,13 +181,13 @@ func NetworkAddress(host string, port uint16) (network, address string) { // // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // values that will be tried in order. This can be used as part of a high availability system. See -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // // # Example URL // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed -// via database URL or DSN: +// via database URL or keyword/value: // // PGHOST // PGPORT @@ -184,13 +202,15 @@ func NetworkAddress(host string, port uint16) (network, address string) { // PGSSLKEY // PGSSLROOTCERT // PGSSLPASSWORD +// PGOPTIONS // PGAPPNAME // PGCONNECT_TIMEOUT // PGTARGETSESSIONATTRS +// PGTZ // -// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. +// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // -// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are +// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // usually but not always the environment variable name downcased and without the "PG" prefix. // // Important Security Notes: @@ -198,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) { // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // not set. // -// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of +// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // security each sslmode provides. // // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of @@ -233,16 +253,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con connStringSettings := make(map[string]string) if connString != "" { var err error - // connString may be a database URL or a DSN + // connString may be a database URL or in PostgreSQL keyword/value format if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { connStringSettings, err = parseURLSettings(connString) if err != nil { return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err} } } else { - connStringSettings, err = parseDSNSettings(connString) + connStringSettings, err = parseKeywordValueSettings(connString) if err != nil { - return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as DSN", err: err} + return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err} } } } @@ -266,6 +286,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { return pgproto3.NewFrontend(r, w) }, + BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler { + return &DeadlineContextWatcherHandler{Conn: pgConn.conn} + }, OnPgError: func(_ *PgConn, pgErr *PgError) bool { // we want to automatically close any fatal errors if strings.EqualFold(pgErr.Severity, "FATAL") { @@ -301,6 +324,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con "sslkey": {}, "sslcert": {}, "sslrootcert": {}, + "sslnegotiation": {}, "sslpassword": {}, "sslsni": {}, "krbspn": {}, @@ -369,6 +393,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con config.Port = fallbacks[0].Port config.TLSConfig = fallbacks[0].TLSConfig config.Fallbacks = fallbacks[1:] + config.SSLNegotiation = settings["sslnegotiation"] passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) if err == nil { @@ -432,9 +457,12 @@ func parseEnvSettings() map[string]string { "PGSSLSNI": "sslsni", "PGSSLROOTCERT": "sslrootcert", "PGSSLPASSWORD": "sslpassword", + "PGSSLNEGOTIATION": "sslnegotiation", "PGTARGETSESSIONATTRS": "target_session_attrs", "PGSERVICE": "service", "PGSERVICEFILE": "servicefile", + "PGTZ": "timezone", + "PGOPTIONS": "options", } for envname, realname := range nameMap { @@ -450,14 +478,17 @@ func parseEnvSettings() map[string]string { func parseURLSettings(connString string) (map[string]string, error) { settings := make(map[string]string) - url, err := url.Parse(connString) + parsedURL, err := url.Parse(connString) if err != nil { + if urlErr := new(url.Error); errors.As(err, &urlErr) { + return nil, urlErr.Err + } return nil, err } - if url.User != nil { - settings["user"] = url.User.Username() - if password, present := url.User.Password(); present { + if parsedURL.User != nil { + settings["user"] = parsedURL.User.Username() + if password, present := parsedURL.User.Password(); present { settings["password"] = password } } @@ -465,7 +496,7 @@ func parseURLSettings(connString string) (map[string]string, error) { // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. var hosts []string var ports []string - for _, host := range strings.Split(url.Host, ",") { + for _, host := range strings.Split(parsedURL.Host, ",") { if host == "" { continue } @@ -491,7 +522,7 @@ func parseURLSettings(connString string) (map[string]string, error) { settings["port"] = strings.Join(ports, ",") } - database := strings.TrimLeft(url.Path, "/") + database := strings.TrimLeft(parsedURL.Path, "/") if database != "" { settings["database"] = database } @@ -500,7 +531,7 @@ func parseURLSettings(connString string) (map[string]string, error) { "dbname": "database", } - for k, v := range url.Query() { + for k, v := range parsedURL.Query() { if k2, present := nameMap[k]; present { k = k2 } @@ -517,7 +548,7 @@ func isIPOnly(host string) bool { var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} -func parseDSNSettings(s string) (map[string]string, error) { +func parseKeywordValueSettings(s string) (map[string]string, error) { settings := make(map[string]string) nameMap := map[string]string{ @@ -528,7 +559,7 @@ func parseDSNSettings(s string) (map[string]string, error) { var key, val string eqIdx := strings.IndexRune(s, '=') if eqIdx < 0 { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") @@ -580,7 +611,7 @@ func parseDSNSettings(s string) (map[string]string, error) { } if key == "" { - return nil, errors.New("invalid dsn") + return nil, errors.New("invalid keyword/value") } settings[key] = val @@ -626,6 +657,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P sslkey := settings["sslkey"] sslpassword := settings["sslpassword"] sslsni := settings["sslsni"] + sslnegotiation := settings["sslnegotiation"] // Match libpq default behavior if sslmode == "" { @@ -637,6 +669,43 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P tlsConfig := &tls.Config{} + if sslnegotiation == "direct" { + tlsConfig.NextProtos = []string{"postgresql"} + if sslmode == "prefer" { + sslmode = "require" + } + } + + if sslrootcert != "" { + var caCertPool *x509.CertPool + + if sslrootcert == "system" { + var err error + + caCertPool, err = x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("unable to load system certificate pool: %w", err) + } + + sslmode = "verify-full" + } else { + caCertPool = x509.NewCertPool() + + caPath := sslrootcert + caCert, err := os.ReadFile(caPath) + if err != nil { + return nil, fmt.Errorf("unable to read CA file: %w", err) + } + + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, errors.New("unable to add CA to cert pool") + } + } + + tlsConfig.RootCAs = caCertPool + tlsConfig.ClientCAs = caCertPool + } + switch sslmode { case "disable": return []*tls.Config{nil}, nil @@ -646,7 +715,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P // According to PostgreSQL documentation, if a root CA file exists, // the behavior of sslmode=require should be the same as that of verify-ca // - // See https://www.postgresql.org/docs/12/libpq-ssl.html + // See https://www.postgresql.org/docs/current/libpq-ssl.html if sslrootcert != "" { goto nextCase } @@ -694,23 +763,6 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P return nil, errors.New("sslmode is invalid") } - if sslrootcert != "" { - caCertPool := x509.NewCertPool() - - caPath := sslrootcert - caCert, err := os.ReadFile(caPath) - if err != nil { - return nil, fmt.Errorf("unable to read CA file: %w", err) - } - - if !caCertPool.AppendCertsFromPEM(caCert) { - return nil, errors.New("unable to add CA to cert pool") - } - - tlsConfig.RootCAs = caCertPool - tlsConfig.ClientCAs = caCertPool - } - if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { return nil, errors.New(`both "sslcert" and "sslkey" are required`) } @@ -721,6 +773,9 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P return nil, fmt.Errorf("unable to read sslkey: %w", err) } block, _ := pem.Decode(buf) + if block == nil { + return nil, errors.New("failed to decode sslkey") + } var pemKey []byte var decryptedKey []byte var decryptedError error @@ -731,8 +786,8 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P if sslpassword != "" { decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) } - //if sslpassword not provided or has decryption error when use it - //try to find sslpassword with callback function + // if sslpassword not provided or has decryption error when use it + // try to find sslpassword with callback function if sslpassword == "" || decryptedError != nil { if parseConfigOptions.GetSSLPassword != nil { sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) @@ -797,7 +852,8 @@ func parsePort(s string) (uint16, error) { } func makeDefaultDialer() *net.Dialer { - return &net.Dialer{KeepAlive: 5 * time.Minute} + // rely on GOLANG KeepAlive settings + return &net.Dialer{} } func makeDefaultResolver() *net.Resolver { @@ -824,12 +880,12 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc { // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-write. func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "on" { + if string(result[0].Rows[0][0]) == "on" { return errors.New("read only connection") } @@ -839,12 +895,12 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=read-only. func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "on" { + if string(result[0].Rows[0][0]) != "on" { return errors.New("connection is not read only") } @@ -854,12 +910,12 @@ func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgCo // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=standby. func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return errors.New("server is not in hot standby mode") } @@ -869,12 +925,12 @@ func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=primary. func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) == "t" { + if string(result[0].Rows[0][0]) == "t" { return errors.New("server is in standby mode") } @@ -884,12 +940,12 @@ func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgCon // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible // target_session_attrs=prefer-standby. func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { - result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() - if result.Err != nil { - return result.Err + result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll() + if err != nil { + return err } - if string(result.Rows[0][0]) != "t" { + if string(result[0].Rows[0][0]) != "t" { return &NotPreferredError{err: errors.New("server is not in hot standby mode")} } diff --git a/pgconn/config_test.go b/pgconn/config_test.go index ec511ad91..17a7f5f4c 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "os/user" + "path/filepath" "runtime" "strconv" "strings" @@ -132,7 +133,10 @@ func TestParseConfig(t *testing.T) { name: "sslmode prefer", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", config: &pgconn.Config{ +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e User: "jack", Password: "secret", Host: "localhost", @@ -336,7 +340,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN everything", + name: "Key/value everything", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", config: &pgconn.Config{ User: "jack", @@ -353,7 +357,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with escaped single quote", + name: "Key/value with escaped single quote", connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack's", @@ -366,7 +370,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with escaped backslash", + name: "Key/value with escaped backslash", connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", @@ -379,7 +383,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with single quoted values", + name: "Key/value with single quoted values", connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", @@ -391,7 +395,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with single quoted value with escaped single quote", + name: "Key/value with single quoted value with escaped single quote", connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack's", @@ -403,7 +407,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with empty single quoted value", + name: "Key/value with empty single quoted value", connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", @@ -415,7 +419,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN with space between key and value", + name: "Key/value with space between key and value", connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", config: &pgconn.Config{ User: "jack", @@ -491,7 +495,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN multiple hosts one port", + name: "Key/value multiple hosts one port", connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", @@ -516,7 +520,7 @@ func TestParseConfig(t *testing.T) { }, }, { - name: "DSN multiple hosts multiple ports", + name: "Key/value multiple hosts multiple ports", connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", config: &pgconn.Config{ User: "jack", @@ -566,7 +570,8 @@ func TestParseConfig(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "bar", - }}, + }, + }, { Host: "bar", Port: defaultPort, @@ -578,7 +583,8 @@ func TestParseConfig(t *testing.T) { TLSConfig: &tls.Config{ InsecureSkipVerify: true, ServerName: "baz", - }}, + }, + }, { Host: "baz", Port: defaultPort, @@ -772,18 +778,18 @@ func TestParseConfig(t *testing.T) { } // https://github.com/jackc/pgconn/issues/47 -func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { +func TestParseConfigKVWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { _, err := pgconn.ParseConfig("host= user= password= port= database=") require.NoError(t, err) } -func TestParseConfigDSNLeadingEqual(t *testing.T) { +func TestParseConfigKVLeadingEqual(t *testing.T) { _, err := pgconn.ParseConfig("= user=jack") require.Error(t, err) } // https://github.com/jackc/pgconn/issues/49 -func TestParseConfigDSNTrailingBackslash(t *testing.T) { +func TestParseConfigKVTrailingBackslash(t *testing.T) { _, err := pgconn.ParseConfig(`x=x\`) require.Error(t, err) assert.Contains(t, err.Error(), "invalid backslash") @@ -931,20 +937,7 @@ func TestParseConfigEnvLibpq(t *testing.T) { } } - pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI"} - - savedEnv := make(map[string]string) - for _, n := range pgEnvvars { - savedEnv[n] = os.Getenv(n) - } - defer func() { - for k, v := range savedEnv { - err := os.Setenv(k, v) - if err != nil { - t.Fatalf("Unable to restore environment: %v", err) - } - } - }() + pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI", "PGTZ", "PGOPTIONS"} tests := []struct { name string @@ -983,6 +976,8 @@ func TestParseConfigEnvLibpq(t *testing.T) { "PGCONNECT_TIMEOUT": "10", "PGSSLMODE": "disable", "PGAPPNAME": "pgxtest", + "PGTZ": "America/New_York", + "PGOPTIONS": "-c search_path=myschema", }, config: &pgconn.Config{ Host: "123.123.123.123", @@ -992,7 +987,7 @@ func TestParseConfigEnvLibpq(t *testing.T) { Password: "baz", ConnectTimeout: 10 * time.Second, TLSConfig: nil, - RuntimeParams: map[string]string{"application_name": "pgxtest"}, + RuntimeParams: map[string]string{"application_name": "pgxtest", "timezone": "America/New_York", "options": "-c search_path=myschema"}, }, }, { @@ -1015,14 +1010,8 @@ func TestParseConfigEnvLibpq(t *testing.T) { } for i, tt := range tests { - for _, n := range pgEnvvars { - err := os.Unsetenv(n) - require.NoError(t, err) - } - - for k, v := range tt.envvars { - err := os.Setenv(k, v) - require.NoError(t, err) + for _, env := range pgEnvvars { + t.Setenv(env, tt.envvars[env]) } config, err := pgconn.ParseConfig("") @@ -1038,16 +1027,11 @@ func TestParseConfigReadsPgPassfile(t *testing.T) { skipOnWindows(t) t.Parallel() - tf, err := os.CreateTemp("", "") + tfName := filepath.Join(t.TempDir(), "config") + err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0o600) require.NoError(t, err) - defer tf.Close() - defer os.Remove(tf.Name()) - - _, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk")) - require.NoError(t, err) - - connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name()) + connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName) expected := &pgconn.Config{ User: "curly", Password: "nyuknyuknyuk", @@ -1068,13 +1052,9 @@ func TestParseConfigReadsPgServiceFile(t *testing.T) { skipOnWindows(t) t.Parallel() - tf, err := os.CreateTemp("", "") - require.NoError(t, err) - - defer tf.Close() - defer os.Remove(tf.Name()) + tfName := filepath.Join(t.TempDir(), "config") - _, err = tf.Write([]byte(` + err := os.WriteFile(tfName, []byte(` [abc] host=abc.example.com port=9999 @@ -1086,7 +1066,7 @@ host = def.example.com dbname = defdb user = defuser application_name = spaced string -`)) +`), 0o600) require.NoError(t, err) defaultPort := getDefaultPort(t) @@ -1098,7 +1078,7 @@ application_name = spaced string }{ { name: "abc", - connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "abc"), config: &pgconn.Config{ Host: "abc.example.com", Database: "abcdb", @@ -1120,7 +1100,7 @@ application_name = spaced string }, { name: "def", - connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), + connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "def"), config: &pgconn.Config{ Host: "def.example.com", Port: defaultPort, @@ -1142,7 +1122,7 @@ application_name = spaced string }, { name: "conn string has precedence", - connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), + connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tfName, "abc"), config: &pgconn.Config{ Host: "other.example.com", Database: "abcdb", diff --git a/pgconn/ctxwatch/context_watcher.go b/pgconn/ctxwatch/context_watcher.go new file mode 100644 index 000000000..db8884eb8 --- /dev/null +++ b/pgconn/ctxwatch/context_watcher.go @@ -0,0 +1,80 @@ +package ctxwatch + +import ( + "context" + "sync" +) + +// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a +// time. +type ContextWatcher struct { + handler Handler + unwatchChan chan struct{} + + lock sync.Mutex + watchInProgress bool + onCancelWasCalled bool +} + +// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. +// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and +// onCancel called. +func NewContextWatcher(handler Handler) *ContextWatcher { + cw := &ContextWatcher{ + handler: handler, + unwatchChan: make(chan struct{}), + } + + return cw +} + +// Watch starts watching ctx. If ctx is canceled then the onCancel function passed to NewContextWatcher will be called. +func (cw *ContextWatcher) Watch(ctx context.Context) { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + panic("Watch already in progress") + } + + cw.onCancelWasCalled = false + + if ctx.Done() != nil { + cw.watchInProgress = true + go func() { + select { + case <-ctx.Done(): + cw.handler.HandleCancel(ctx) + cw.onCancelWasCalled = true + <-cw.unwatchChan + case <-cw.unwatchChan: + } + }() + } else { + cw.watchInProgress = false + } +} + +// Unwatch stops watching the previously watched context. If the onCancel function passed to NewContextWatcher was +// called then onUnwatchAfterCancel will also be called. +func (cw *ContextWatcher) Unwatch() { + cw.lock.Lock() + defer cw.lock.Unlock() + + if cw.watchInProgress { + cw.unwatchChan <- struct{}{} + if cw.onCancelWasCalled { + cw.handler.HandleUnwatchAfterCancel() + } + cw.watchInProgress = false + } +} + +type Handler interface { + // HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the + // context that was canceled. + HandleCancel(canceledCtx context.Context) + + // HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched. + HandleUnwatchAfterCancel() +} diff --git a/pgconn/ctxwatch/context_watcher_test.go b/pgconn/ctxwatch/context_watcher_test.go new file mode 100644 index 000000000..a18e7339e --- /dev/null +++ b/pgconn/ctxwatch/context_watcher_test.go @@ -0,0 +1,185 @@ +package ctxwatch_test + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgconn/ctxwatch" + "github.com/stretchr/testify/require" +) + +type testHandler struct { + handleCancel func(context.Context) + handleUnwatchAfterCancel func() +} + +func (h *testHandler) HandleCancel(ctx context.Context) { + h.handleCancel(ctx) +} + +func (h *testHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancel() +} + +func TestContextWatcherContextCancelled(t *testing.T) { + canceledChan := make(chan struct{}) + cleanupCalled := false + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + canceledChan <- struct{}{} + }, handleUnwatchAfterCancel: func() { + cleanupCalled = true + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + + select { + case <-canceledChan: + case <-time.NewTimer(time.Second).C: + t.Fatal("Timed out waiting for cancel func to be called") + } + + cw.Unwatch() + + require.True(t, cleanupCalled, "Cleanup func was not called") +} + +func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + t.Error("cancel func should not have been called") + }, handleUnwatchAfterCancel: func() { + t.Error("cleanup func should not have been called") + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cw.Unwatch() + cancel() +} + +func TestContextWatcherMultipleWatchPanics(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + defer cw.Unwatch() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + require.Panics(t, func() { cw.Watch(ctx2) }, "Expected panic when Watch called multiple times") +} + +func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + cw.Unwatch() // unwatch when not / never watching + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cw.Watch(ctx) + cw.Unwatch() + cw.Unwatch() // double unwatch +} + +func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + cw.Watch(ctx) + + go cw.Unwatch() + go cw.Unwatch() + + <-ctx.Done() +} + +func TestContextWatcherStress(t *testing.T) { + var cancelFuncCalls int64 + var cleanupFuncCalls int64 + + cw := ctxwatch.NewContextWatcher(&testHandler{ + handleCancel: func(context.Context) { + atomic.AddInt64(&cancelFuncCalls, 1) + }, handleUnwatchAfterCancel: func() { + atomic.AddInt64(&cleanupFuncCalls, 1) + }, + }) + + cycleCount := 100000 + + for i := 0; i < cycleCount; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + if i%2 == 0 { + cancel() + } + + // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. + if i%333 == 0 { + // on Windows Sleep takes more time than expected so we try to get here less frequently to avoid + // the CI takes a long time + time.Sleep(time.Nanosecond) + } + + cw.Unwatch() + if i%2 == 1 { + cancel() + } + } + + actualCancelFuncCalls := atomic.LoadInt64(&cancelFuncCalls) + actualCleanupFuncCalls := atomic.LoadInt64(&cleanupFuncCalls) + + if actualCancelFuncCalls == 0 { + t.Fatal("actualCancelFuncCalls == 0") + } + + maxCancelFuncCalls := int64(cycleCount) / 2 + if actualCancelFuncCalls > maxCancelFuncCalls { + t.Errorf("cancel func calls should be no more than %d but was %d", actualCancelFuncCalls, maxCancelFuncCalls) + } + + if actualCancelFuncCalls != actualCleanupFuncCalls { + t.Errorf("cancel func calls (%d) should be equal to cleanup func calls (%d) but was not", actualCancelFuncCalls, actualCleanupFuncCalls) + } +} + +func BenchmarkContextWatcherUncancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + for i := 0; i < b.N; i++ { + cw.Watch(context.Background()) + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancelled(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithCancel(context.Background()) + cw.Watch(ctx) + cancel() + cw.Unwatch() + } +} + +func BenchmarkContextWatcherCancellable(b *testing.B) { + cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for i := 0; i < b.N; i++ { + cw.Watch(ctx) + cw.Unwatch() + } +} diff --git a/pgconn/doc.go b/pgconn/doc.go index e3242cf4e..701375019 100644 --- a/pgconn/doc.go +++ b/pgconn/doc.go @@ -5,8 +5,8 @@ nearly the same level is the C library libpq. Establishing a Connection -Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for -libpq style environment variables. +Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the +environment for libpq style environment variables. Executing a Query @@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory. Pipeline Mode -Pipeline mode allows sending queries without having read the results of previously sent queries. It allows -control of exactly how many and when network round trips occur. +Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of +exactly how many and when network round trips occur. Context Support -All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the -method immediately returns. In most circumstances, this will close the underlying connection. +All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the +method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can +be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior. +This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is +a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before +interrupting the query in such a way as to close the connection. The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the client to abort. diff --git a/pgconn/errors.go b/pgconn/errors.go index c315739a9..d968d3f03 100644 --- a/pgconn/errors.go +++ b/pgconn/errors.go @@ -12,13 +12,14 @@ import ( // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. func SafeToRetry(err error) bool { - if e, ok := err.(interface{ SafeToRetry() bool }); ok { - return e.SafeToRetry() + var retryableErr interface{ SafeToRetry() bool } + if errors.As(err, &retryableErr) { + return retryableErr.SafeToRetry() } return false } -// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a +// Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. func Timeout(err error) bool { var timeoutErr *errTimeout @@ -26,26 +27,27 @@ func Timeout(err error) bool { } // PgError represents an error reported by the PostgreSQL server. See -// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for +// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for // detailed field description. type PgError struct { - Severity string - Code string - Message string - Detail string - Hint string - Position int32 - InternalPosition int32 - InternalQuery string - Where string - SchemaName string - TableName string - ColumnName string - DataTypeName string - ConstraintName string - File string - Line int32 - Routine string + Severity string + SeverityUnlocalized string + Code string + Message string + Detail string + Hint string + Position int32 + InternalPosition int32 + InternalQuery string + Where string + SchemaName string + TableName string + ColumnName string + DataTypeName string + ConstraintName string + File string + Line int32 + Routine string } func (pe *PgError) Error() string { @@ -60,23 +62,37 @@ func (pe *PgError) SQLState() string { // ConnectError is the error returned when a connection attempt fails. type ConnectError struct { Config *Config // The configuration that was used in the connection attempt. - msg string err error } func (e *ConnectError) Error() string { - sb := &strings.Builder{} - fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.Config.Host, e.Config.User, e.Config.Database, e.msg) - if e.err != nil { - fmt.Fprintf(sb, " (%s)", e.err.Error()) + prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database) + details := e.err.Error() + if strings.Contains(details, "\n") { + return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t") + } else { + return prefix + " " + details } - return sb.String() } func (e *ConnectError) Unwrap() error { return e.err } +type perDialConnectError struct { + address string + originalHostname string + err error +} + +func (e *perDialConnectError) Error() string { + return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error()) +} + +func (e *perDialConnectError) Unwrap() error { + return e.err +} + type connLockError struct { status string } @@ -96,6 +112,14 @@ type ParseConfigError struct { err error } +func NewParseConfigError(conn, msg string, err error) error { + return &ParseConfigError{ + ConnString: conn, + msg: msg, + err: err, + } +} + func (e *ParseConfigError) Error() string { // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would @@ -195,10 +219,10 @@ func redactPW(connString string) string { return redactURL(u) } } - quotedDSN := regexp.MustCompile(`password='[^']*'`) - connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") - plainDSN := regexp.MustCompile(`password=[^ ]*`) - connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") + quotedKV := regexp.MustCompile(`password='[^']*'`) + connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx") + plainKV := regexp.MustCompile(`password=[^ ]*`) + connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx") brokenURL := regexp.MustCompile(`:[^:@]+?@`) connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") return connString diff --git a/pgconn/errors_test.go b/pgconn/errors_test.go index dd094d786..3f923a988 100644 --- a/pgconn/errors_test.go +++ b/pgconn/errors_test.go @@ -19,18 +19,18 @@ func TestConfigError(t *testing.T) { expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", }, { - name: "dsn with password unquoted", + name: "keyword/value with password unquoted", err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", }, { - name: "dsn with password quoted", + name: "keyword/value with password quoted", err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", }, { name: "weird url", - err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), + err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil), expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", }, { diff --git a/pgconn/export_test.go b/pgconn/export_test.go index 6726e4cf7..9c0e02e74 100644 --- a/pgconn/export_test.go +++ b/pgconn/export_test.go @@ -1,11 +1,3 @@ // File export_test exports some methods for better testing. package pgconn - -func NewParseConfigError(conn, msg string, err error) error { - return &ParseConfigError{ - ConnString: conn, - msg: msg, - err: err, - } -} diff --git a/pgconn/krb5.go b/pgconn/krb5.go index cc373ae51..da87c5f0d 100644 --- a/pgconn/krb5.go +++ b/pgconn/krb5.go @@ -28,7 +28,7 @@ func RegisterGSSProvider(newGSSArg NewGSSFunc) { // GSS provides GSSAPI authentication (e.g., Kerberos). type GSS interface { - GetInitToken(host string, service string) ([]byte, error) + GetInitToken(host, service string) ([]byte, error) GetInitTokenFromSPN(spn string) ([]byte, error) Continue(inToken []byte) (done bool, outToken []byte, err error) } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index 3d525be0f..bfa73cf41 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1,6 +1,7 @@ package pgconn import ( + "container/list" "context" "crypto/md5" "crypto/tls" @@ -82,6 +83,8 @@ type PgConn struct { slowWriteTimer *time.Timer bgReaderStarted chan struct{} + customData map[string]any + config *Config status byte // One of connStatus* constants @@ -103,8 +106,9 @@ type PgConn struct { cleanupDone chan struct{} } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) to provide configuration. See documentation for [ParseConfig] for details. ctx can be used to cancel a +// connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { @@ -114,9 +118,9 @@ func Connect(ctx context.Context, connString string) (*PgConn, error) { return ConnectConfig(ctx, config) } -// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) -// and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. ctx can be -// used to cancel a connect attempt. +// Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or keyword/value +// format) and ParseConfigOptions to provide additional configuration. See documentation for [ParseConfig] for details. +// ctx can be used to cancel a connect attempt. func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptions ParseConfigOptions) (*PgConn, error) { config, err := ParseConfigWithOptions(connString, parseConfigOptions) if err != nil { @@ -131,113 +135,77 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: -// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, -// if all attempts fail the last error is returned. -func ConnectConfig(octx context.Context, config *Config) (pgConn *PgConn, err error) { +// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. +func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // zero values. if !config.createdByParseConfig { panic("config must be created by ParseConfig") } - // Simplify usage by treating primary config and fallbacks the same. - fallbackConfigs := []*FallbackConfig{ - { - Host: config.Host, - Port: config.Port, - TLSConfig: config.TLSConfig, - }, - } - fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) - ctx := octx - fallbackConfigs, err = expandWithIPs(ctx, config.LookupFunc, fallbackConfigs) - if err != nil { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: err} - } + var allErrors []error - if len(fallbackConfigs) == 0 { - return nil, &ConnectError{Config: config, msg: "hostname resolving error", err: errors.New("ip addr wasn't found")} - } - - foundBestServer := false - var fallbackConfig *FallbackConfig - for i, fc := range fallbackConfigs { - // ConnectTimeout restricts the whole connection process. - if config.ConnectTimeout != 0 { - // create new context first time or when previous host was different - if i == 0 || (fallbackConfigs[i].Host != fallbackConfigs[i-1].Host) { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) - defer cancel() - } - } else { - ctx = octx - } - pgConn, err = connect(ctx, config, fc, false) - if err == nil { - foundBestServer = true - break - } else if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} - const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password - const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings - const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist - const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege - if pgerr.Code == ERRCODE_INVALID_PASSWORD || - pgerr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && fc.TLSConfig != nil || - pgerr.Code == ERRCODE_INVALID_CATALOG_NAME || - pgerr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { - break - } - } else if cerr, ok := err.(*ConnectError); ok { - if _, ok := cerr.err.(*NotPreferredError); ok { - fallbackConfig = fc - } - } + connectConfigs, errs := buildConnectOneConfigs(ctx, config) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) } - if !foundBestServer && fallbackConfig != nil { - pgConn, err = connect(ctx, config, fallbackConfig, true) - if pgerr, ok := err.(*PgError); ok { - err = &ConnectError{Config: config, msg: "server error", err: pgerr} - } + if len(connectConfigs) == 0 { + return nil, &ConnectError{Config: config, err: fmt.Errorf("hostname resolving error: %w", errors.Join(allErrors...))} } - if err != nil { - return nil, err // no need to wrap in connectError because it will already be wrapped in all cases except PgError + pgConn, errs := connectPreferred(ctx, config, connectConfigs) + if len(errs) > 0 { + allErrors = append(allErrors, errs...) + return nil, &ConnectError{Config: config, err: errors.Join(allErrors...)} } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "AfterConnect error", err: err} + return nil, &ConnectError{Config: config, err: fmt.Errorf("AfterConnect error: %w", err)} } } return pgConn, nil } -func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*FallbackConfig) ([]*FallbackConfig, error) { - var configs []*FallbackConfig +// buildConnectOneConfigs resolves hostnames and builds a list of connectOneConfigs to try connecting to. It returns a +// slice of successfully resolved connectOneConfigs and a slice of errors. It is possible for both slices to contain +// values if some hosts were successfully resolved and others were not. +func buildConnectOneConfigs(ctx context.Context, config *Config) ([]*connectOneConfig, []error) { + // Simplify usage by treating primary config and fallbacks the same. + fallbackConfigs := []*FallbackConfig{ + { + Host: config.Host, + Port: config.Port, + TLSConfig: config.TLSConfig, + }, + } + fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) + + var configs []*connectOneConfig - var lookupErrors []error + var allErrors []error - for _, fb := range fallbacks { + for _, fb := range fallbackConfigs { // skip resolve for unix sockets if isAbsolutePath(fb.Host) { - configs = append(configs, &FallbackConfig{ - Host: fb.Host, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(fb.Host, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) continue } - ips, err := lookupFn(ctx, fb.Host) + ips, err := config.LookupFunc(ctx, fb.Host) if err != nil { - lookupErrors = append(lookupErrors, err) + allErrors = append(allErrors, err) continue } @@ -246,63 +214,137 @@ func expandWithIPs(ctx context.Context, lookupFn LookupFunc, fallbacks []*Fallba if err == nil { port, err := strconv.ParseUint(splitPort, 10, 16) if err != nil { - return nil, fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err) + return nil, []error{fmt.Errorf("error parsing port (%s) from lookup: %w", splitPort, err)} } - configs = append(configs, &FallbackConfig{ - Host: splitIP, - Port: uint16(port), - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(splitIP, uint16(port)) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } else { - configs = append(configs, &FallbackConfig{ - Host: ip, - Port: fb.Port, - TLSConfig: fb.TLSConfig, + network, address := NetworkAddress(ip, fb.Port) + configs = append(configs, &connectOneConfig{ + network: network, + address: address, + originalHostname: fb.Host, + tlsConfig: fb.TLSConfig, }) } } } - // See https://github.com/jackc/pgx/issues/1464. When Go 1.20 can be used in pgx consider using errors.Join so all - // errors are reported. - if len(configs) == 0 && len(lookupErrors) > 0 { - return nil, lookupErrors[0] + return configs, allErrors +} + +// connectPreferred attempts to connect to the preferred host from connectOneConfigs. The connections are attempted in +// order. If a connection is successful it is returned. If no connection is successful then all errors are returned. If +// a connection attempt returns a [NotPreferredError], then that host will be used if no other hosts are successful. +func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*connectOneConfig) (*PgConn, []error) { + octx := ctx + var allErrors []error + + var fallbackConnectOneConfig *connectOneConfig + for i, c := range connectOneConfigs { + // ConnectTimeout restricts the whole connection process. + if config.ConnectTimeout != 0 { + // create new context first time or when previous host was different + if i == 0 || (connectOneConfigs[i].address != connectOneConfigs[i-1].address) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(octx, config.ConnectTimeout) + defer cancel() + } + } else { + ctx = octx + } + + pgConn, err := connectOne(ctx, config, c, false) + if pgConn != nil { + return pgConn, nil + } + + allErrors = append(allErrors, err) + + var pgErr *PgError + if errors.As(err, &pgErr) { + // pgx will try next host even if libpq does not in certain cases (see #2246) + // consider change for the next major version + + const ERRCODE_INVALID_PASSWORD = "28P01" + const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist + const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege + + // auth failed due to invalid password, db does not exist or user has no permission + if pgErr.Code == ERRCODE_INVALID_PASSWORD || + pgErr.Code == ERRCODE_INVALID_CATALOG_NAME || + pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE { + return nil, allErrors + } + } + + var npErr *NotPreferredError + if errors.As(err, &npErr) { + fallbackConnectOneConfig = c + } } - return configs, nil + if fallbackConnectOneConfig != nil { + pgConn, err := connectOne(ctx, config, fallbackConnectOneConfig, true) + if err == nil { + return pgConn, nil + } + allErrors = append(allErrors, err) + } + + return nil, allErrors } -func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig, +// connectOne makes one connection attempt to a single host. +func connectOne(ctx context.Context, config *Config, connectConfig *connectOneConfig, ignoreNotPreferredErr bool, ) (*PgConn, error) { pgConn := new(PgConn) pgConn.config = config pgConn.cleanupDone = make(chan struct{}) + pgConn.customData = make(map[string]any) var err error - network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - netConn, err := config.DialFunc(ctx, network, address) - if err != nil { - return nil, &ConnectError{Config: config, msg: "dial error", err: normalizeTimeoutError(ctx, err)} + + newPerDialConnectError := func(msg string, err error) *perDialConnectError { + err = normalizeTimeoutError(ctx, err) + e := &perDialConnectError{address: connectConfig.address, originalHostname: connectConfig.originalHostname, err: fmt.Errorf("%s: %w", msg, err)} + return e } - pgConn.conn = netConn - pgConn.contextWatcher = newContextWatcher(netConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn, err = config.DialFunc(ctx, connectConfig.network, connectConfig.address) + if err != nil { + return nil, newPerDialConnectError("dial error", err) + } - if fallbackConfig.TLSConfig != nil { - nbTLSConn, err := startTLS(netConn, fallbackConfig.TLSConfig) + if connectConfig.tlsConfig != nil { + pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn}) + pgConn.contextWatcher.Watch(ctx) + var ( + tlsConn net.Conn + err error + ) + if config.SSLNegotiation == "direct" { + tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig) + } else { + tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig) + } pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS. if err != nil { - netConn.Close() - return nil, &ConnectError{Config: config, msg: "tls error", err: normalizeTimeoutError(ctx, err)} + pgConn.conn.Close() + return nil, newPerDialConnectError("tls error", err) } - pgConn.conn = nbTLSConn - pgConn.contextWatcher = newContextWatcher(nbTLSConn) - pgConn.contextWatcher.Watch(ctx) + pgConn.conn = tlsConn } + pgConn.contextWatcher = ctxwatch.NewContextWatcher(config.BuildContextWatcherHandler(pgConn)) + pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() pgConn.parameterStatuses = make(map[string]string) @@ -336,7 +378,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.frontend.Send(&startupMsg) if err := pgConn.flushWithPotentialWriteReadDeadlock(); err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write startup message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to write startup message", err) } for { @@ -344,9 +386,9 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if err != nil { pgConn.conn.Close() if err, ok := err.(*PgError); ok { - return nil, err + return nil, newPerDialConnectError("server error", err) } - return nil, &ConnectError{Config: config, msg: "failed to receive message", err: normalizeTimeoutError(ctx, err)} + return nil, newPerDialConnectError("failed to receive message", err) } switch msg := msg.(type) { @@ -359,26 +401,26 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig err = pgConn.txPasswordMessage(pgConn.config.Password) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.config.Password+pgConn.config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed to write password message", err: err} + return nil, newPerDialConnectError("failed to write password message", err) } case *pgproto3.AuthenticationSASL: err = pgConn.scramAuth(msg.AuthMechanisms) if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed SASL auth", err: err} + return nil, newPerDialConnectError("failed SASL auth", err) } case *pgproto3.AuthenticationGSS: err = pgConn.gssAuth() if err != nil { pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "failed GSS auth", err: err} + return nil, newPerDialConnectError("failed GSS auth", err) } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle @@ -396,7 +438,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig return pgConn, nil } pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "ValidateConnect failed", err: err} + return nil, newPerDialConnectError("ValidateConnect failed", err) } } return pgConn, nil @@ -404,21 +446,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() - return nil, ErrorResponseToPgError(msg) + return nil, newPerDialConnectError("server error", ErrorResponseToPgError(msg)) default: pgConn.conn.Close() - return nil, &ConnectError{Config: config, msg: "received unexpected message", err: err} + return nil, newPerDialConnectError("received unexpected message", err) } } } -func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher { - return ctxwatch.NewContextWatcher( - func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { conn.SetDeadline(time.Time{}) }, - ) -} - func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { @@ -928,23 +963,24 @@ func (pgConn *PgConn) Deallocate(ctx context.Context, name string) error { // ErrorResponseToPgError converts a wire protocol error message to a *PgError. func ErrorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ - Severity: msg.Severity, - Code: string(msg.Code), - Message: string(msg.Message), - Detail: string(msg.Detail), - Hint: msg.Hint, - Position: msg.Position, - InternalPosition: msg.InternalPosition, - InternalQuery: string(msg.InternalQuery), - Where: string(msg.Where), - SchemaName: string(msg.SchemaName), - TableName: string(msg.TableName), - ColumnName: string(msg.ColumnName), - DataTypeName: string(msg.DataTypeName), - ConstraintName: msg.ConstraintName, - File: string(msg.File), - Line: msg.Line, - Routine: string(msg.Routine), + Severity: msg.Severity, + SeverityUnlocalized: msg.SeverityUnlocalized, + Code: string(msg.Code), + Message: string(msg.Message), + Detail: string(msg.Detail), + Hint: msg.Hint, + Position: msg.Position, + InternalPosition: msg.InternalPosition, + InternalQuery: string(msg.InternalQuery), + Where: string(msg.Where), + SchemaName: string(msg.SchemaName), + TableName: string(msg.TableName), + ColumnName: string(msg.ColumnName), + DataTypeName: string(msg.DataTypeName), + ConstraintName: msg.ConstraintName, + File: string(msg.File), + Line: msg.Line, + Routine: string(msg.Routine), } } @@ -955,7 +991,8 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there -// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 +// is no way to be sure a query was canceled. +// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be @@ -987,10 +1024,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { defer cancelConn.Close() if ctx != context.Background() { - contextWatcher := ctxwatch.NewContextWatcher( - func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, - func() { cancelConn.SetDeadline(time.Time{}) }, - ) + contextWatcher := ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: cancelConn}) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() } @@ -1107,7 +1141,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { +func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1133,7 +1167,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [] // binary format. If resultFormats is nil all results will be in text format. // // ResultReader must be closed before PgConn can be used again. -func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { +func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result @@ -1340,7 +1374,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co close(pgConn.cleanupDone) return CommandTag{}, normalizeTimeoutError(ctx, err) } - msg, _ := pgConn.receiveMessage() + // peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables. + // Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type + // and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn). + msg, serverError := pgConn.receiveMessage() + if serverError != nil { + close(abortCopyChan) + return CommandTag{}, serverError + } switch msg := msg.(type) { case *pgproto3.ErrorResponse: @@ -1387,9 +1428,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { - pgConn *PgConn - ctx context.Context - pipeline *Pipeline + pgConn *PgConn + ctx context.Context rr *ResultReader @@ -1422,12 +1462,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.closed = true - if mrr.pipeline != nil { - mrr.pipeline.expectedReadyForQueryCount-- - } else { - mrr.pgConn.contextWatcher.Unwatch() - mrr.pgConn.unlock() - } + mrr.pgConn.contextWatcher.Unwatch() + mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = ErrorResponseToPgError(msg) } @@ -1523,8 +1559,10 @@ func (rr *ResultReader) Read() *Result { values := rr.Values() row := make([][]byte, len(values)) for i := range row { - row[i] = make([]byte, len(values[i])) - copy(row[i], values[i]) + if values[i] != nil { + row[i] = make([]byte, len(values[i])) + copy(row[i], values[i]) + } } br.Rows = append(br.Rows, row) } @@ -1649,7 +1687,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error case *pgproto3.EmptyQueryResponse: rr.concludeCommand(CommandTag{}, nil) case *pgproto3.ErrorResponse: - rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg)) + pgErr := ErrorResponseToPgError(msg) + if rr.pipeline != nil { + rr.pipeline.state.HandleError(pgErr) + } + rr.concludeCommand(CommandTag{}, pgErr) } return msg, nil @@ -1674,25 +1716,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. -func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) +func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. -func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) +func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -1718,15 +1790,23 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.contextWatcher.Watch(ctx) } - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err) + multiResult.closed = true + pgConn.asyncClose() + return multiResult + } pgConn.enterPotentialWriteReadDeadlock() defer pgConn.exitPotentialWriteReadDeadlock() _, err := pgConn.conn.Write(batch.buf) if err != nil { + pgConn.contextWatcher.Unwatch() + multiResult.err = normalizeTimeoutError(multiResult.ctx, err) multiResult.closed = true - multiResult.err = err - pgConn.unlock() + pgConn.asyncClose() return multiResult } @@ -1843,6 +1923,11 @@ func (pgConn *PgConn) SyncConn(ctx context.Context) error { return errors.New("SyncConn: conn never synchronized") } +// CustomData returns a map that can be used to associate custom data with the connection. +func (pgConn *PgConn) CustomData() map[string]any { + return pgConn.customData +} + // HijackedConn is the result of hijacking a connection. // // Due to the necessary exposure of internal implementation details, it is not covered by the semantic versioning @@ -1855,6 +1940,7 @@ type HijackedConn struct { TxStatus byte Frontend *pgproto3.Frontend Config *Config + CustomData map[string]any } // Hijack extracts the internal connection data. pgConn must be in an idle state. SyncConn should be called immediately @@ -1877,6 +1963,7 @@ func (pgConn *PgConn) Hijack() (*HijackedConn, error) { TxStatus: pgConn.txStatus, Frontend: pgConn.frontend, Config: pgConn.config, + CustomData: pgConn.customData, }, nil } @@ -1896,13 +1983,14 @@ func Construct(hc *HijackedConn) (*PgConn, error) { txStatus: hc.TxStatus, frontend: hc.Frontend, config: hc.Config, + customData: hc.CustomData, status: connStatusIdle, cleanupDone: make(chan struct{}), } - pgConn.contextWatcher = newContextWatcher(pgConn.conn) + pgConn.contextWatcher = ctxwatch.NewContextWatcher(hc.Config.BuildContextWatcherHandler(pgConn)) pgConn.bgReader = bgreader.New(pgConn.conn) pgConn.slowWriteTimer = time.AfterFunc(time.Duration(math.MaxInt64), func() { @@ -1932,9 +2020,7 @@ type Pipeline struct { conn *PgConn ctx context.Context - expectedReadyForQueryCount int - pendingSync bool - + state pipelineState err error closed bool } @@ -1945,6 +2031,122 @@ type PipelineSync struct{} // CloseComplete is returned by GetResults when a CloseComplete message is received. type CloseComplete struct{} +type pipelineRequestType int + +const ( + pipelineNil pipelineRequestType = iota + pipelinePrepare + pipelineQueryParams + pipelineQueryPrepared + pipelineDeallocate + pipelineSyncRequest + pipelineFlushRequest +) + +type pipelineRequestEvent struct { + RequestType pipelineRequestType + WasSentToServer bool + BeforeFlushOrSync bool +} + +type pipelineState struct { + requestEventQueue list.List + lastRequestType pipelineRequestType + pgErr *PgError + expectedReadyForQueryCount int +} + +func (s *pipelineState) Init() { + s.requestEventQueue.Init() + s.lastRequestType = pipelineNil +} + +func (s *pipelineState) RegisterSendingToServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.WasSentToServer { + return + } + val.WasSentToServer = true + elem.Value = val + } +} + +func (s *pipelineState) registerFlushingBufferOnServer() { + for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() { + val := elem.Value.(pipelineRequestEvent) + if val.BeforeFlushOrSync { + return + } + val.BeforeFlushOrSync = true + elem.Value = val + } +} + +func (s *pipelineState) PushBackRequestType(req pipelineRequestType) { + if req == pipelineNil { + return + } + + if req != pipelineFlushRequest { + s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req}) + } + if req == pipelineFlushRequest || req == pipelineSyncRequest { + s.registerFlushingBufferOnServer() + } + s.lastRequestType = req + + if req == pipelineSyncRequest { + s.expectedReadyForQueryCount++ + } +} + +func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType { + for { + elem := s.requestEventQueue.Front() + if elem == nil { + return pipelineNil + } + val := elem.Value.(pipelineRequestEvent) + if !(val.WasSentToServer && val.BeforeFlushOrSync) { + return pipelineNil + } + + s.requestEventQueue.Remove(elem) + if val.RequestType == pipelineSyncRequest { + s.pgErr = nil + } + if s.pgErr == nil { + return val.RequestType + } + } +} + +func (s *pipelineState) HandleError(err *PgError) { + s.pgErr = err +} + +func (s *pipelineState) HandleReadyForQuery() { + s.expectedReadyForQueryCount-- +} + +func (s *pipelineState) PendingSync() bool { + var notPendingSync bool + + if elem := s.requestEventQueue.Back(); elem != nil { + val := elem.Value.(pipelineRequestEvent) + notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer + } else { + notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil) + } + + return !notPendingSync +} + +func (s *pipelineState) ExpectedReadyForQuery() int { + return s.expectedReadyForQueryCount +} + // StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent // to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection // to normal mode. While in pipeline mode, no methods that communicate with the server may be called except @@ -1953,16 +2155,21 @@ type CloseComplete struct{} // Prefer ExecBatch when only sending one group of queries at once. func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline { if err := pgConn.lock(); err != nil { - return &Pipeline{ + pipeline := &Pipeline{ closed: true, err: err, } + pipeline.state.Init() + + return pipeline } pgConn.pipeline = Pipeline{ conn: pgConn, ctx: ctx, } + pgConn.pipeline.state.Init() + pipeline := &pgConn.pipeline if ctx != context.Background() { @@ -1985,10 +2192,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelinePrepare) } // SendDeallocate deallocates a prepared statement. @@ -1996,34 +2203,65 @@ func (p *Pipeline) SendDeallocate(name string) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name}) + p.state.PushBackRequestType(pipelineDeallocate) } // SendQueryParams is the pipeline version of *PgConn.QueryParams. -func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { +func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}) p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryParams) } // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. -func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { +func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { if p.closed { return } - p.pendingSync = true p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}) p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'}) p.conn.frontend.SendExecute(&pgproto3.Execute{}) + p.state.PushBackRequestType(pipelineQueryPrepared) +} + +// SendFlushRequest sends a request for the server to flush its output buffer. +// +// The server flushes its output buffer automatically as a result of Sync being called, +// or on any request when not in pipeline mode; this function is useful to cause the server +// to flush its output buffer in pipeline mode without establishing a synchronization point. +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendFlushRequest. +func (p *Pipeline) SendFlushRequest() { + if p.closed { + return + } + + p.conn.frontend.Send(&pgproto3.Flush{}) + p.state.PushBackRequestType(pipelineFlushRequest) +} + +// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message +// without flushing the send buffer. This serves as the delimiter of an implicit +// transaction and an error recovery point. +// +// Note that the request is not itself flushed to the server automatically; use Flush if +// necessary. This copies the behavior of libpq PQsendPipelineSync. +func (p *Pipeline) SendPipelineSync() { + if p.closed { + return + } + + p.conn.frontend.SendSync(&pgproto3.Sync{}) + p.state.PushBackRequestType(pipelineSyncRequest) } // Flush flushes the queued requests without establishing a synchronization point. @@ -2048,28 +2286,14 @@ func (p *Pipeline) Flush() error { return err } + p.state.RegisterSendingToServer() return nil } // Sync establishes a synchronization point and flushes the queued requests. func (p *Pipeline) Sync() error { - if p.closed { - if p.err != nil { - return p.err - } - return errors.New("pipeline closed") - } - - p.conn.frontend.SendSync(&pgproto3.Sync{}) - err := p.Flush() - if err != nil { - return err - } - - p.pendingSync = false - p.expectedReadyForQueryCount++ - - return nil + p.SendPipelineSync() + return p.Flush() } // GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or @@ -2083,7 +2307,7 @@ func (p *Pipeline) GetResults() (results any, err error) { return nil, errors.New("pipeline closed") } - if p.expectedReadyForQueryCount == 0 { + if p.state.ExtractFrontRequestType() == pipelineNil { return nil, nil } @@ -2094,6 +2318,8 @@ func (p *Pipeline) getResults() (results any, err error) { for { msg, err := p.conn.receiveMessage() if err != nil { + p.closed = true + p.err = err p.conn.asyncClose() return nil, normalizeTimeoutError(p.ctx, err) } @@ -2126,13 +2352,13 @@ func (p *Pipeline) getResults() (results any, err error) { case *pgproto3.CloseComplete: return &CloseComplete{}, nil case *pgproto3.ReadyForQuery: - p.expectedReadyForQueryCount-- + p.state.HandleReadyForQuery() return &PipelineSync{}, nil case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr } - } } @@ -2162,6 +2388,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) { // These should never happen here. But don't take chances that could lead to a deadlock. case *pgproto3.ErrorResponse: pgErr := ErrorResponseToPgError(msg) + p.state.HandleError(pgErr) return nil, pgErr case *pgproto3.CommandComplete: p.conn.asyncClose() @@ -2181,7 +2408,7 @@ func (p *Pipeline) Close() error { p.closed = true - if p.pendingSync { + if p.state.PendingSync() { p.conn.asyncClose() p.err = errors.New("pipeline has unsynced requests") p.conn.contextWatcher.Unwatch() @@ -2190,7 +2417,7 @@ func (p *Pipeline) Close() error { return p.err } - for p.expectedReadyForQueryCount > 0 { + for p.state.ExpectedReadyForQuery() > 0 { _, err := p.getResults() if err != nil { p.err = err @@ -2207,3 +2434,71 @@ func (p *Pipeline) Close() error { return p.err } + +// DeadlineContextWatcherHandler handles canceled contexts by setting a deadline on a net.Conn. +type DeadlineContextWatcherHandler struct { + Conn net.Conn + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration +} + +func (h *DeadlineContextWatcherHandler) HandleCancel(ctx context.Context) { + h.Conn.SetDeadline(time.Now().Add(h.DeadlineDelay)) +} + +func (h *DeadlineContextWatcherHandler) HandleUnwatchAfterCancel() { + h.Conn.SetDeadline(time.Time{}) +} + +// CancelRequestContextWatcherHandler handles canceled contexts by sending a cancel request to the server. It also sets +// a deadline on a net.Conn as a fallback. +type CancelRequestContextWatcherHandler struct { + Conn *PgConn + + // CancelRequestDelay is the delay before sending the cancel request to the server. + CancelRequestDelay time.Duration + + // DeadlineDelay is the delay to set on the deadline set on net.Conn when the context is canceled. + DeadlineDelay time.Duration + + cancelFinishedChan chan struct{} + handleUnwatchAfterCancelCalled func() +} + +func (h *CancelRequestContextWatcherHandler) HandleCancel(context.Context) { + h.cancelFinishedChan = make(chan struct{}) + var handleUnwatchedAfterCancelCalledCtx context.Context + handleUnwatchedAfterCancelCalledCtx, h.handleUnwatchAfterCancelCalled = context.WithCancel(context.Background()) + + deadline := time.Now().Add(h.DeadlineDelay) + h.Conn.conn.SetDeadline(deadline) + + go func() { + defer close(h.cancelFinishedChan) + + select { + case <-handleUnwatchedAfterCancelCalledCtx.Done(): + return + case <-time.After(h.CancelRequestDelay): + } + + cancelRequestCtx, cancel := context.WithDeadline(handleUnwatchedAfterCancelCalledCtx, deadline) + defer cancel() + h.Conn.CancelRequest(cancelRequestCtx) + + // CancelRequest is inherently racy. Even though the cancel request has been received by the server at this point, + // it hasn't necessarily been delivered to the other connection. If we immediately return and the connection is + // immediately used then it is possible the CancelRequest will actually cancel our next query. The + // TestCancelRequestContextWatcherHandler Stress test can produce this error without the sleep below. The sleep time + // is arbitrary, but should be sufficient to prevent this error case. + time.Sleep(100 * time.Millisecond) + }() +} + +func (h *CancelRequestContextWatcherHandler) HandleUnwatchAfterCancel() { + h.handleUnwatchAfterCancelCalled() + <-h.cancelFinishedChan + + h.Conn.conn.SetDeadline(time.Time{}) +} diff --git a/pgconn/pgconn_private_test.go b/pgconn/pgconn_private_test.go index 5659bc9ef..a0c15c27a 100644 --- a/pgconn/pgconn_private_test.go +++ b/pgconn/pgconn_private_test.go @@ -9,7 +9,7 @@ import ( func TestCommandTag(t *testing.T) { t.Parallel() - var tests = []struct { + tests := []struct { commandTag CommandTag rowsAffected int64 isInsert bool diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index a98338aeb..f0e816d59 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -14,6 +14,7 @@ import ( "os" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -24,6 +25,7 @@ import ( "github.com/yugabyte/pgx/v5/internal/pgio" "github.com/yugabyte/pgx/v5/internal/pgmock" "github.com/yugabyte/pgx/v5/pgconn" + "github.com/yugabyte/pgx/v5/pgconn/ctxwatch" "github.com/yugabyte/pgx/v5/pgproto3" "github.com/yugabyte/pgx/v5/pgtype" ) @@ -356,10 +358,8 @@ func TestConnectInvalidUser(t *testing.T) { _, err = pgconn.ConnectConfig(ctx, config) require.Error(t, err) - pgErr, ok := errors.Unwrap(err).(*pgconn.PgError) - if !ok { - t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err) - } + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) if pgErr.Code != "28000" && pgErr.Code != "28P01" { t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) } @@ -528,6 +528,22 @@ func TestConnectWithFallback(t *testing.T) { closeConn(t, conn) } +func TestConnectFailsWithResolveFailureAndFailedConnectionAttempts(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, "host=localhost,127.0.0.1,foo.invalid port=1,2,3 sslmode=disable") + require.Error(t, err) + require.Nil(t, conn) + + require.ErrorContains(t, err, "lookup foo.invalid") + // Not testing the entire string as depending on IPv4 or IPv6 support localhost may resolve to 127.0.0.1 or ::1. + require.ErrorContains(t, err, ":1 (localhost): dial error:") + require.ErrorContains(t, err, ":2 (127.0.0.1): dial error:") +} + func TestConnectWithValidateConnect(t *testing.T) { t.Parallel() @@ -1175,6 +1191,24 @@ func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/1987 +func TestResultReaderReadNil(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + result := pgConn.ExecParams(ctx, "select null::text", nil, nil, nil, nil).Read() + require.NoError(t, result.Err) + require.Nil(t, result.Rows[0][0]) + + ensureConnValid(t, pgConn) +} + func TestConnExecPrepared(t *testing.T) { t.Parallel() @@ -1387,6 +1421,52 @@ func TestConnExecBatch(t *testing.T) { assert.Equal(t, "SELECT 1", results[2].CommandTag.String()) } +type mockConnection struct { + net.Conn + writeLatency *time.Duration +} + +func (m mockConnection) Write(b []byte) (n int, err error) { + time.Sleep(*m.writeLatency) + return m.Conn.Write(b) +} + +func TestConnExecBatchWriteError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + var mockConn mockConnection + writeLatency := 0 * time.Second + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := net.Dial(network, address) + mockConn = mockConnection{conn, &writeLatency} + return mockConn, err + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + batch := &pgconn.Batch{} + pgConn.Conn() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel2() + + batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil) + writeLatency = 2 * time.Second + mrr := pgConn.ExecBatch(ctx2, batch) + err = mrr.Close() + require.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) +} + func TestConnExecBatchDeferredError(t *testing.T) { t.Parallel() @@ -1556,9 +1636,9 @@ func TestConnOnNotice(t *testing.T) { config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) require.NoError(t, err) - var msg string - config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) { - msg = notice.Message + var notice *pgconn.Notice + config.OnNotice = func(c *pgconn.PgConn, n *pgconn.Notice) { + notice = n } config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect. @@ -1576,7 +1656,8 @@ begin end$$;`) err = multiResult.Close() require.NoError(t, err) - assert.Equal(t, "hello, world", msg) + assert.Equal(t, "NOTICE", notice.SeverityUnlocalized) + assert.Equal(t, "hello, world", notice.Message) ensureConnValid(t, pgConn) } @@ -2049,6 +2130,63 @@ func TestConnCopyFromPrecanceled(t *testing.T) { ensureConnValid(t, pgConn) } +// https://github.com/jackc/pgx/issues/2364 +func TestConnCopyFromConnectionTerminated(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + if pgConn.ParameterStatus("crdb_version") != "" { + t.Skip("Server does not support pg_terminate_backend") + } + + closerConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + time.AfterFunc(500*time.Millisecond, func() { + // defer inside of AfterFunc instead of outer test function because outer function can finish while Read is still in + // progress which could cause closerConn to be closed too soon. + defer closeConn(t, closerConn) + err := closerConn.ExecParams(ctx, "select pg_terminate_backend($1)", [][]byte{[]byte(fmt.Sprintf("%d", pgConn.PID()))}, nil, nil, nil).Read().Err + require.NoError(t, err) + }) + + _, err = pgConn.Exec(ctx, `create temporary table foo( + a int4, + b varchar + )`).ReadAll() + require.NoError(t, err) + + r, w := io.Pipe() + go func() { + for i := 0; i < 5_000; i++ { + a := strconv.Itoa(i) + b := "foo " + a + " bar" + _, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b))) + if err != nil { + return + } + time.Sleep(time.Millisecond) + } + }() + + copySql := "COPY foo FROM STDIN WITH (FORMAT csv)" + ct, err := pgConn.CopyFrom(ctx, r, copySql) + assert.Equal(t, int64(0), ct.RowsAffected()) + assert.Error(t, err) + + assert.True(t, pgConn.IsClosed()) + select { + case <-pgConn.CleanupDone(): + case <-time.After(5 * time.Second): + t.Fatal("Connection cleanup exceeded maximum time") + } +} + func TestConnCopyFromGzipReader(t *testing.T) { t.Parallel() @@ -2069,8 +2207,9 @@ func TestConnCopyFromGzipReader(t *testing.T) { )`).ReadAll() require.NoError(t, err) - f, err := os.CreateTemp("", "*") + f, err := os.CreateTemp(t.TempDir(), "*") require.NoError(t, err) + defer f.Close() gw := gzip.NewWriter(f) @@ -2103,12 +2242,6 @@ func TestConnCopyFromGzipReader(t *testing.T) { err = gr.Close() require.NoError(t, err) - err = f.Close() - require.NoError(t, err) - - err = os.Remove(f.Name()) - require.NoError(t, err) - result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read() require.NoError(t, result.Err) @@ -2680,7 +2813,7 @@ func TestPipelinePrepare(t *testing.T) { sd, ok := results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, "a", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -2688,7 +2821,7 @@ func TestPipelinePrepare(t *testing.T) { sd, ok = results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "b") + require.Equal(t, "b", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -2696,7 +2829,7 @@ func TestPipelinePrepare(t *testing.T) { sd, ok = results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "c") + require.Equal(t, "c", string(sd.Fields[0].Name)) require.Equal(t, []uint32{}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -2750,7 +2883,7 @@ func TestPipelinePrepareError(t *testing.T) { sd, ok := results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, "a", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -2794,7 +2927,7 @@ func TestPipelinePrepareAndDeallocate(t *testing.T) { sd, ok := results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "a") + require.Equal(t, "a", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -2931,7 +3064,7 @@ func TestPipelinePrepareQuery(t *testing.T) { sd, ok := results.(*pgconn.StatementDescription) require.Truef(t, ok, "expected StatementDescription, got: %#v", results) require.Len(t, sd.Fields, 1) - require.Equal(t, string(sd.Fields[0].Name), "msg") + require.Equal(t, "msg", string(sd.Fields[0].Name)) require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) results, err = pipeline.GetResults() @@ -3076,7 +3209,7 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) { ensureConnValid(t, pgConn) } -func TestPipelineCloseReadsUnreadResults(t *testing.T) { +func TestPipelineFlushForSingleRequests(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -3087,34 +3220,95 @@ func TestPipelineCloseReadsUnreadResults(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(ctx) - pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) - err = pipeline.Sync() - require.NoError(t, err) - pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) - err = pipeline.Sync() + pipeline.SendPrepare("ps", "select $1::text as msg", nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() require.NoError(t, err) results, err := pipeline.GetResults() require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "msg", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) rr, ok := results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) readResult := rr.Read() require.NoError(t, readResult.Err) require.Len(t, readResult.Rows, 1) require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "hello", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendDeallocate("ps") + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.CloseComplete) + require.Truef(t, ok, "expected CloseComplete, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) require.Equal(t, "1", string(readResult.Rows[0][0])) + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + err = pipeline.Close() require.NoError(t, err) ensureConnValid(t, pgConn) } -func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { +func TestPipelineFlushForRequestSeries(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) @@ -3125,17 +3319,31 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { defer closeConn(t, pgConn) pipeline := pgConn.StartPipeline(ctx) - pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendPrepare("ps", "select $1::bigint as num", nil) err = pipeline.Sync() require.NoError(t, err) - pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) - pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) - results, err := pipeline.GetResults() require.NoError(t, err) + sd, ok := results.(*pgconn.StatementDescription) + require.Truef(t, ok, "expected StatementDescription, got: %#v", results) + require.Len(t, sd.Fields, 1) + require.Equal(t, "num", string(sd.Fields[0].Name)) + require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) rr, ok := results.(*pgconn.ResultReader) require.Truef(t, ok, "expected ResultReader, got: %#v", results) readResult := rr.Read() @@ -3144,46 +3352,325 @@ func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { require.Len(t, readResult.Rows[0], 1) require.Equal(t, "1", string(readResult.Rows[0][0])) - err = pipeline.Close() - require.EqualError(t, err, "pipeline has unsynced requests") -} - -func TestConnOnPgError(t *testing.T) { - t.Parallel() + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "2", string(readResult.Rows[0][0])) - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) - config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil) + err = pipeline.Flush() require.NoError(t, err) - config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { - require.NotNil(t, c) - require.NotNil(t, pgErr) - // close connection on undefined tables only - if pgErr.Code == "42P01" { - return false - } - return true - } - pgConn, err := pgconn.ConnectConfig(ctx, config) + results, err = pipeline.GetResults() require.NoError(t, err) - defer closeConn(t, pgConn) + require.Nil(t, results) - _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() - assert.NoError(t, err) - assert.False(t, pgConn.IsClosed()) + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) - _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() - assert.Error(t, err) - assert.False(t, pgConn.IsClosed()) + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "3", string(readResult.Rows[0][0])) - _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() - assert.Error(t, err) - assert.True(t, pgConn.IsClosed()) -} + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "4", string(readResult.Rows[0][0])) -func Example() { + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil) + pipeline.SendFlushRequest() + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "6", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineFlushWithError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + var pgErr *pgconn.PgError + require.ErrorAs(t, readResult.Err, &pgErr) + require.Equal(t, "22012", pgErr.Code) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendPipelineSync() + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + pipeline.SendFlushRequest() + err = pipeline.Flush() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + results, err = pipeline.GetResults() + require.NoError(t, err) + rr, ok = results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult = rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "5", string(readResult.Rows[0][0])) + + results, err = pipeline.GetResults() + require.NoError(t, err) + require.Nil(t, results) + + err = pipeline.Sync() + require.NoError(t, err) + + results, err = pipeline.GetResults() + require.NoError(t, err) + _, ok = results.(*pgconn.PipelineSync) + require.Truef(t, ok, "expected PipelineSync, got: %#v", results) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseReadsUnreadResults(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.NoError(t, err) + + ensureConnValid(t, pgConn) +} + +func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pipeline := pgConn.StartPipeline(ctx) + pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil) + err = pipeline.Sync() + require.NoError(t, err) + + pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil) + pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil) + + results, err := pipeline.GetResults() + require.NoError(t, err) + rr, ok := results.(*pgconn.ResultReader) + require.Truef(t, ok, "expected ResultReader, got: %#v", results) + readResult := rr.Read() + require.NoError(t, readResult.Err) + require.Len(t, readResult.Rows, 1) + require.Len(t, readResult.Rows[0], 1) + require.Equal(t, "1", string(readResult.Rows[0][0])) + + err = pipeline.Close() + require.EqualError(t, err, "pipeline has unsynced requests") +} + +func TestConnOnPgError(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool { + require.NotNil(t, c) + require.NotNil(t, pgErr) + // close connection on undefined tables only + if pgErr.Code == "42P01" { + return false + } + return true + } + + pgConn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + _, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll() + assert.NoError(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select 1/0").ReadAll() + assert.Error(t, err) + assert.False(t, pgConn.IsClosed()) + + _, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll() + assert.Error(t, err) + assert.True(t, pgConn.IsClosed()) +} + +func TestConnCustomData(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer closeConn(t, pgConn) + + pgConn.CustomData()["foo"] = "bar" + assert.Equal(t, "bar", pgConn.CustomData()["foo"]) + + ensureConnValid(t, pgConn) +} + +func Example() { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -3363,9 +3850,9 @@ func TestSNISupport(t *testing.T) { return } - srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) - srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) - srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) serverSNINameChan <- sniHost }() @@ -3377,9 +3864,9 @@ func TestSNISupport(t *testing.T) { select { case sniHost := <-serverSNINameChan: if tt.sni_set { - require.Equal(t, sniHost, "localhost") + require.Equal(t, "localhost", sniHost) } else { - require.Equal(t, sniHost, "") + require.Equal(t, "", sniHost) } case err = <-serverErrChan: t.Fatalf("server failed with error: %+v", err) @@ -3389,3 +3876,418 @@ func TestSNISupport(t *testing.T) { }) } } + +func TestConnectWithDirectSSLNegotiation(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + connString string + expectDirectNego bool + }{ + { + name: "Default negotiation (postgres)", + connString: "sslmode=require", + expectDirectNego: false, + }, + { + name: "Direct negotiation", + connString: "sslmode=require sslnegotiation=direct", + expectDirectNego: true, + }, + { + name: "Explicit postgres negotiation", + connString: "sslmode=require sslnegotiation=postgres", + expectDirectNego: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + _, port, err := net.SplitHostPort(ln.Addr().String()) + require.NoError(t, err) + + var directNegoObserved atomic.Bool + + serverErrCh := make(chan error, 1) + go func() { + defer close(serverErrCh) + + conn, err := ln.Accept() + if err != nil { + serverErrCh <- fmt.Errorf("accept error: %w", err) + return + } + defer conn.Close() + + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + firstByte := make([]byte, 1) + _, err = conn.Read(firstByte) + if err != nil { + serverErrCh <- fmt.Errorf("read first byte error: %w", err) + return + } + + // Check if TLS Client Hello (direct) or PostgreSQL SSLRequest + isDirect := firstByte[0] >= 20 && firstByte[0] <= 23 + directNegoObserved.Store(isDirect) + + var tlsConn *tls.Conn + + if !isDirect { + // Handle standard PostgreSQL SSL negotiation + // Read the rest of the SSL request message + sslRequestRemainder := make([]byte, 7) + _, err = io.ReadFull(conn, sslRequestRemainder) + if err != nil { + serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err) + return + } + + // Send SSL acceptance response + _, err = conn.Write([]byte("S")) + if err != nil { + serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err) + return + } + + // Setup TLS server without needing to reuse the first byte + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + tlsConn = tls.Server(conn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } else { + // Handle direct TLS negotiation + // Setup TLS server with the first byte already read + cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM)) + if err != nil { + serverErrCh <- fmt.Errorf("cert error: %w", err) + return + } + + // Use a wrapper to inject the first byte back into the TLS handshake + bufConn := &prefixConn{ + Conn: conn, + prefixData: firstByte, + } + + tlsConn = tls.Server(bufConn, &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + } + + // Complete TLS handshake + if err := tlsConn.Handshake(); err != nil { + serverErrCh <- fmt.Errorf("TLS handshake error: %w", err) + return + } + defer tlsConn.Close() + + err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn)) + if err != nil { + serverErrCh <- fmt.Errorf("pgmock run error: %w", err) + return + } + }() + + connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1", + tt.connString, port) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := pgconn.Connect(ctx, connStr) + + require.NoError(t, err) + + defer conn.Close(ctx) + + err = <-serverErrCh + require.NoError(t, err) + + require.Equal(t, tt.expectDirectNego, directNegoObserved.Load()) + }) + } +} + +// prefixConn implements a net.Conn that prepends some data to the first Read +type prefixConn struct { + net.Conn + prefixData []byte + prefixConsumed bool +} + +func (c *prefixConn) Read(b []byte) (n int, err error) { + if !c.prefixConsumed && len(c.prefixData) > 0 { + n = copy(b, c.prefixData) + c.prefixData = c.prefixData[n:] + c.prefixConsumed = len(c.prefixData) == 0 + return n, nil + } + return c.Conn.Read(b) +} + +// https://github.com/jackc/pgx/issues/1920 +func TestFatalErrorReceivedInPipelineMode(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + steps := pgmock.AcceptUnauthenticatedConnRequestSteps() + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{})) + steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{})) + steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + {Name: []byte("mock")}, + }})) + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + // We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that + // causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic. + steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"})) + + script := &pgmock.Script{Steps: steps} + + ln, err := net.Listen("tcp", "127.0.0.1:") + require.NoError(t, err) + defer ln.Close() + + serverKeepAlive := make(chan struct{}) + defer close(serverKeepAlive) + + serverErrChan := make(chan error, 1) + go func() { + defer close(serverErrChan) + + conn, err := ln.Accept() + if err != nil { + serverErrChan <- err + return + } + defer conn.Close() + + err = conn.SetDeadline(time.Now().Add(59 * time.Second)) + if err != nil { + serverErrChan <- err + return + } + + err = script.Run(pgproto3.NewBackend(conn, conn)) + if err != nil { + serverErrChan <- err + return + } + + <-serverKeepAlive + }() + + parts := strings.Split(ln.Addr().String(), ":") + host := parts[0] + port := parts[1] + connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) + + ctx, cancel = context.WithTimeout(ctx, 59*time.Second) + defer cancel() + conn, err := pgconn.Connect(ctx, connStr) + require.NoError(t, err) + + pipeline := conn.StartPipeline(ctx) + pipeline.SendPrepare("s1", "select 1", nil) + pipeline.SendPrepare("s2", "select 2", nil) + pipeline.SendPrepare("s3", "select 3", nil) + err = pipeline.Sync() + require.NoError(t, err) + + _, err = pipeline.GetResults() + require.NoError(t, err) + _, err = pipeline.GetResults() + require.Error(t, err) + + err = pipeline.Close() + require.Error(t, err) +} + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} + +func TestDeadlineContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond} + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) +} + +func TestCancelRequestContextWatcherHandler(t *testing.T) { + t.Parallel() + + t.Run("DeadlineExceeded cancels request after CancelRequestDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 250 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(3)").ReadAll() + require.Error(t, err) + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded - do not send cancel request when query finishes in grace period", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 1000 * time.Millisecond, + DeadlineDelay: 5000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll() + require.NoError(t, err) + + ensureConnValid(t, pgConn) + }) + + t.Run("DeadlineExceeded sets conn deadline with DeadlineDelay", func(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5000 * time.Millisecond, // purposely setting this higher than DeadlineDelay to ensure the cancel request never happens. + DeadlineDelay: 250 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll() + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.True(t, pgConn.IsClosed()) + }) + + for i := 0; i < 10; i++ { + t.Run(fmt.Sprintf("Stress %d", i), func(t *testing.T) { + t.Parallel() + + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler { + return &pgconn.CancelRequestContextWatcherHandler{ + Conn: conn, + CancelRequestDelay: 5 * time.Millisecond, + DeadlineDelay: 1000 * time.Millisecond, + } + } + config.ConnectTimeout = 5 * time.Second + + pgConn, err := pgconn.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer closeConn(t, pgConn) + + for i := 0; i < 20; i++ { + func() { + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) + defer cancel() + pgConn.Exec(ctx, "select 1, pg_sleep(0.010)").ReadAll() + time.Sleep(100 * time.Millisecond) // ensure a cancel request that was a little late doesn't interrupt ensureConnValid. + ensureConnValid(t, pgConn) + }() + } + }) + } +} diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go index e4a9b8b43..baae6538c 100644 --- a/pgproto3/authentication_cleartext_password.go +++ b/pgproto3/authentication_cleartext_password.go @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go index 583026efb..731d8a9d6 100644 --- a/pgproto3/authentication_gss.go +++ b/pgproto3/authentication_gss.go @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error { return nil } -func (a *AuthenticationGSS) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 4) +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSS) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go index 035974545..c22da48d8 100644 --- a/pgproto3/authentication_gss_continue.go +++ b/pgproto3/authentication_gss_continue.go @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error { return nil } -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = append(dst, a.Data...) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go index 7eeeafb6b..3747f680b 100644 --- a/pgproto3/authentication_md5_password.go +++ b/pgproto3/authentication_md5_password.go @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 12) +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go index 6039204ad..881db3b8c 100644 --- a/pgproto3/authentication_ok.go +++ b/pgproto3/authentication_ok.go @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationOk) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeOk) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go index bcc6c2ddb..7ddbebe59 100644 --- a/pgproto3/authentication_sasl.go +++ b/pgproto3/authentication_sasl.go @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASL) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASL) for _, s := range src.AuthMechanisms { @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go index 910af1536..a9f217018 100644 --- a/pgproto3/authentication_sasl_continue.go +++ b/pgproto3/authentication_sasl_continue.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go index 4f3f917fd..6811ce442 100644 --- a/pgproto3/authentication_sasl_final.go +++ b/pgproto3/authentication_sasl_final.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Unmarshaler. diff --git a/pgproto3/backend.go b/pgproto3/backend.go index efa909c3a..28cff049a 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -16,7 +16,8 @@ type Backend struct { // before it is actually transmitted (i.e. before Flush). tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Frontend message flyweights bind Bind @@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. func (b *Backend) Send(msg BackendMessage) { + if b.encodeError != nil { + return + } + prevLen := len(b.wbuf) - b.wbuf = msg.Encode(b.wbuf) + newBuf, err := msg.Encode(b.wbuf) + if err != nil { + b.encodeError = err + return + } + b.wbuf = newBuf + if b.tracer != nil { b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } @@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) { // Flush writes any pending messages to the frontend (i.e. the client). func (b *Backend) Flush() error { + if err := b.encodeError; err != nil { + b.encodeError = nil + b.wbuf = b.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + n, err := b.w.Write(b.wbuf) const maxLen = 1024 @@ -158,7 +175,13 @@ func (b *Backend) Receive() (FrontendMessage, error) { } b.msgType = header[0] - b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 + + msgLength := int(binary.BigEndian.Uint32(header[1:])) + if msgLength < 4 { + return nil, fmt.Errorf("invalid message length: %d", msgLength) + } + + b.bodyLen = msgLength - 4 if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen { return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen} } @@ -265,9 +288,10 @@ func (b *Backend) SetAuthType(authType uint32) error { return nil } -// SetMaxBodyLen sets the maximum length of a message body in octets. If a message body exceeds this length, Receive will return -// an error. This is useful for protecting against malicious clients that send large messages with the intent of -// causing memory exhaustion. +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against malicious clients that send +// large messages with the intent of causing memory exhaustion. // The default value is 0. // If maxBodyLen is 0, then no maximum is enforced. func (b *Backend) SetMaxBodyLen(maxBodyLen int) { diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 704cda18a..b8bddfe27 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 8147accd4..dd37ab7d8 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) { "username": "tester", }, } - dst := []byte{} - dst = want.Encode(dst) + dst, err := want.Encode([]byte{}) + require.NoError(t, err) server := &interruptReader{} server.push(dst) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index db606cde5..1a53e70fa 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -108,21 +108,25 @@ func (dst *Bind) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -134,14 +138,15 @@ func (src *Bind) Encode(dst []byte) []byte { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 3be256c89..bacf30d88 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_test.go b/pgproto3/bind_test.go new file mode 100644 index 000000000..6ec0e0245 --- /dev/null +++ b/pgproto3/bind_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. + _, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go index c3e054037..074b0b246 100644 --- a/pgproto3/cancel_request.go +++ b/pgproto3/cancel_request.go @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *CancelRequest) Encode(dst []byte) []byte { +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close.go b/pgproto3/close.go index ce60b6ca5..0b50f27cb 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type Close struct { @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index 1d7b8f085..833f7a12c 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index a55fa3024..eba70947d 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type CommandComplete struct { @@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 5a17e910e..05dc1f2ac 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -44,19 +44,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go index f9cb40242..0d327a56a 100644 --- a/pgproto3/copy_both_response_test.go +++ b/pgproto3/copy_both_response_test.go @@ -3,8 +3,9 @@ package pgproto3_test import ( "testing" - "github.com/stretchr/testify/assert" "github.com/yugabyte/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecode(t *testing.T) { @@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) { err := dstResp.Decode(srcBytes[5:]) assert.NoError(t, err, "No errors on decode") dstBytes := []byte{} - dstBytes = dstResp.Encode(dstBytes) + dstBytes, err = dstResp.Encode(dstBytes) + require.NoError(t, err) assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") } diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index 6d236cfd4..89ecdd4dd 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type CopyData struct { @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go index 0e13282bf..c3421a9b5 100644 --- a/pgproto3/copy_done.go +++ b/pgproto3/copy_done.go @@ -4,8 +4,7 @@ import ( "encoding/json" ) -type CopyDone struct { -} +type CopyDone struct{} // Backend identifies this message as sendable by the PostgreSQL backend. func (*CopyDone) Backend() {} @@ -24,8 +23,8 @@ func (dst *CopyDone) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go index 868ba950c..72a85fd09 100644 --- a/pgproto3/copy_fail.go +++ b/pgproto3/copy_fail.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type CopyFail struct { @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'f') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') dst = append(dst, src.Message...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 2d462beaa..d0601f1eb 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -44,20 +44,19 @@ func (dst *CopyInResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index c88de1fb4..6851bc817 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -43,21 +43,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 7448f3e35..c1202226c 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { @@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/describe.go b/pgproto3/describe.go index 7ab567773..89feff215 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type Describe struct { @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index 2b85e744b..cb6cca073 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 45c9a9810..6ef9bd061 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -2,7 +2,6 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" "strconv" ) @@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) } if src.SeverityUnlocalized != "" { - buf.WriteByte('V') - buf.WriteString(src.SeverityUnlocalized) - buf.WriteByte(0) + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) - - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + dst = append(dst, 0) - return buf.Bytes() + return dst } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go index 6be76a726..47c9ac5b5 100644 --- a/pgproto3/example/pgfortune/server.go +++ b/pgproto3/example/pgfortune/server.go @@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error { return fmt.Errorf("error generating query response: %w", err) } - buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ { Name: []byte("fortune"), TableOID: 0, @@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error { TypeModifier: -1, Format: 0, }, - }}).Encode(nil) - buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) - buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + }}).Encode(nil)) + buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)) + buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error writing query response: %w", err) @@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error { switch startupMessage.(type) { case *pgproto3.StartupMessage: - buf := (&pgproto3.AuthenticationOk{}).Encode(nil) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error sending ready for query: %w", err) @@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error { func (p *PgFortuneBackend) Close() error { return p.conn.Close() } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go index 2d273f3c3..c73cbefa4 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/flush.go b/pgproto3/flush.go index 2725f6894..e5dc1fbbd 100644 --- a/pgproto3/flush.go +++ b/pgproto3/flush.go @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 60c34ef02..056e547cd 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -18,7 +18,8 @@ type Frontend struct { // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Backend message flyweights authenticationOk AuthenticationOk @@ -53,6 +54,7 @@ type Frontend struct { portalSuspended PortalSuspended bodyLen int + maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error. msgType byte partialMsg bool authType uint32 @@ -64,16 +66,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. // // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden // behind an interface. func (f *Frontend) Send(msg FrontendMessage) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } @@ -81,6 +93,12 @@ func (f *Frontend) Send(msg FrontendMessage) { // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if err := f.encodeError; err != nil { + f.encodeError = nil + f.wbuf = f.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + if len(f.wbuf) == 0 { return nil } @@ -116,71 +134,141 @@ func (f *Frontend) Untrace() { f.tracer = nil } -// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendBind(msg *Bind) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendParse(msg *Parse) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendClose(msg *Close) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is +// called. Any error encountered will be returned from Flush. func (f *Frontend) SendDescribe(msg *Describe) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. +// Any error encountered will be returned from Flush. func (f *Frontend) SendExecute(msg *Execute) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendSync(msg *Sync) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendQuery(msg *Query) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) } @@ -230,6 +318,9 @@ func (f *Frontend) Receive() (BackendMessage, error) { } f.bodyLen = msgLength - 4 + if f.maxBodyLen > 0 && f.bodyLen > f.maxBodyLen { + return nil, &ExceededMaxBodyLenErr{f.maxBodyLen, f.bodyLen} + } f.partialMsg = true } @@ -365,3 +456,13 @@ func (f *Frontend) GetAuthType() uint32 { func (f *Frontend) ReadBufferLen() int { return f.cr.wp - f.cr.rp } + +// SetMaxBodyLen sets the maximum length of a message body in octets. +// If a message body exceeds this length, Receive will return an error. +// This is useful for protecting against a corrupted server that sends +// messages with incorrect length, which can cause memory exhaustion. +// The default value is 0. +// If maxBodyLen is 0, then no maximum is enforced. +func (f *Frontend) SetMaxBodyLen(maxBodyLen int) { + f.maxBodyLen = maxBodyLen +} diff --git a/pgproto3/frontend_test.go b/pgproto3/frontend_test.go index 3468fe406..2c41424af 100644 --- a/pgproto3/frontend_test.go +++ b/pgproto3/frontend_test.go @@ -115,3 +115,21 @@ func TestErrorResponse(t *testing.T) { require.NoError(t, err) assert.Equal(t, want, got) } + +func TestFrontendReceiveExceededMaxBodyLen(t *testing.T) { + t.Parallel() + + client := &interruptReader{} + client.push([]byte{'D', 0, 0, 10, 10}) + + frontend := pgproto3.NewFrontend(client, nil) + + // Set max body len to 5 + frontend.SetMaxBodyLen(5) + + // Receive regular msg + msg, err := frontend.Receive() + assert.Nil(t, msg) + var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr + assert.ErrorAs(t, err, &invalidBodyLenErr) +} diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index f4f95d928..856b53666 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -2,6 +2,8 @@ package pgproto3 import ( "encoding/binary" + "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -71,15 +73,21 @@ func (dst *FunctionCall) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCall) Encode(dst []byte) []byte { - dst = append(dst, 'F') - sp := len(dst) - dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { dst = pgio.AppendUint16(dst, argFormatCode) } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { if argument == nil { @@ -90,6 +98,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte { } } dst = pgio.AppendUint16(dst, src.ResultFormatCode) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return dst + return finishMessage(dst, sp) } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 13e93dc1a..6411ec033 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/function_call_test.go b/pgproto3/function_call_test.go index 8c08bb240..2a70fd308 100644 --- a/pgproto3/function_call_test.go +++ b/pgproto3/function_call_test.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestFunctionCall_EncodeDecode(t *testing.T) { @@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { Arguments: tt.fields.Arguments, ResultFormatCode: tt.fields.ResultFormatCode, } - encoded := src.Encode([]byte{}) + encoded, err := src.Encode([]byte{}) + require.NoError(t, err) dst := &FunctionCall{} // Check the header msgTypeCode := encoded[0] @@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) } // Check decoding works as expected - err := dst.Decode(encoded[5:]) + err = dst.Decode(encoded[5:]) if err != nil { if !tt.wantErr { t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go index 641f485cc..ba6b002d3 100644 --- a/pgproto3/gss_enc_request.go +++ b/pgproto3/gss_enc_request.go @@ -10,8 +10,7 @@ import ( const gssEncReqNumber = 80877104 -type GSSEncRequest struct { -} +type GSSEncRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*GSSEncRequest) Frontend() {} @@ -31,10 +30,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *GSSEncRequest) Encode(dst []byte) []byte { +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, gssEncReqNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go index 6d79045b1..10d937759 100644 --- a/pgproto3/gss_response.go +++ b/pgproto3/gss_response.go @@ -2,8 +2,6 @@ package pgproto3 import ( "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type GSSResponse struct { @@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error { return nil } -func (g *GSSResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, g.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/json_test.go b/pgproto3/json_test.go index 8fad4f882..677221249 100644 --- a/pgproto3/json_test.go +++ b/pgproto3/json_test.go @@ -332,7 +332,7 @@ func TestJSONUnmarshalRowDescription(t *testing.T) { } func TestJSONUnmarshalBind(t *testing.T) { - var testCases = []struct { + testCases := []struct { desc string data []byte }{ @@ -348,7 +348,7 @@ func TestJSONUnmarshalBind(t *testing.T) { for _, tc := range testCases { t.Run(tc.desc, func(t *testing.T) { - var want = Bind{ + want := Bind{ PreparedStatement: "lrupsc_1_0", ParameterFormatCodes: []int16{0}, Parameters: [][]byte{[]byte("ABC-123")}, diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index d8f85d38a..cbcaad40c 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 4ac28a791..497aba6dd 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index a91ca868b..62e31b0d3 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index c660cc13e..29f493a80 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -39,19 +41,18 @@ func (dst *ParameterDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index 057719e08..9ee0720b5 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type ParameterStatus struct { @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parse.go b/pgproto3/parse.go index 5748ad08b..e921feef6 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -52,27 +54,25 @@ func (dst *Parse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } +<<<<<<< HEAD pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return dst -} - -// MarshalJSON implements encoding/json.Marshaler. func (src Parse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 92c9498b6..cff9e27d0 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index e935f9e41..67b78515d 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type PasswordMessage struct { @@ -14,7 +12,7 @@ type PasswordMessage struct { // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*PasswordMessage) Frontend() {} -// Frontend identifies this message as an authentication response. +// InitialResponse identifies this message as an authentication response. func (*PasswordMessage) InitialResponse() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 8df383c2c..b72ca39b6 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -4,8 +4,14 @@ import ( "encoding/hex" "errors" "fmt" + + "github.com/yugabyte/pgx/v5/internal/pgio" ) +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) + // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. type Message interface { @@ -14,7 +20,7 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } // FrontendMessage is a message sent by the frontend (i.e. the client). @@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) { } return nil, errors.New("unknown protocol representation") } + +// beginMessage begins a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil +} diff --git a/pgproto3/pgproto3_private_test.go b/pgproto3/pgproto3_private_test.go new file mode 100644 index 000000000..15da1eafb --- /dev/null +++ b/pgproto3/pgproto3_private_test.go @@ -0,0 +1,3 @@ +package pgproto3 + +const MaxMessageBodyLen = maxMessageBodyLen diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go index 1a9e7bfb1..9e2f8cbc4 100644 --- a/pgproto3/portal_suspended.go +++ b/pgproto3/portal_suspended.go @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PortalSuspended) Encode(dst []byte) []byte { - return append(dst, 's', 0, 0, 0, 4) +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query.go b/pgproto3/query.go index b805f59e3..aebdfde89 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type Query struct { @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query_test.go b/pgproto3/query_test.go new file mode 100644 index 000000000..9551fc14d --- /dev/null +++ b/pgproto3/query_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string. + _, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 67a39be39..a56af9fb2 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index 680476485..bfbbf0f82 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -54,7 +56,6 @@ func (*RowDescription) Backend() {} // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // type identifier and 4 byte message length. func (dst *RowDescription) Decode(src []byte) error { - if len(src) < 2 { return &invalidMessageFormatErr{messageType: "RowDescription"} } @@ -99,11 +100,12 @@ func (dst *RowDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...) @@ -117,9 +119,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go index 2eef83b44..9eb1b6a4b 100644 --- a/pgproto3/sasl_initial_response.go +++ b/pgproto3/sasl_initial_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/yugabyte/pgx/v5/internal/pgio" + "github.com/jackc/pgx/v5/internal/pgio" ) type SASLInitialResponse struct { @@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLInitialResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, 0) @@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go index 3cc055720..1b604c254 100644 --- a/pgproto3/sasl_response.go +++ b/pgproto3/sasl_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/yugabyte/pgx/v5/internal/pgio" ) type SASLResponse struct { @@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Data...) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go index 10794d7e6..95a114f69 100644 --- a/pgproto3/ssl_request.go +++ b/pgproto3/ssl_request.go @@ -10,8 +10,7 @@ import ( const sslRequestNumber = 80877103 -type SSLRequest struct { -} +type SSLRequest struct{} // Frontend identifies this message as sendable by a PostgreSQL frontend. func (*SSLRequest) Frontend() {} @@ -31,10 +30,10 @@ func (dst *SSLRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *SSLRequest) Encode(dst []byte) []byte { +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, sslRequestNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 06e808789..c141b27f4 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *StartupMessage) Encode(dst []byte) []byte { +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -77,13 +77,11 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) +<<<<<<< HEAD pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return dst -} - // MarshalJSON implements encoding/json.Marshaler. -func (src StartupMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProtocolVersion uint32 diff --git a/pgproto3/sync.go b/pgproto3/sync.go index 5db8e07ac..ea4fc9594 100644 --- a/pgproto3/sync.go +++ b/pgproto3/sync.go @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go index 135191eae..35a9dc837 100644 --- a/pgproto3/terminate.go +++ b/pgproto3/terminate.go @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgtype/array.go b/pgtype/array.go index f848ec143..916c0adb2 100644 --- a/pgtype/array.go +++ b/pgtype/array.go @@ -374,8 +374,8 @@ func quoteArrayElementIfNeeded(src string) string { return src } -// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves -// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed. +// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves +// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed. type Array[T any] struct { Elements []T Dims []ArrayDimension @@ -419,8 +419,8 @@ func (a Array[T]) ScanIndexType() any { return new(T) } -// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions -// and custom lower bounds. Use Array to preserve these. +// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions +// and custom lower bounds. Use [Array] to preserve these. type FlatArray[T any] []T func (a FlatArray[T]) Dimensions() []ArrayDimension { diff --git a/pgtype/array_codec.go b/pgtype/array_codec.go index a714b9757..b06bf8f88 100644 --- a/pgtype/array_codec.go +++ b/pgtype/array_codec.go @@ -6,7 +6,6 @@ import ( "fmt" "reflect" - "github.com/yugabyte/pgx/v5/internal/anynil" "github.com/yugabyte/pgx/v5/internal/pgio" ) @@ -230,7 +229,7 @@ func (c *ArrayCodec) PlanScan(m *Map, oid uint32, format int16, target any) Scan // target / arrayScanner might be a pointer to a nil. If it is create one so we can call ScanIndexType to plan the // scan of the elements. - if anynil.Is(target) { + if isNil, _ := isNilDriverValuer(target); isNil { arrayScanner = reflect.New(reflect.TypeOf(target).Elem()).Interface().(ArraySetter) } diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index fd67fcbaa..0019cb3d2 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -256,7 +256,10 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) require.NoError(t, err) diff --git a/pgtype/bits.go b/pgtype/bits.go index 3c523c0a3..02e4d787e 100644 --- a/pgtype/bits.go +++ b/pgtype/bits.go @@ -23,16 +23,18 @@ type Bits struct { Valid bool } +// ScanBits implements the [BitsScanner] interface. func (b *Bits) ScanBits(v Bits) error { *b = v return nil } +// BitsValue implements the [BitsValuer] interface. func (b Bits) BitsValue() (Bits, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bits) Scan(src any) error { if src == nil { *dst = Bits{} @@ -47,7 +49,7 @@ func (dst *Bits) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bits) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,6 @@ func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -176,8 +177,10 @@ func (scanPlanBinaryBitsToBitsScanner) Scan(src []byte, dst any) error { bitLen := int32(binary.BigEndian.Uint32(src)) rp := 4 + buf := make([]byte, len(src[rp:])) + copy(buf, src[rp:]) - return scanner.ScanBits(Bits{Bytes: src[rp:], Len: bitLen, Valid: true}) + return scanner.ScanBits(Bits{Bytes: buf, Len: bitLen, Valid: true}) } type scanPlanTextAnyToBitsScanner struct{} diff --git a/pgtype/bool.go b/pgtype/bool.go index 71caffa74..b74fe4414 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -22,16 +22,18 @@ type Bool struct { Valid bool } +// ScanBool implements the [BoolScanner] interface. func (b *Bool) ScanBool(v Bool) error { *b = v return nil } +// BoolValue implements the [BoolValuer] interface. func (b Bool) BoolValue() (Bool, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Bool) Scan(src any) error { if src == nil { *dst = Bool{} @@ -61,7 +63,7 @@ func (dst *Bool) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Bool) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +72,7 @@ func (src Bool) Value() (driver.Value, error) { return src.Bool, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Bool) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -82,6 +85,7 @@ func (src Bool) MarshalJSON() ([]byte, error) { } } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Bool) UnmarshalJSON(b []byte) error { var v *bool err := json.Unmarshal(b, &v) @@ -200,7 +204,10 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, } func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { @@ -328,7 +335,11 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{Bool: v, Valid: true}) } +<<<<<<< HEAD // https://www.postgresql.org/docs/11/datatype-boolean.html +======= +// https://www.postgresql.org/docs/current/datatype-boolean.html +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e func planTextToBool(src []byte) (bool, error) { s := string(bytes.ToLower(bytes.TrimSpace(src))) diff --git a/pgtype/box.go b/pgtype/box.go index d38180e9a..8f869744a 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -24,16 +24,18 @@ type Box struct { Valid bool } +// ScanBox implements the [BoxScanner] interface. func (b *Box) ScanBox(v Box) error { *b = v return nil } +// BoxValue implements the [BoxValuer] interface. func (b Box) BoxValue() (Box, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Box) Scan(src any) error { if src == nil { *dst = Box{} @@ -48,7 +50,7 @@ func (dst *Box) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Box) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -127,7 +129,10 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index b39d3fa10..248756d70 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -881,7 +881,6 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error return nil } - } func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { diff --git a/pgtype/bytea.go b/pgtype/bytea.go index a247705e9..6c4f0c5ea 100644 --- a/pgtype/bytea.go +++ b/pgtype/bytea.go @@ -148,7 +148,6 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf } func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index cc7f75931..86b8a12ff 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -92,7 +92,7 @@ func TestPreallocBytes(t *testing.T) { require.Len(t, buf, 2) require.Equal(t, 4, cap(buf)) - require.Equal(t, buf, []byte{1, 2}) + require.Equal(t, []byte{1, 2}, buf) require.Equal(t, []byte{1, 2, 7, 8}, origBuf) @@ -112,7 +112,7 @@ func TestUndecodedBytes(t *testing.T) { require.NoError(t, err) require.Len(t, buf, 4) - require.Equal(t, buf, []byte{0, 0, 0, 1}) + require.Equal(t, []byte{0, 0, 0, 1}, buf) }) } @@ -132,6 +132,6 @@ func TestByteaCodecDecodeDatabaseSQLValue(t *testing.T) { require.NoError(t, err) require.Len(t, buf, 4) - require.Equal(t, buf, []byte{0xa1, 0xb2, 0xc3, 0xd4}) + require.Equal(t, []byte{0xa1, 0xb2, 0xc3, 0xd4}, buf) }) } diff --git a/pgtype/circle.go b/pgtype/circle.go index b539dda49..72e7e9722 100644 --- a/pgtype/circle.go +++ b/pgtype/circle.go @@ -25,16 +25,18 @@ type Circle struct { Valid bool } +// ScanCircle implements the [CircleScanner] interface. func (c *Circle) ScanCircle(v Circle) error { *c = v return nil } +// CircleValue implements the [CircleValuer] interface. func (c Circle) CircleValue() (Circle, error) { return c, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Circle) Scan(src any) error { if src == nil { *dst = Circle{} @@ -49,7 +51,7 @@ func (dst *Circle) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Circle) Value() (driver.Value, error) { if !src.Valid { return nil, nil diff --git a/pgtype/composite.go b/pgtype/composite.go index d43f08c06..b3197359d 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -276,7 +276,10 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt default: return nil, fmt.Errorf("unknown format code %d", format) } +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } type CompositeBinaryScanner struct { diff --git a/pgtype/composite_test.go b/pgtype/composite_test.go index e25fe985b..ea2ecb916 100644 --- a/pgtype/composite_test.go +++ b/pgtype/composite_test.go @@ -5,16 +5,13 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/require" pgx "github.com/yugabyte/pgx/v5" "github.com/yugabyte/pgx/v5/pgtype" + "github.com/stretchr/testify/require" ) func TestCompositeCodecTranscode(t *testing.T) { - skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists ct_test; create type ct_test as ( @@ -91,10 +88,7 @@ func (p *point3d) ScanIndex(i int) any { } func TestCompositeCodecTranscodeStruct(t *testing.T) { - skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -128,10 +122,7 @@ create type point3d as ( } func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { - skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -169,10 +160,7 @@ create type point3d as ( } func TestCompositeCodecDecodeValue(t *testing.T) { - skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") - defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists point3d; create type point3d as ( @@ -214,10 +202,9 @@ create type point3d as ( // // https://github.com/jackc/pgx/issues/1576 func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) { - skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + skipCockroachDB(t, "Server does not support composite types from table definitions") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop table if exists point3d; create table point3d ( diff --git a/pgtype/date.go b/pgtype/date.go index 0e2579205..8824b50db 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -26,11 +26,13 @@ type Date struct { Valid bool } +// ScanDate implements the [DateScanner] interface. func (d *Date) ScanDate(v Date) error { *d = v return nil } +// DateValue implements the [DateValuer] interface. func (d Date) DateValue() (Date, error) { return d, nil } @@ -40,7 +42,7 @@ const ( infinityDayOffset = 2147483647 ) -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Date) Scan(src any) error { if src == nil { *dst = Date{} @@ -58,7 +60,7 @@ func (dst *Date) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Date) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -70,6 +72,7 @@ func (src Date) Value() (driver.Value, error) { return src.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Date) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -89,6 +92,7 @@ func (src Date) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Date) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -223,7 +227,10 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/derived_types_test.go b/pgtype/derived_types_test.go new file mode 100644 index 000000000..05d109bda --- /dev/null +++ b/pgtype/derived_types_test.go @@ -0,0 +1,61 @@ +package pgtype_test + +import ( + "context" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestDerivedTypes(t *testing.T) { + skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)") + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + _, err := conn.Exec(ctx, ` +drop type if exists dt_test; +drop domain if exists dt_uint64; + +create domain dt_uint64 as numeric(20,0); +create type dt_test as ( + a text, + b dt_uint64, + c dt_uint64[] +);`) + require.NoError(t, err) + defer conn.Exec(ctx, "drop domain dt_uint64") + defer conn.Exec(ctx, "drop type dt_test") + + dtypes, err := conn.LoadTypes(ctx, []string{"dt_test"}) + require.Len(t, dtypes, 6) + require.Equal(t, dtypes[0].Name, "public.dt_uint64") + require.Equal(t, dtypes[1].Name, "dt_uint64") + require.Equal(t, dtypes[2].Name, "public._dt_uint64") + require.Equal(t, dtypes[3].Name, "_dt_uint64") + require.Equal(t, dtypes[4].Name, "public.dt_test") + require.Equal(t, dtypes[5].Name, "dt_test") + require.NoError(t, err) + conn.TypeMap().RegisterTypes(dtypes) + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for _, format := range formats { + var a string + var b uint64 + var c *[]uint64 + + row := conn.QueryRow(ctx, "select $1::dt_test", pgx.QueryResultFormats{format.code}, pgtype.CompositeFields{"hi", uint64(42), []uint64{10, 20, 30}}) + err := row.Scan(pgtype.CompositeFields{&a, &b, &c}) + require.NoError(t, err) + require.EqualValuesf(t, "hi", a, "%v", format.name) + require.EqualValuesf(t, 42, b, "%v", format.name) + } + }) +} diff --git a/pgtype/doc.go b/pgtype/doc.go index ec9270acb..83dfc5de5 100644 --- a/pgtype/doc.go +++ b/pgtype/doc.go @@ -53,6 +53,9 @@ similar fashion to database/sql. The second is to use a pointer to a pointer. return err } +When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to +true, otherwise the parameter's value will be NULL. + JSON Support pgtype automatically marshals and unmarshals data from json and jsonb PostgreSQL types. @@ -139,6 +142,16 @@ Compatibility with database/sql pgtype also includes support for custom types implementing the database/sql.Scanner and database/sql/driver.Valuer interfaces. +Encoding Typed Nils + +pgtype encodes untyped and typed nils (e.g. nil and []byte(nil)) to the SQL NULL value without going through the Codec +system. This means that Codecs and other encoding logic do not have to handle nil or *T(nil). + +However, database/sql compatibility requires Value to be called on T(nil) when T implements driver.Valuer. Therefore, +driver.Valuer values are only considered NULL when *T(nil) where driver.Valuer is implemented on T not on *T. See +https://github.com/golang/go/issues/8415 and +https://github.com/golang/go/commit/0ce1d79a6a771f7449ec493b993ed2a720917870. + Child Records pgtype's support for arrays and composite records can be used to load records and their children in a single query. See @@ -146,11 +159,16 @@ example_child_records_test.go for an example. Overview of Scanning Implementation -The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID -from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for -scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are -interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and -PointValuer interfaces. +The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a +plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types +are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner +and PointValuer interfaces. + +If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that +interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and +numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be +parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with +a large number of affected values. If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again. For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that diff --git a/pgtype/enum_codec_test.go b/pgtype/enum_codec_test.go index bc0d96972..ac3f29295 100644 --- a/pgtype/enum_codec_test.go +++ b/pgtype/enum_codec_test.go @@ -10,7 +10,6 @@ import ( func TestEnumCodec(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists enum_test; create type enum_test as enum ('foo', 'bar', 'baz');`) @@ -47,7 +46,6 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) func TestEnumCodecValues(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { - _, err := conn.Exec(ctx, `drop type if exists enum_test; create type enum_test as enum ('foo', 'bar', 'baz');`) @@ -64,6 +62,6 @@ create type enum_test as enum ('foo', 'bar', 'baz');`) require.True(t, rows.Next()) values, err := rows.Values() require.NoError(t, err) - require.Equal(t, values, []any{"foo"}) + require.Equal(t, []any{"foo"}, values) }) } diff --git a/pgtype/float4.go b/pgtype/float4.go index f03d46af7..42db76d83 100644 --- a/pgtype/float4.go +++ b/pgtype/float4.go @@ -16,26 +16,29 @@ type Float4 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float4) ScanFloat64(n Float8) error { *f = Float4{Float32: float32(n.Float64), Valid: n.Valid} return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float4) Float64Value() (Float8, error) { return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float4) ScanInt64(n Int8) error { *f = Float4{Float32: float32(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float4) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float4) Scan(src any) error { if src == nil { *f = Float4{} @@ -58,7 +61,7 @@ func (f *Float4) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float4) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -66,6 +69,7 @@ func (f Float4) Value() (driver.Value, error) { return float64(f.Float32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float4) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -73,6 +77,7 @@ func (f Float4) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float32) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float4) UnmarshalJSON(b []byte) error { var n *float32 err := json.Unmarshal(b, &n) @@ -170,7 +175,6 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (new } func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -297,12 +301,12 @@ func (c Float4Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, sr return nil, nil } - var n float64 + var n float32 err := codecScan(c, m, oid, format, src, &n) if err != nil { return nil, err } - return n, nil + return float64(n), nil } func (c Float4Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { diff --git a/pgtype/float8.go b/pgtype/float8.go index 0eab7652c..be781a5e4 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -24,26 +24,29 @@ type Float8 struct { Valid bool } -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [Float64Scanner] interface. func (f *Float8) ScanFloat64(n Float8) error { *f = n return nil } +// Float64Value implements the [Float64Valuer] interface. func (f Float8) Float64Value() (Float8, error) { return f, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (f *Float8) ScanInt64(n Int8) error { *f = Float8{Float64: float64(n.Int64), Valid: n.Valid} return nil } +// Int64Value implements the [Int64Valuer] interface. func (f Float8) Int64Value() (Int8, error) { return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float8) Scan(src any) error { if src == nil { *f = Float8{} @@ -66,7 +69,7 @@ func (f *Float8) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float8) Value() (driver.Value, error) { if !f.Valid { return nil, nil @@ -74,6 +77,7 @@ func (f Float8) Value() (driver.Value, error) { return f.Float64, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (f Float8) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil @@ -81,6 +85,7 @@ func (f Float8) MarshalJSON() ([]byte, error) { return json.Marshal(f.Float64) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (f *Float8) UnmarshalJSON(b []byte) error { var n *float64 err := json.Unmarshal(b, &n) @@ -208,7 +213,10 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e } func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 54f8dbac3..4a19338b0 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -22,16 +22,18 @@ type HstoreValuer interface { // associated with its keys. type Hstore map[string]*string +// ScanHstore implements the [HstoreScanner] interface. func (h *Hstore) ScanHstore(v Hstore) error { *h = v return nil } +// HstoreValue implements the [HstoreValuer] interface. func (h Hstore) HstoreValue() (Hstore, error) { return h, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (h *Hstore) Scan(src any) error { if src == nil { *h = nil @@ -46,7 +48,7 @@ func (h *Hstore) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (h Hstore) Value() (driver.Value, error) { if h == nil { return nil, nil @@ -162,7 +164,10 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e } func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { @@ -298,7 +303,7 @@ func (p *hstoreParser) consume() (b byte, end bool) { return b, false } -func unexpectedByteErr(actualB byte, expectedB byte) error { +func unexpectedByteErr(actualB, expectedB byte) error { return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) } @@ -316,7 +321,7 @@ func (p *hstoreParser) consumeExpectedByte(expectedB byte) error { // consumeExpected2 consumes two expected bytes or returns an error. // This was a bit faster than using a string argument (better inlining? Not sure). -func (p *hstoreParser) consumeExpected2(one byte, two byte) error { +func (p *hstoreParser) consumeExpected2(one, two byte) error { if p.pos+2 > len(p.str) { return errors.New("unexpected end of string") } diff --git a/pgtype/hstore_test.go b/pgtype/hstore_test.go index de3e75d3a..7462fb7ad 100644 --- a/pgtype/hstore_test.go +++ b/pgtype/hstore_test.go @@ -306,12 +306,13 @@ func TestRoundTrip(t *testing.T) { }) } } - } func BenchmarkHstoreEncode(b *testing.B) { - h := pgtype.Hstore{"a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"), - "d": stringPtr("400"), "e": stringPtr("500")} + h := pgtype.Hstore{ + "a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"), + "d": stringPtr("400"), "e": stringPtr("500"), + } serializeConfigs := []struct { name string diff --git a/pgtype/inet.go b/pgtype/inet.go index 6ca10ea07..6363cf44b 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -24,7 +24,7 @@ type NetipPrefixValuer interface { NetipPrefixValue() (netip.Prefix, error) } -// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If +// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If // IsValid() is false then they are treated as SQL NULL. type InetCodec struct{} @@ -107,7 +107,10 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/int.go b/pgtype/int.go index 62403dada..40fe27b28 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/int.go.erb +// Code generated from pgtype/int.go.erb. DO NOT EDIT. + package pgtype import ( @@ -25,7 +26,7 @@ type Int2 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int2) ScanInt64(n Int8) error { if !n.Valid { *dst = Int2{} @@ -43,11 +44,12 @@ func (dst *Int2) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int2) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int2) Scan(src any) error { if src == nil { *dst = Int2{} @@ -86,7 +88,7 @@ func (dst *Int2) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int2) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -94,6 +96,7 @@ func (src Int2) Value() (driver.Value, error) { return int64(src.Int16), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int2) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -101,6 +104,7 @@ func (src Int2) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int2) UnmarshalJSON(b []byte) error { var n *int16 err := json.Unmarshal(b, &n) @@ -585,7 +589,7 @@ type Int4 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int4) ScanInt64(n Int8) error { if !n.Valid { *dst = Int4{} @@ -603,11 +607,12 @@ func (dst *Int4) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int4) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int4) Scan(src any) error { if src == nil { *dst = Int4{} @@ -646,7 +651,7 @@ func (dst *Int4) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int4) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -654,6 +659,7 @@ func (src Int4) Value() (driver.Value, error) { return int64(src.Int32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int4) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -661,6 +667,7 @@ func (src Int4) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int4) UnmarshalJSON(b []byte) error { var n *int32 err := json.Unmarshal(b, &n) @@ -1156,7 +1163,7 @@ type Int8 struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int8) ScanInt64(n Int8) error { if !n.Valid { *dst = Int8{} @@ -1174,11 +1181,12 @@ func (dst *Int8) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int8) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int8) Scan(src any) error { if src == nil { *dst = Int8{} @@ -1217,7 +1225,7 @@ func (dst *Int8) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int8) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -1225,6 +1233,7 @@ func (src Int8) Value() (driver.Value, error) { return int64(src.Int64), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int8) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -1232,6 +1241,7 @@ func (src Int8) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int8) UnmarshalJSON(b []byte) error { var n *int64 err := json.Unmarshal(b, &n) diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index cc675b903..93ebca530 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -27,7 +27,7 @@ type Int<%= pg_byte_size %> struct { Valid bool } -// ScanInt64 implements the Int64Scanner interface. +// ScanInt64 implements the [Int64Scanner] interface. func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { if !n.Valid { *dst = Int<%= pg_byte_size %>{} @@ -45,11 +45,12 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = Int<%= pg_byte_size %>{} @@ -88,7 +89,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +97,7 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { return int64(src.Int<%= pg_bit_size %>), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -103,6 +105,7 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { var n *int<%= pg_bit_size %> err := json.Unmarshal(b, &n) diff --git a/pgtype/int_test.go b/pgtype/int_test.go index e8d53ffc3..e1ee05cdd 100644 --- a/pgtype/int_test.go +++ b/pgtype/int_test.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/int_test.go.erb +// Code generated from pgtype/int_test.go.erb. DO NOT EDIT. + package pgtype_test import ( diff --git a/pgtype/integration_benchmark_test.go b/pgtype/integration_benchmark_test.go index 9715a51cf..f1775ec87 100644 --- a/pgtype/integration_benchmark_test.go +++ b/pgtype/integration_benchmark_test.go @@ -1,3 +1,5 @@ +// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT. + package pgtype_test import ( diff --git a/pgtype/integration_benchmark_test.go.erb b/pgtype/integration_benchmark_test.go.erb index b17e4db63..00c53d766 100644 --- a/pgtype/integration_benchmark_test.go.erb +++ b/pgtype/integration_benchmark_test.go.erb @@ -25,7 +25,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go rows, _ := conn.Query( ctx, `select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil }) if err != nil { @@ -49,7 +49,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array rows, _ := conn.Query( ctx, `select array_agg(n) from generate_series(1, <%= array_size %>) n`, - []any{pgx.QueryResultFormats{<%= format_code %>}}, + pgx.QueryResultFormats{<%= format_code %>}, ) _, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil }) if err != nil { diff --git a/pgtype/interval.go b/pgtype/interval.go index 124eea3e0..d71f34107 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -33,16 +33,18 @@ type Interval struct { Valid bool } +// ScanInterval implements the [IntervalScanner] interface. func (interval *Interval) ScanInterval(v Interval) error { *interval = v return nil } +// IntervalValue implements the [IntervalValuer] interface. func (interval Interval) IntervalValue() (Interval, error) { return interval, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (interval *Interval) Scan(src any) error { if src == nil { *interval = Interval{} @@ -57,7 +59,7 @@ func (interval *Interval) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (interval Interval) Value() (driver.Value, error) { if !interval.Valid { return nil, nil @@ -144,15 +146,19 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, hours := absMicroseconds / microsecondsPerHour minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - microseconds := absMicroseconds % microsecondsPerSecond - timeStr := fmt.Sprintf("%02d:%02d:%02d.%06d", hours, minutes, seconds, microseconds) + timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) buf = append(buf, timeStr...) + + microseconds := absMicroseconds % microsecondsPerSecond + if microseconds != 0 { + buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) + } + return buf, nil } func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index 2e7afa5b1..ca2d90fe8 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -136,3 +136,22 @@ func TestIntervalCodec(t *testing.T) { {nil, new(pgtype.Interval), isExpectedEq(pgtype.Interval{})}, }) } + +func TestIntervalTextEncode(t *testing.T) { + m := pgtype.NewMap() + + successfulTests := []struct { + source pgtype.Interval + result string + }{ + {source: pgtype.Interval{Months: 2, Days: 1, Microseconds: 0, Valid: true}, result: "2 mon 1 day 00:00:00"}, + {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 0, Valid: true}, result: "00:00:00"}, + {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 6 * 60 * 1000000, Valid: true}, result: "00:06:00"}, + {source: pgtype.Interval{Months: 0, Days: 1, Microseconds: 6*60*1000000 + 30, Valid: true}, result: "1 day 00:06:00.000030"}, + } + for i, tt := range successfulTests { + buf, err := m.Encode(pgtype.DateOID, pgtype.TextFormatCode, tt.source, nil) + assert.NoErrorf(t, err, "%d", i) + assert.Equalf(t, tt.result, string(buf), "%d", i) + } +} diff --git a/pgtype/json.go b/pgtype/json.go index 3f1a750f6..60aa2b71d 100644 --- a/pgtype/json.go +++ b/pgtype/json.go @@ -8,28 +8,36 @@ import ( "reflect" ) -type JSONCodec struct{} +type JSONCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONCodec) FormatSupported(format int16) bool { +func (*JSONCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONCodec) PreferredFormat() int16 { +func (*JSONCodec) PreferredFormat() int16 { return TextFormatCode } -func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch value.(type) { case string: return encodePlanJSONCodecEitherFormatString{} case []byte: return encodePlanJSONCodecEitherFormatByteSlice{} + // Handle json.RawMessage specifically because if it is run through json.Marshal it may be mutated. + // e.g. `{"foo": "bar"}` -> `{"foo":"bar"}`. + case json.RawMessage: + return encodePlanJSONCodecEitherFormatJSONRawMessage{} + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. // // https://github.com/jackc/pgx/issues/1430 // - // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to beused + // Check for driver.Valuer must come before json.Marshaler so that it is guaranteed to be used // when both are implemented https://github.com/jackc/pgx/issues/1805 case driver.Valuer: return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} @@ -39,7 +47,9 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod // // https://github.com/jackc/pgx/issues/1681 case json.Marshaler: - return encodePlanJSONCodecEitherFormatMarshal{} + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } } // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the @@ -56,7 +66,30 @@ func (c JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Encod } } - return encodePlanJSONCodecEitherFormatMarshal{} + return &encodePlanJSONCodecEitherFormatMarshal{ + marshal: c.Marshal, + } +} + +// JSON needs its on scan plan for pointers to handle 'null'::json(b). +// Consider making pointerPointerScanPlan more flexible in the future. +type jsonPointerScanPlan struct { + next ScanPlan +} + +func (p jsonPointerScanPlan) Scan(src []byte, dst any) error { + el := reflect.ValueOf(dst).Elem() + if src == nil || string(src) == "null" { + el.SetZero() + return nil + } + + el.Set(reflect.New(el.Type().Elem())) + if p.next != nil { + return p.next.Scan(src, el.Interface()) + } + + return nil } type encodePlanJSONCodecEitherFormatString struct{} @@ -79,10 +112,24 @@ func (encodePlanJSONCodecEitherFormatByteSlice) Encode(value any, buf []byte) (n return buf, nil } -type encodePlanJSONCodecEitherFormatMarshal struct{} +type encodePlanJSONCodecEitherFormatJSONRawMessage struct{} + +func (encodePlanJSONCodecEitherFormatJSONRawMessage) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes := value.(json.RawMessage) + if jsonBytes == nil { + return nil, nil + } + + buf = append(buf, jsonBytes...) + return buf, nil +} + +type encodePlanJSONCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} -func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { - jsonBytes, err := json.Marshal(value) +func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + jsonBytes, err := e.marshal(value) if err != nil { return nil, err } @@ -91,40 +138,36 @@ func (encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (new return buf, nil } -func (JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan { + return c.planScan(m, oid, formatCode, target, 0) +} + +// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b), +// so we need to duplicate the logic here. +func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + switch target.(type) { case *string: - return scanPlanAnyToString{} - - case **string: - // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better - // solution would be. - // - // https://github.com/jackc/pgx/issues/1470 -- **string - // https://github.com/jackc/pgx/issues/1691 -- ** anything else - - if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { - if nextPlan := m.planScan(oid, format, nextDst); nextPlan != nil { - if _, failed := nextPlan.(*scanPlanFail); !failed { - wrapperPlan.SetNext(nextPlan) - return wrapperPlan - } - } - } - + return &scanPlanAnyToString{} case *[]byte: - return scanPlanJSONToByteSlice{} + return &scanPlanJSONToByteSlice{} case BytesScanner: - return scanPlanBinaryBytesToBytesScanner{} - - // Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence. - // - // https://github.com/jackc/pgx/issues/1418 + return &scanPlanBinaryBytesToBytesScanner{} case sql.Scanner: - return &scanPlanSQLScanner{formatCode: format} + return &scanPlanSQLScanner{formatCode: formatCode} } - return scanPlanJSONToJSONUnmarshal{} + rv := reflect.ValueOf(target) + if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer { + var plan jsonPointerScanPlan + plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1) + return plan + } else { + return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal} + } } type scanPlanAnyToString struct{} @@ -149,16 +192,11 @@ func (scanPlanJSONToByteSlice) Scan(src []byte, dst any) error { return nil } -type scanPlanJSONToBytesScanner struct{} - -func (scanPlanJSONToBytesScanner) Scan(src []byte, dst any) error { - scanner := (dst).(BytesScanner) - return scanner.ScanBytes(src) +type scanPlanJSONToJSONUnmarshal struct { + unmarshal func(data []byte, v any) error } -type scanPlanJSONToJSONUnmarshal struct{} - -func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { +func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { if src == nil { dstValue := reflect.ValueOf(dst) if dstValue.Kind() == reflect.Ptr { @@ -173,13 +211,18 @@ func (scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error { return fmt.Errorf("cannot scan NULL into %T", dst) } - elem := reflect.ValueOf(dst).Elem() + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Pointer || v.IsNil() { + return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst) + } + + elem := v.Elem() elem.Set(reflect.Zero(elem.Type())) - return json.Unmarshal(src, dst) + return s.unmarshal(src, dst) } -func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -189,12 +232,12 @@ func (c JSONCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src return dstBuf, nil } -func (c JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/json_test.go b/pgtype/json_test.go index 65c0826db..d25fc7969 100644 --- a/pgtype/json_test.go +++ b/pgtype/json_test.go @@ -6,11 +6,14 @@ import ( "database/sql/driver" "encoding/json" "errors" + "fmt" + "reflect" "testing" - "github.com/stretchr/testify/require" pgx "github.com/yugabyte/pgx/v5" + "github.com/yugabyte/pgx/v5/pgtype" "github.com/yugabyte/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) func isExpectedEqMap(a any) func(any) bool { @@ -46,6 +49,7 @@ func TestJSONCodec(t *testing.T) { Age int `json:"age"` } + var str string pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{ {nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))}, {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, @@ -61,6 +65,11 @@ func TestJSONCodec(t *testing.T) { // Test driver.Valuer is used before json.Marshaler (https://github.com/jackc/pgx/issues/1805) {Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))}, + // Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146) + {Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))}, + + // Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204) + {NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }}, }) pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ @@ -107,6 +116,52 @@ func (i Issue1805) MarshalJSON() ([]byte, error) { return nil, errors.New("MarshalJSON called") } +type Issue2146 int + +func (i *Issue2146) Scan(src any) error { + var source []byte + switch src.(type) { + case string: + source = []byte(src.(string)) + case []byte: + source = src.([]byte) + default: + return errors.New("unknown source type") + } + var newI int + if err := json.Unmarshal(source, &newI); err != nil { + return err + } + *i = Issue2146(newI + 1) + return nil +} + +func (i Issue2146) Value() (driver.Value, error) { + b, err := json.Marshal(int(i - 1)) + return string(b), err +} + +type NonPointerJSONScanner struct { + V *string +} + +func (i NonPointerJSONScanner) Scan(src any) error { + switch c := src.(type) { + case string: + *i.V = c + case []byte: + *i.V = string(c) + default: + return errors.New("unknown source type") + } + + return nil +} + +func (i NonPointerJSONScanner) Value() (driver.Value, error) { + return i.V, nil +} + // https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 func TestJSONCodecUnmarshalSQLNull(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -137,11 +192,15 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) { // A string cannot scan a NULL. str := "foobar" err = conn.QueryRow(ctx, "select null::json").Scan(&str) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") + fieldName := "json" + if conn.PgConn().ParameterStatus("crdb_version") != "" { + fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb. + } + require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *string", fieldName)) // A non-string cannot scan a NULL. err = conn.QueryRow(ctx, "select null::json").Scan(&n) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int") + require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *int", fieldName)) }) } @@ -224,3 +283,80 @@ func TestJSONCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { require.Equal(t, `{"custom":"thing"}`, jsonStr) }) } + +func TestJSONCodecCustomMarshal(t *testing.T) { + skipCockroachDB(t, "CockroachDB treats json as jsonb. This causes it to format differently than PostgreSQL.") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "json", OID: pgtype.JSONOID, Codec: &pgtype.JSONCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{ + // There is no space between "custom" and "value" in json type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom":"value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) +} + +func TestJSONCodecScanToNonPointerValues(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + n := 44 + err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n) + require.Error(t, err) + + var i *int + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i) + require.Error(t, err) + + m := 0 + err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m) + require.NoError(t, err) + require.Equal(t, 42, m) + }) +} + +func TestJSONCodecScanNull(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest struct{} + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot scan NULL into *struct {}") + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest) + require.NoError(t, err) + + var destPointer *struct{} + err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + + err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer) + require.NoError(t, err) + require.Nil(t, destPointer) + + var raw json.RawMessage + require.NoError(t, conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&raw)) + require.Equal(t, json.RawMessage("null"), raw) + }) +} + +func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var dest *Issue2146 + err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest) + require.NoError(t, err) + require.Nil(t, dest) + }) +} diff --git a/pgtype/jsonb.go b/pgtype/jsonb.go index 25555e7ff..4d4eb58e5 100644 --- a/pgtype/jsonb.go +++ b/pgtype/jsonb.go @@ -2,29 +2,31 @@ package pgtype import ( "database/sql/driver" - "encoding/json" "fmt" ) -type JSONBCodec struct{} +type JSONBCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} -func (JSONBCodec) FormatSupported(format int16) bool { +func (*JSONBCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (JSONBCodec) PreferredFormat() int16 { +func (*JSONBCodec) PreferredFormat() int16 { return TextFormatCode } -func (JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (c *JSONBCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanEncode(m, oid, TextFormatCode, value) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, TextFormatCode, value) if plan != nil { return &encodePlanJSONBCodecBinaryWrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanEncode(m, oid, format, value) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanEncode(m, oid, format, value) } return nil @@ -39,15 +41,15 @@ func (plan *encodePlanJSONBCodecBinaryWrapper) Encode(value any, buf []byte) (ne return plan.textPlan.Encode(value, buf) } -func (JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +func (c *JSONBCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: - plan := JSONCodec{}.PlanScan(m, oid, TextFormatCode, target) + plan := (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, TextFormatCode, target) if plan != nil { return &scanPlanJSONBCodecBinaryUnwrapper{textPlan: plan} } case TextFormatCode: - return JSONCodec{}.PlanScan(m, oid, format, target) + return (&JSONCodec{Marshal: c.Marshal, Unmarshal: c.Unmarshal}).PlanScan(m, oid, format, target) } return nil @@ -73,7 +75,7 @@ func (plan *scanPlanJSONBCodecBinaryUnwrapper) Scan(src []byte, dst any) error { return plan.textPlan.Scan(src[1:], dst) } -func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -100,7 +102,7 @@ func (c JSONBCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src } } -func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } @@ -122,6 +124,6 @@ func (c JSONBCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (a } var dst any - err := json.Unmarshal(src, &dst) + err := c.Unmarshal(src, &dst) return dst, err } diff --git a/pgtype/jsonb_test.go b/pgtype/jsonb_test.go index 1c6336e53..b9b6f79ae 100644 --- a/pgtype/jsonb_test.go +++ b/pgtype/jsonb_test.go @@ -2,11 +2,14 @@ package pgtype_test import ( "context" + "encoding/json" + "reflect" "testing" - "github.com/stretchr/testify/require" pgx "github.com/yugabyte/pgx/v5" + "github.com/yugabyte/pgx/v5/pgtype" "github.com/yugabyte/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) func TestJSONBTranscode(t *testing.T) { @@ -63,11 +66,11 @@ func TestJSONBCodecUnmarshalSQLNull(t *testing.T) { // A string cannot scan a NULL. str := "foobar" err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string") + require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *string") // A non-string cannot scan a NULL. err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n) - require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int") + require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *int") }) } @@ -80,3 +83,27 @@ func TestJSONBCodecEncodeJSONMarshalerThatCanBeWrapped(t *testing.T) { require.Equal(t, `{"custom": "thing"}`, jsonStr) // Note that unlike json, jsonb reformats the JSON string. }) } + +func TestJSONBCodecCustomMarshal(t *testing.T) { + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "jsonb", OID: pgtype.JSONBOID, Codec: &pgtype.JSONBCodec{ + Marshal: func(v any) ([]byte, error) { + return []byte(`{"custom":"value"}`), nil + }, + Unmarshal: func(data []byte, v any) error { + return json.Unmarshal([]byte(`{"custom":"value"}`), v) + }, + }, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ + // There is space between "custom" and "value" in jsonb type. + {map[string]any{"something": "else"}, new(string), isExpectedEq(`{"custom": "value"}`)}, + {[]byte(`{"something":"else"}`), new(map[string]any), func(v any) bool { + return reflect.DeepEqual(v, map[string]any{"custom": "value"}) + }}, + }) +} diff --git a/pgtype/line.go b/pgtype/line.go index 4438dc625..77e993ab0 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -24,11 +24,10 @@ type Line struct { Valid bool } -func (line *Line) ScanLine(v Line) error { - *line = v +// ScanLine implements the [LineScanner] interface. return nil -} +// LineValue implements the [LineValuer] interface. func (line Line) LineValue() (Line, error) { return line, nil } @@ -37,7 +36,7 @@ func (line *Line) Set(src any) error { return fmt.Errorf("cannot convert %v to Line", src) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (line *Line) Scan(src any) error { if src == nil { *line = Line{} @@ -52,7 +51,7 @@ func (line *Line) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (line Line) Value() (driver.Value, error) { if !line.Valid { return nil, nil @@ -129,7 +128,10 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/lseg.go b/pgtype/lseg.go index c0b4ef093..2f22f9126 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -24,16 +24,18 @@ type Lseg struct { Valid bool } +// ScanLseg implements the [LsegScanner] interface. func (lseg *Lseg) ScanLseg(v Lseg) error { *lseg = v return nil } +// LsegValue implements the [LsegValuer] interface. func (lseg Lseg) LsegValue() (Lseg, error) { return lseg, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (lseg *Lseg) Scan(src any) error { if src == nil { *lseg = Lseg{} @@ -48,7 +50,7 @@ func (lseg *Lseg) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (lseg Lseg) Value() (driver.Value, error) { if !lseg.Valid { return nil, nil @@ -127,7 +129,10 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/macaddr_test.go b/pgtype/macaddr_test.go index 074f1f8ea..4fcfcfdd1 100644 --- a/pgtype/macaddr_test.go +++ b/pgtype/macaddr_test.go @@ -48,4 +48,23 @@ func TestMacaddrCodec(t *testing.T) { }, {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "macaddr8", []pgxtest.ValueRoundTripTest{ + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, + { + "01:23:45:67:89:ab:01:08", + new(net.HardwareAddr), + isExpectedEqHardwareAddr(mustParseMacaddr(t, "01:23:45:67:89:ab:01:08")), + }, + { + mustParseMacaddr(t, "01:23:45:67:89:ab:01:08"), + new(string), + isExpectedEq("01:23:45:67:89:ab:01:08"), + }, + {nil, new(*net.HardwareAddr), isExpectedEq((*net.HardwareAddr)(nil))}, + }) } diff --git a/pgtype/multirange.go b/pgtype/multirange.go index 15a68d521..ef454eb61 100644 --- a/pgtype/multirange.go +++ b/pgtype/multirange.go @@ -374,7 +374,6 @@ parseValueLoop: } return elements, nil - } func parseRange(buf *bytes.Buffer) (string, error) { @@ -403,8 +402,8 @@ func parseRange(buf *bytes.Buffer) (string, error) { // Multirange is a generic multirange type. // -// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to -// enforce the RangeScanner constraint. +// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to +// enforce the [RangeScanner] constraint. type Multirange[T RangeValuer] []T func (r Multirange[T]) IsNull() bool { diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go index c08e3a432..13f120025 100644 --- a/pgtype/multirange_test.go +++ b/pgtype/multirange_test.go @@ -71,7 +71,10 @@ func TestMultirangeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for _, tt := range []struct { sql string expected any diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 2377ec15f..80efbe90c 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -27,16 +27,20 @@ const ( pgNumericNegInfSign = 0xf000 ) -var big0 *big.Int = big.NewInt(0) -var big1 *big.Int = big.NewInt(1) -var big10 *big.Int = big.NewInt(10) -var big100 *big.Int = big.NewInt(100) -var big1000 *big.Int = big.NewInt(1000) +var ( + big0 *big.Int = big.NewInt(0) + big1 *big.Int = big.NewInt(1) + big10 *big.Int = big.NewInt(10) + big100 *big.Int = big.NewInt(100) + big1000 *big.Int = big.NewInt(1000) +) -var bigNBase *big.Int = big.NewInt(nbase) -var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) -var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) -var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +var ( + bigNBase *big.Int = big.NewInt(nbase) + bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) + bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) + bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase) +) type NumericScanner interface { ScanNumeric(v Numeric) error @@ -54,15 +58,18 @@ type Numeric struct { Valid bool } +// ScanNumeric implements the [NumericScanner] interface. func (n *Numeric) ScanNumeric(v Numeric) error { *n = v return nil } +// NumericValue implements the [NumericValuer] interface. func (n Numeric) NumericValue() (Numeric, error) { return n, nil } +// Float64Value implements the [Float64Valuer] interface. func (n Numeric) Float64Value() (Float8, error) { if !n.Valid { return Float8{}, nil @@ -92,6 +99,7 @@ func (n Numeric) Float64Value() (Float8, error) { return Float8{Float64: f, Valid: true}, nil } +// ScanInt64 implements the [Int64Scanner] interface. func (n *Numeric) ScanInt64(v Int8) error { if !v.Valid { *n = Numeric{} @@ -102,6 +110,7 @@ func (n *Numeric) ScanInt64(v Int8) error { return nil } +// Int64Value implements the [Int64Valuer] interface. func (n Numeric) Int64Value() (Int8, error) { if !n.Valid { return Int8{}, nil @@ -203,7 +212,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) { return accum, rp, digits } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (n *Numeric) Scan(src any) error { if src == nil { *n = Numeric{} @@ -218,7 +227,7 @@ func (n *Numeric) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (n Numeric) Value() (driver.Value, error) { if !n.Valid { return nil, nil @@ -231,6 +240,7 @@ func (n Numeric) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (n Numeric) MarshalJSON() ([]byte, error) { if !n.Valid { return []byte("null"), nil @@ -243,6 +253,7 @@ func (n Numeric) MarshalJSON() ([]byte, error) { return n.numberTextBytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (n *Numeric) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte(`null`)) { *n = Numeric{} @@ -553,7 +564,10 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { } func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index 56f61403b..f201f8606 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -198,7 +198,10 @@ func TestNumericMarshalJSON(t *testing.T) { skipCockroachDB(t, "server formats numeric text format differently") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for i, tt := range []struct { decString string }{ diff --git a/pgtype/path.go b/pgtype/path.go index 1771c68b2..c2fe94ad6 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -25,16 +25,18 @@ type Path struct { Valid bool } +// ScanPath implements the [PathScanner] interface. func (path *Path) ScanPath(v Path) error { *path = v return nil } +// PathValue implements the [PathValuer] interface. func (path Path) PathValue() (Path, error) { return path, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (path *Path) Scan(src any) error { if src == nil { *path = Path{} @@ -49,7 +51,7 @@ func (path *Path) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (path Path) Value() (driver.Value, error) { if !path.Valid { return nil, nil @@ -154,7 +156,10 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 08833f876..56ef19e02 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -26,7 +26,10 @@ const ( XIDOID = 28 CIDOID = 29 JSONOID = 114 + XMLOID = 142 + XMLArrayOID = 143 JSONArrayOID = 199 + XID8ArrayOID = 271 PointOID = 600 LsegOID = 601 PathOID = 602 @@ -41,6 +44,7 @@ const ( CircleOID = 718 CircleArrayOID = 719 UnknownOID = 705 + Macaddr8OID = 774 MacaddrOID = 829 InetOID = 869 BoolArrayOID = 1000 @@ -114,6 +118,7 @@ const ( TstzmultirangeOID = 4534 DatemultirangeOID = 4535 Int8multirangeOID = 4536 + XID8OID = 5069 Int4multirangeArrayOID = 6150 NummultirangeArrayOID = 6151 TsmultirangeArrayOID = 6152 @@ -197,7 +202,6 @@ type Map struct { reflectTypeToType map[reflect.Type]*Type - memoizedScanPlans map[uint32]map[reflect.Type][2]ScanPlan memoizedEncodePlans map[uint32]map[reflect.Type][2]EncodePlan // TryWrapEncodePlanFuncs is a slice of functions that will wrap a value that cannot be encoded by the Codec. Every @@ -213,6 +217,15 @@ type Map struct { TryWrapScanPlanFuncs []TryWrapScanPlanFunc } +// Copy returns a new Map containing the same registered types. +func (m *Map) Copy() *Map { + newMap := NewMap() + for _, type_ := range m.oidToType { + newMap.RegisterType(type_) + } + return newMap +} + func NewMap() *Map { defaultMapInitOnce.Do(initDefaultMap) @@ -222,7 +235,6 @@ func NewMap() *Map { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ @@ -247,6 +259,13 @@ func NewMap() *Map { } } +// RegisterTypes registers multiple data types in the sequence they are provided. +func (m *Map) RegisterTypes(types []*Type) { + for _, t := range types { + m.RegisterType(t) + } +} + // RegisterType registers a data type with the Map. t must not be mutated after it is registered. func (m *Map) RegisterType(t *Type) { m.oidToType[t.OID] = t @@ -255,9 +274,6 @@ func (m *Map) RegisterType(t *Type) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } @@ -271,9 +287,6 @@ func (m *Map) RegisterDefaultPgType(value any, name string) { // Invalidated by type registration m.reflectTypeToType = nil - for k := range m.memoizedScanPlans { - delete(m.memoizedScanPlans, k) - } for k := range m.memoizedEncodePlans { delete(m.memoizedEncodePlans, k) } @@ -430,14 +443,14 @@ func (plan *scanPlanFail) Scan(src []byte, dst any) error { // As a horrible hack try all types to find anything that can scan into dst. for oid := range plan.m.oidToType { // using planScan instead of Scan or PlanScan to avoid polluting the planned scan cache. - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } } for oid := range defaultMap.oidToType { if _, ok := plan.m.oidToType[oid]; !ok { - plan := plan.m.planScan(oid, plan.formatCode, dst) + plan := plan.m.planScan(oid, plan.formatCode, dst, 0) if _, ok := plan.(*scanPlanFail); !ok { return plan.Scan(src, dst) } @@ -554,17 +567,24 @@ func TryFindUnderlyingTypeScanPlan(dst any) (plan WrappedScanPlanNextSetter, nex elemValue = dstValue.Elem() } nextDstType := elemKindToPointerTypes[elemValue.Kind()] - if nextDstType == nil && elemValue.Kind() == reflect.Slice { - if elemValue.Type().Elem().Kind() == reflect.Uint8 { - var v *[]byte - nextDstType = reflect.TypeOf(v) + if nextDstType == nil { + if elemValue.Kind() == reflect.Slice { + if elemValue.Type().Elem().Kind() == reflect.Uint8 { + var v *[]byte + nextDstType = reflect.TypeOf(v) + } + } + + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if elemValue.Kind() == reflect.Array { + nextDstType = reflect.PointerTo(reflect.ArrayOf(elemValue.Len(), elemValue.Type().Elem())) } } - if nextDstType != nil && dstValue.Type() != nextDstType { + if nextDstType != nil && dstValue.Type() != nextDstType && dstValue.CanConvert(nextDstType) { return &underlyingTypeScanPlan{dstType: dstValue.Type(), nextDstType: nextDstType}, dstValue.Convert(nextDstType).Interface(), true } - } return nil, nil, false @@ -1038,24 +1058,14 @@ func (plan *wrapPtrArrayReflectScanPlan) Scan(src []byte, target any) error { // PlanScan prepares a plan to scan a value into target. func (m *Map) PlanScan(oid uint32, formatCode int16, target any) ScanPlan { - oidMemo := m.memoizedScanPlans[oid] - if oidMemo == nil { - oidMemo = make(map[reflect.Type][2]ScanPlan) - m.memoizedScanPlans[oid] = oidMemo - } - targetReflectType := reflect.TypeOf(target) - typeMemo := oidMemo[targetReflectType] - plan := typeMemo[formatCode] - if plan == nil { - plan = m.planScan(oid, formatCode, target) - typeMemo[formatCode] = plan - oidMemo[targetReflectType] = typeMemo - } - - return plan + return m.planScan(oid, formatCode, target, 0) } -func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { +func (m *Map) planScan(oid uint32, formatCode int16, target any, depth int) ScanPlan { + if depth > 8 { + return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} + } + if target == nil { return &scanPlanFail{m: m, oid: oid, formatCode: formatCode} } @@ -1115,7 +1125,7 @@ func (m *Map) planScan(oid uint32, formatCode int16, target any) ScanPlan { for _, f := range m.TryWrapScanPlanFuncs { if wrapperPlan, nextDst, ok := f(target); ok { - if nextPlan := m.planScan(oid, formatCode, nextDst); nextPlan != nil { + if nextPlan := m.planScan(oid, formatCode, nextDst, depth+1); nextPlan != nil { if _, failed := nextPlan.(*scanPlanFail); !failed { wrapperPlan.SetNext(nextPlan) return wrapperPlan @@ -1172,9 +1182,18 @@ func codecDecodeToTextFormat(codec Codec, m *Map, oid uint32, format int16, src } } -// PlanEncode returns an Encode plan for encoding value into PostgreSQL format for oid and format. If no plan can be +// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be // found then nil is returned. func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { + return m.planEncodeDepth(oid, format, value, 0) +} + +func (m *Map) planEncodeDepth(oid uint32, format int16, value any, depth int) EncodePlan { + // Guard against infinite recursion. + if depth > 8 { + return nil + } + oidMemo := m.memoizedEncodePlans[oid] if oidMemo == nil { oidMemo = make(map[reflect.Type][2]EncodePlan) @@ -1184,7 +1203,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { typeMemo := oidMemo[targetReflectType] plan := typeMemo[format] if plan == nil { - plan = m.planEncode(oid, format, value) + plan = m.planEncode(oid, format, value, depth) typeMemo[format] = plan oidMemo[targetReflectType] = typeMemo } @@ -1192,7 +1211,7 @@ func (m *Map) PlanEncode(oid uint32, format int16, value any) EncodePlan { return plan } -func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { +func (m *Map) planEncode(oid uint32, format int16, value any, depth int) EncodePlan { if format == TextFormatCode { switch value.(type) { case string: @@ -1223,7 +1242,7 @@ func (m *Map) planEncode(oid uint32, format int16, value any) EncodePlan { for _, f := range m.TryWrapEncodePlanFuncs { if wrapperPlan, nextValue, ok := f(value); ok { - if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil { + if nextPlan := m.planEncodeDepth(oid, format, nextValue, depth+1); nextPlan != nil { wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1330,7 +1349,7 @@ func (plan *derefPointerEncodePlan) Encode(value any, buf []byte) (newBuf []byte } // TryWrapDerefPointerEncodePlan tries to dereference a pointer. e.g. If value was of type *string then a wrapper plan -// would be returned that derefences the value. +// would be returned that dereferences the value. func TryWrapDerefPointerEncodePlan(value any) (plan WrappedEncodePlanNextSetter, nextValue any, ok bool) { if _, ok := value.(driver.Valuer); ok { return nil, nil, false @@ -1404,6 +1423,15 @@ func TryWrapFindUnderlyingTypeEncodePlan(value any) (plan WrappedEncodePlanNextS return &underlyingTypeEncodePlan{nextValueType: byteSliceType}, refValue.Convert(byteSliceType).Interface(), true } + // Get underlying type of any array. + // https://github.com/jackc/pgx/issues/2107 + if refValue.Kind() == reflect.Array { + underlyingArrayType := reflect.ArrayOf(refValue.Len(), refValue.Type().Elem()) + if refValue.Type() != underlyingArrayType { + return &underlyingTypeEncodePlan{nextValueType: underlyingArrayType}, refValue.Convert(underlyingArrayType).Interface(), true + } + } + return nil, nil, false } @@ -1911,8 +1939,17 @@ func newEncodeError(value any, m *Map, oid uint32, formatCode int16, err error) // (nil, nil). The caller of Encode is responsible for writing the correct NULL value or the length of the data // written. func (m *Map) Encode(oid uint32, formatCode int16, value any, buf []byte) (newBuf []byte, err error) { - if value == nil { - return nil, nil + if isNil, callNilDriverValuer := isNilDriverValuer(value); isNil { + if callNilDriverValuer { + newBuf, err = (&encodePlanDriverValuer{m: m, oid: oid, formatCode: formatCode}).Encode(value, buf) + if err != nil { + return nil, newEncodeError(value, m, oid, formatCode, err) + } + + return newBuf, nil + } else { + return nil, nil + } } plan := m.PlanEncode(oid, formatCode, value) @@ -1967,3 +2004,36 @@ func (w *sqlScannerWrapper) Scan(src any) error { return w.m.Scan(t.OID, TextFormatCode, bufSrc, w.v) } + +var valuerReflectType = reflect.TypeFor[driver.Valuer]() + +// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement +// driver.Valuer if it is only implemented by T. +func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) { + if value == nil { + return true, false + } + + refVal := reflect.ValueOf(value) + kind := refVal.Kind() + switch kind { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice: + if !refVal.IsNil() { + return false, false + } + + if _, ok := value.(driver.Valuer); ok { + if kind == reflect.Ptr { + // The type assertion will succeed if driver.Valuer is implemented on T or *T. Check if it is implemented on *T + // by checking if it is not implemented on *T. + return true, !refVal.Type().Elem().Implements(valuerReflectType) + } else { + return true, true + } + } + + return true, false + default: + return false, false + } +} diff --git a/pgtype/pgtype_default.go b/pgtype/pgtype_default.go index 58f4b92c7..5648d89bf 100644 --- a/pgtype/pgtype_default.go +++ b/pgtype/pgtype_default.go @@ -1,6 +1,8 @@ package pgtype import ( + "encoding/json" + "encoding/xml" "net" "net/netip" "reflect" @@ -21,7 +23,6 @@ func initDefaultMap() { reflectTypeToName: make(map[reflect.Type]string), oidToFormatCode: make(map[uint32]int16), - memoizedScanPlans: make(map[uint32]map[reflect.Type][2]ScanPlan), memoizedEncodePlans: make(map[uint32]map[reflect.Type][2]EncodePlan), TryWrapEncodePlanFuncs: []TryWrapEncodePlanFunc{ @@ -64,11 +65,12 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "int4", OID: Int4OID, Codec: Int4Codec{}}) defaultMap.RegisterType(&Type{Name: "int8", OID: Int8OID, Codec: Int8Codec{}}) defaultMap.RegisterType(&Type{Name: "interval", OID: IntervalOID, Codec: IntervalCodec{}}) - defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: JSONCodec{}}) - defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: JSONBCodec{}}) + defaultMap.RegisterType(&Type{Name: "json", OID: JSONOID, Codec: &JSONCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) + defaultMap.RegisterType(&Type{Name: "jsonb", OID: JSONBOID, Codec: &JSONBCodec{Marshal: json.Marshal, Unmarshal: json.Unmarshal}}) defaultMap.RegisterType(&Type{Name: "jsonpath", OID: JSONPathOID, Codec: &TextFormatOnlyCodec{TextCodec{}}}) defaultMap.RegisterType(&Type{Name: "line", OID: LineOID, Codec: LineCodec{}}) defaultMap.RegisterType(&Type{Name: "lseg", OID: LsegOID, Codec: LsegCodec{}}) + defaultMap.RegisterType(&Type{Name: "macaddr8", OID: Macaddr8OID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "macaddr", OID: MacaddrOID, Codec: MacaddrCodec{}}) defaultMap.RegisterType(&Type{Name: "name", OID: NameOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "numeric", OID: NumericOID, Codec: NumericCodec{}}) @@ -80,13 +82,33 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}}) defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: TimestampCodec{}}) - defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: TimestamptzCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}}) + defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}}) defaultMap.RegisterType(&Type{Name: "unknown", OID: UnknownOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "uuid", OID: UUIDOID, Codec: UUIDCodec{}}) defaultMap.RegisterType(&Type{Name: "varbit", OID: VarbitOID, Codec: BitsCodec{}}) defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}}) defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}}) + defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}}) + defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{ + Marshal: xml.Marshal, + // xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a + // *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing + // directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML + // unmarshaler that does support unmarshalling into *any. + // + // https://github.com/jackc/pgx/issues/2227 + // https://github.com/jackc/pgx/pull/2228 + Unmarshal: func(data []byte, v any) error { + if v, ok := v.(*any); ok { + dstBuf := make([]byte, len(data)) + copy(dstBuf, data) + *v = dstBuf + return nil + } + return xml.Unmarshal(data, v) + }, + }}) // Range types defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}}) @@ -151,6 +173,8 @@ func initDefaultMap() { defaultMap.RegisterType(&Type{Name: "_varbit", OID: VarbitArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarbitOID]}}) defaultMap.RegisterType(&Type{Name: "_varchar", OID: VarcharArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[VarcharOID]}}) defaultMap.RegisterType(&Type{Name: "_xid", OID: XIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XIDOID]}}) + defaultMap.RegisterType(&Type{Name: "_xid8", OID: XID8ArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XID8OID]}}) + defaultMap.RegisterType(&Type{Name: "_xml", OID: XMLArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[XMLOID]}}) // Integer types that directly map to a PostgreSQL type registerDefaultPgTypeVariants[int16](defaultMap, "int2") @@ -173,6 +197,7 @@ func initDefaultMap() { registerDefaultPgTypeVariants[time.Time](defaultMap, "timestamptz") registerDefaultPgTypeVariants[time.Duration](defaultMap, "interval") registerDefaultPgTypeVariants[string](defaultMap, "text") + registerDefaultPgTypeVariants[json.RawMessage](defaultMap, "json") registerDefaultPgTypeVariants[[]byte](defaultMap, "bytea") registerDefaultPgTypeVariants[net.IP](defaultMap, "inet") diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 64ee40395..f63330e44 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "os" + "reflect" "regexp" "strconv" "testing" @@ -33,16 +34,19 @@ func init() { } // Test for renamed types -type _string string -type _bool bool -type _int8 int8 -type _int16 int16 -type _int16Slice []int16 -type _int32Slice []int32 -type _int64Slice []int64 -type _float32Slice []float32 -type _float64Slice []float64 -type _byteSlice []byte +type ( + _string string + _bool bool + _uint8 uint8 + _int8 int8 + _int16 int16 + _int16Slice []int16 + _int32Slice []int32 + _int64Slice []int64 + _float32Slice []float32 + _float64Slice []float64 + _byteSlice []byte +) // unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type // is not known (e.g. when using the simple protocol). @@ -308,7 +312,7 @@ func TestPointerPointerStructScan(t *testing.T) { plan := m.PlanScan(pgt.OID, pgtype.TextFormatCode, &c) err := plan.Scan([]byte("(1)"), &c) require.NoError(t, err) - require.Equal(t, c.ID, 1) + require.Equal(t, 1, c.ID) } // https://github.com/jackc/pgx/issues/1263 @@ -453,6 +457,14 @@ func TestMapScanNullToWrongType(t *testing.T) { assert.False(t, pn.Valid) } +func TestScanToSliceOfRenamedUint8(t *testing.T) { + m := pgtype.NewMap() + var ruint8 []_uint8 + err := m.Scan(pgtype.Int2ArrayOID, pgx.TextFormatCode, []byte("{2,4}"), &ruint8) + assert.NoError(t, err) + assert.Equal(t, []_uint8{2, 4}, ruint8) +} + func TestMapScanTextToBool(t *testing.T) { tests := []struct { name string @@ -520,7 +532,8 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) { 0, 0, 0, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 16, - 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0} + 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + } m := pgtype.NewMap() buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode, @@ -537,6 +550,29 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) { require.Error(t, err) } +// https://github.com/jackc/pgx/issues/1763 +func TestMapEncodeRawJSONIntoUnknownOID(t *testing.T) { + m := pgtype.NewMap() + buf, err := m.Encode(0, pgtype.TextFormatCode, json.RawMessage(`{"foo": "bar"}`), nil) + require.NoError(t, err) + require.Equal(t, []byte(`{"foo": "bar"}`), buf) +} + +// PlanScan previously used a cache to improve performance. However, the cache could get confused in certain cases. The +// example below was one such failure case. +func TestCachedPlanScanConfusion(t *testing.T) { + m := pgtype.NewMap() + var err error + + var tags any + err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{foo,bar,baz}"), &tags) + require.NoError(t, err) + + var cells [][]string + err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{{foo,bar},{baz,quz}}"), &cells) + require.NoError(t, err) +} + func BenchmarkMapScanInt4IntoBinaryDecoder(b *testing.B) { m := pgtype.NewMap() src := []byte{0, 0, 0, 42} @@ -614,3 +650,10 @@ func isExpectedEq(a any) func(any) bool { return a == v } } + +func isPtrExpectedEq(a any) func(any) bool { + return func(v any) bool { + val := reflect.ValueOf(v) + return a == val.Elem().Interface() + } +} diff --git a/pgtype/point.go b/pgtype/point.go index c717de844..40275c1b0 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -30,11 +30,13 @@ type Point struct { Valid bool } +// ScanPoint implements the [PointScanner] interface. func (p *Point) ScanPoint(v Point) error { *p = v return nil } +// PointValue implements the [PointValuer] interface. func (p Point) PointValue() (Point, error) { return p, nil } @@ -68,7 +70,7 @@ func parsePoint(src []byte) (*Point, error) { return &Point{P: Vec2{x, y}, Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Point) Scan(src any) error { if src == nil { *dst = Point{} @@ -83,7 +85,7 @@ func (dst *Point) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Point) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,7 @@ func (src Point) Value() (driver.Value, error) { return string(buf), err } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Point) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +111,7 @@ func (src Point) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Point) UnmarshalJSON(point []byte) error { p, err := parsePoint(point) if err != nil { @@ -178,7 +182,10 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er } func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/polygon.go b/pgtype/polygon.go index b19cf33a4..f8df718e0 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -24,16 +24,18 @@ type Polygon struct { Valid bool } +// ScanPolygon implements the [PolygonScanner] interface. func (p *Polygon) ScanPolygon(v Polygon) error { *p = v return nil } +// PolygonValue implements the [PolygonValuer] interface. func (p Polygon) PolygonValue() (Polygon, error) { return p, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (p *Polygon) Scan(src any) error { if src == nil { *p = Polygon{} @@ -48,7 +50,7 @@ func (p *Polygon) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (p Polygon) Value() (driver.Value, error) { if !p.Valid { return nil, nil @@ -139,7 +141,10 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, } func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/range.go b/pgtype/range.go index 16427cccd..a73d4d7f9 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -191,11 +191,13 @@ type untypedBinaryRange struct { // 18 = [ = 10010 // 24 = = 11000 -const emptyMask = 1 -const lowerInclusiveMask = 2 -const upperInclusiveMask = 4 -const lowerUnboundedMask = 8 -const upperUnboundedMask = 16 +const ( + emptyMask = 1 + lowerInclusiveMask = 2 + upperInclusiveMask = 4 + lowerUnboundedMask = 8 + upperUnboundedMask = 16 +) func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { ubr := &untypedBinaryRange{} @@ -273,7 +275,10 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { } return ubr, nil +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } // Range is a generic range type. diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index efd5866c1..f004fba3a 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -75,7 +75,10 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e var r pgtype.Range[pgtype.Int4] err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) @@ -129,7 +132,10 @@ func TestRangeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for _, tt := range []struct { sql string expected any diff --git a/pgtype/record_codec.go b/pgtype/record_codec.go index b3b166045..90b9bd4bb 100644 --- a/pgtype/record_codec.go +++ b/pgtype/record_codec.go @@ -121,5 +121,4 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an default: return nil, fmt.Errorf("unknown format code %d", format) } - } diff --git a/pgtype/text.go b/pgtype/text.go index 021ee331b..e08b12549 100644 --- a/pgtype/text.go +++ b/pgtype/text.go @@ -19,16 +19,18 @@ type Text struct { Valid bool } +// ScanText implements the [TextScanner] interface. func (t *Text) ScanText(v Text) error { *t = v return nil } +// TextValue implements the [TextValuer] interface. func (t Text) TextValue() (Text, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Text) Scan(src any) error { if src == nil { *dst = Text{} @@ -47,7 +49,7 @@ func (dst *Text) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Text) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -55,6 +57,7 @@ func (src Text) Value() (driver.Value, error) { return src.String, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src Text) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -63,6 +66,7 @@ func (src Text) MarshalJSON() ([]byte, error) { return json.Marshal(src.String) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *Text) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -146,7 +150,6 @@ func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byt } func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case TextFormatCode, BinaryFormatCode: switch target.(type) { diff --git a/pgtype/tid.go b/pgtype/tid.go index a65bbe434..03a4ce714 100644 --- a/pgtype/tid.go +++ b/pgtype/tid.go @@ -35,16 +35,18 @@ type TID struct { Valid bool } +// ScanTID implements the [TIDScanner] interface. func (b *TID) ScanTID(v TID) error { *b = v return nil } +// TIDValue implements the [TIDValuer] interface. func (b TID) TIDValue() (TID, error) { return b, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *TID) Scan(src any) error { if src == nil { *dst = TID{} @@ -59,7 +61,7 @@ func (dst *TID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src TID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -131,7 +133,6 @@ func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/time.go b/pgtype/time.go index a8f62407a..8a303b9aa 100644 --- a/pgtype/time.go +++ b/pgtype/time.go @@ -19,24 +19,28 @@ type TimeValuer interface { // Time represents the PostgreSQL time type. The PostgreSQL time is a time of day without time zone. // -// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time -// and date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due -// to needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +// Time is represented as the number of microseconds since midnight in the same way that PostgreSQL does. Other time and +// date types in pgtype can use time.Time as the underlying representation. However, pgtype.Time type cannot due to +// needing to handle 24:00:00. time.Time converts that to 00:00:00 on the following day. +// +// The time with time zone type is not supported. Use of time with time zone is discouraged by the PostgreSQL documentation. type Time struct { Microseconds int64 // Number of microseconds since midnight Valid bool } +// ScanTime implements the [TimeScanner] interface. func (t *Time) ScanTime(v Time) error { *t = v return nil } +// TimeValue implements the [TimeValuer] interface. func (t Time) TimeValue() (Time, error) { return t, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (t *Time) Scan(src any) error { if src == nil { *t = Time{} @@ -45,13 +49,18 @@ func (t *Time) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + err := scanPlanTextAnyToTimeScanner{}.Scan([]byte(src), t) + if err != nil { + t.Microseconds = 0 + t.Valid = false + } + return err } return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (t Time) Value() (driver.Value, error) { if !t.Valid { return nil, nil @@ -130,12 +139,13 @@ func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { case TimeScanner: return scanPlanBinaryTimeToTimeScanner{} + case TextScanner: + return scanPlanBinaryTimeToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -165,6 +175,34 @@ func (scanPlanBinaryTimeToTimeScanner) Scan(src []byte, dst any) error { return scanner.ScanTime(Time{Microseconds: usec, Valid: true}) } +type scanPlanBinaryTimeToTextScanner struct{} + +func (scanPlanBinaryTimeToTextScanner) Scan(src []byte, dst any) error { + ts, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return ts.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for time: %v", len(src)) + } + + usec := int64(binary.BigEndian.Uint64(src)) + + tim := Time{Microseconds: usec, Valid: true} + + buf, err := TimeCodec{}.PlanEncode(nil, 0, TextFormatCode, tim).Encode(tim, nil) + if err != nil { + return err + } + + return ts.ScanText(Text{String: string(buf), Valid: true}) +} + type scanPlanTextAnyToTimeScanner struct{} func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { @@ -176,7 +214,7 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { s := string(src) - if len(s) < 8 { + if len(s) < 8 || s[2] != ':' || s[5] != ':' { return fmt.Errorf("cannot decode %v into Time", s) } @@ -199,6 +237,10 @@ func (scanPlanTextAnyToTimeScanner) Scan(src []byte, dst any) error { usec += seconds * microsecondsPerSecond if len(s) > 9 { + if s[8] != '.' || len(s) > 15 { + return fmt.Errorf("cannot decode %v into Time", s) + } + fraction := s[9:] n, err := strconv.ParseInt(fraction, 10, 64) if err != nil { diff --git a/pgtype/time_test.go b/pgtype/time_test.go index 4a1da4422..5e990b720 100644 --- a/pgtype/time_test.go +++ b/pgtype/time_test.go @@ -2,11 +2,13 @@ package pgtype_test import ( "context" + "strconv" "testing" "time" "github.com/yugabyte/pgx/v5/pgtype" "github.com/yugabyte/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" ) func TestTimeCodec(t *testing.T) { @@ -45,3 +47,69 @@ func TestTimeCodec(t *testing.T) { {nil, new(pgtype.Time), isExpectedEq(pgtype.Time{})}, }) } + +func TestTimeTextScanner(t *testing.T) { + var pgTime pgtype.Time + + assert.NoError(t, pgTime.Scan("07:37:16")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(7*time.Hour+37*time.Minute+16*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + assert.NoError(t, pgTime.Scan("15:04:05")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + // parsing of fractional digits + assert.NoError(t, pgTime.Scan("15:04:05.00")) + assert.Equal(t, true, pgTime.Valid) + assert.Equal(t, int64(15*time.Hour+4*time.Minute+5*time.Second), pgTime.Microseconds*int64(time.Microsecond)) + + const mirco = "789123" + const woFraction = int64(4*time.Hour + 5*time.Minute + 6*time.Second) // time without fraction + for i := 0; i <= len(mirco); i++ { + assert.NoError(t, pgTime.Scan("04:05:06."+mirco[:i])) + assert.Equal(t, true, pgTime.Valid) + + frac, _ := strconv.ParseInt(mirco[:i], 10, 64) + for k := i; k < 6; k++ { + frac *= 10 + } + assert.Equal(t, woFraction+frac*int64(time.Microsecond), pgTime.Microseconds*int64(time.Microsecond)) + } + + // parsing of too long fraction errors + assert.Error(t, pgTime.Scan("04:05:06.7891234")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of timetz errors + assert.Error(t, pgTime.Scan("04:05:06.789-08")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("04:05:06-08:00")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of date errors + assert.Error(t, pgTime.Scan("1997-12-17")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + // parsing of text errors + assert.Error(t, pgTime.Scan("12345678")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12:34-56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) + + assert.Error(t, pgTime.Scan("12-34:56")) + assert.Equal(t, false, pgTime.Valid) + assert.Equal(t, int64(0), pgTime.Microseconds) +} diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index d95b31471..8bf3babc3 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -11,7 +11,10 @@ import ( "github.com/yugabyte/pgx/v5/internal/pgio" ) -const pgTimestampFormat = "2006-01-02 15:04:05.999999999" +const ( + pgTimestampFormat = "2006-01-02 15:04:05.999999999" + jsonISO8601 = "2006-01-02T15:04:05.999999999" +) type TimestampScanner interface { ScanTimestamp(v Timestamp) error @@ -28,16 +31,18 @@ type Timestamp struct { Valid bool } +// ScanTimestamp implements the [TimestampScanner] interface. func (ts *Timestamp) ScanTimestamp(v Timestamp) error { *ts = v return nil } +// TimestampValue implements the [TimestampValuer] interface. func (ts Timestamp) TimestampValue() (Timestamp, error) { return ts, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} @@ -46,7 +51,7 @@ func (ts *Timestamp) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestampToTimestampScanner{}.Scan([]byte(src), ts) + return (&scanPlanTextTimestampToTimestampScanner{}).Scan([]byte(src), ts) case time.Time: *ts = Timestamp{Time: src, Valid: true} return nil @@ -55,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (ts Timestamp) Value() (driver.Value, error) { if !ts.Valid { return nil, nil @@ -67,6 +72,7 @@ func (ts Timestamp) Value() (driver.Value, error) { return ts.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (ts Timestamp) MarshalJSON() ([]byte, error) { if !ts.Valid { return []byte("null"), nil @@ -76,7 +82,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { switch ts.InfinityModifier { case Finite: - s = ts.Time.Format(time.RFC3339Nano) + s = ts.Time.Format(jsonISO8601) case Infinity: s = "infinity" case NegativeInfinity: @@ -86,6 +92,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (ts *Timestamp) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -104,29 +111,41 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error { case "-infinity": *ts = Timestamp{Valid: true, InfinityModifier: -Infinity} default: - // PostgreSQL uses ISO 8601 for to_json function and casting from a string to timestamptz - tim, err := time.Parse(time.RFC3339Nano, *s) - if err != nil { - return err + // Parse time with or without timezonr + tss := *s + // PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt + tim, err := time.Parse(time.RFC3339Nano, tss) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil } - - *ts = Timestamp{Time: tim, Valid: true} + tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC) + if err == nil { + *ts = Timestamp{Time: tim, Valid: true} + return nil + } + ts.Valid = false + return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)", + *s, time.RFC3339Nano, jsonISO8601, err) } - return nil } -type TimestampCodec struct{} +type TimestampCodec struct { + // ScanLocation is the location that the time is assumed to be in for scanning. This is different from + // TimestamptzCodec.ScanLocation in that this setting does change the instant in time that the timestamp represents. + ScanLocation *time.Location +} -func (TimestampCodec) FormatSupported(format int16) bool { +func (*TimestampCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestampCodec) PreferredFormat() int16 { +func (*TimestampCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestampCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestampValuer); !ok { return nil } @@ -220,27 +239,26 @@ func discardTimeZone(t time.Time) time.Time { return t } -func (TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - +func (c *TimestampCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanBinaryTimestampToTimestampScanner{} + return &scanPlanBinaryTimestampToTimestampScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestampScanner: - return scanPlanTextTimestampToTimestampScanner{} + return &scanPlanTextTimestampToTimestampScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestampToTimestampScanner struct{} +type scanPlanBinaryTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -264,15 +282,18 @@ func (scanPlanBinaryTimestampToTimestampScanner) Scan(src []byte, dst any) error microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ).UTC() + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -type scanPlanTextTimestampToTimestampScanner struct{} +type scanPlanTextTimestampToTimestampScanner struct{ location *time.Location } -func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestampScanner) if src == nil { @@ -302,13 +323,17 @@ func (scanPlanTextTimestampToTimestampScanner) Scan(src []byte, dst any) error { tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = time.Date(tim.Year(), tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), plan.location) + } + ts = Timestamp{Time: tim, Valid: true} } return scanner.ScanTimestamp(ts) } -func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -326,7 +351,7 @@ func (c TimestampCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, return ts.Time, nil } -func (c TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestampCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamp_test.go b/pgtype/timestamp_test.go index e3aff3291..dab0987b6 100644 --- a/pgtype/timestamp_test.go +++ b/pgtype/timestamp_test.go @@ -2,13 +2,15 @@ package pgtype_test import ( "context" + "encoding/json" "testing" "time" - "github.com/stretchr/testify/require" pgx "github.com/yugabyte/pgx/v5" "github.com/yugabyte/pgx/v5/pgtype" "github.com/yugabyte/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestTimestampCodec(t *testing.T) { @@ -38,6 +40,42 @@ func TestTimestampCodec(t *testing.T) { }) } +func TestTimestampCodecWithScanLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.UTC}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + // Have to use pgtype.Timestamp instead of time.Time as source because otherwise the simple and exec query exec + // modes will encode the time for timestamptz. That is, they will convert it from local time zone. + {pgtype.Timestamp{Time: time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), Valid: true}, new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) +} + +func TestTimestampCodecWithScanLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamp", + OID: pgtype.TimestampOID, + Codec: &pgtype.TimestampCodec{ScanLocation: time.Local}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamp", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestampTranscodeBigTimeBinary(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { @@ -64,13 +102,23 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) { } func TestTimestampMarshalJSON(t *testing.T) { + tsStruct := struct { + TS pgtype.Timestamp `json:"ts"` + }{} + + tm := time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC) + tsString := "\"" + tm.Format("2006-01-02T15:04:05") + "\"" // `"2012-03-29T10:05:45"` + var pgt pgtype.Timestamp + _ = pgt.Scan(tm) + successfulTests := []struct { source pgtype.Timestamp result string }{ {source: pgtype.Timestamp{}, result: "null"}, - {source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45Z\""}, - {source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45.555Z\""}, + {source: pgtype.Timestamp{Time: tm, Valid: true}, result: tsString}, + {source: pgt, result: tsString}, + {source: pgtype.Timestamp{Time: tm.Add(time.Second * 555 / 1000), Valid: true}, result: `"2012-03-29T10:05:45.555"`}, {source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""}, {source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""}, } @@ -80,20 +128,40 @@ func TestTimestampMarshalJSON(t *testing.T) { t.Errorf("%d: %v", i, err) } - if string(r) != tt.result { + if !assert.Equal(t, tt.result, string(r)) { t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r)) } + tsStruct.TS = tt.source + b, err := json.Marshal(tsStruct) + assert.NoErrorf(t, err, "failed to marshal %v %s", tt.source, err) + t2 := tsStruct + t2.TS = pgtype.Timestamp{} // Clear out the value so that we can compare after unmarshalling + err = json.Unmarshal(b, &t2) + assert.NoErrorf(t, err, "failed to unmarshal %v with %s", tt.source, err) + assert.True(t, tsStruct.TS.Time.Unix() == t2.TS.Time.Unix()) } } +func TestTimestampUnmarshalJSONErrors(t *testing.T) { + tsStruct := struct { + TS pgtype.Timestamp `json:"ts"` + }{} + goodJson1 := []byte(`{"ts":"2012-03-29T10:05:45"}`) + assert.NoError(t, json.Unmarshal(goodJson1, &tsStruct)) + goodJson2 := []byte(`{"ts":"2012-03-29T10:05:45Z"}`) + assert.NoError(t, json.Unmarshal(goodJson2, &tsStruct)) + badJson := []byte(`{"ts":"2012-03-29"}`) + assert.Error(t, json.Unmarshal(badJson, &tsStruct)) +} + func TestTimestampUnmarshalJSON(t *testing.T) { successfulTests := []struct { source string result pgtype.Timestamp }{ {source: "null", result: pgtype.Timestamp{}}, - {source: "\"2012-03-29T10:05:45Z\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}}, - {source: "\"2012-03-29T10:05:45.555Z\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}}, + {source: "\"2012-03-29T10:05:45\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}}, + {source: "\"2012-03-29T10:05:45.555\"", result: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}}, {source: "\"infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}}, {source: "\"-infinity\"", result: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}}, } diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 6d9b80cb4..71c73a786 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -11,10 +11,12 @@ import ( "github.com/yugabyte/pgx/v5/internal/pgio" ) -const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" -const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" -const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" -const microsecFromUnixEpochToY2K = 946684800 * 1000000 +const ( + pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" + pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" + pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" + microsecFromUnixEpochToY2K = 946684800 * 1000000 +) const ( negativeInfinityMicrosecondOffset = -9223372036854775808 @@ -36,16 +38,18 @@ type Timestamptz struct { Valid bool } +// ScanTimestamptz implements the [TimestamptzScanner] interface. func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { *tstz = v return nil } +// TimestamptzValue implements the [TimestamptzValuer] interface. func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { return tstz, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (tstz *Timestamptz) Scan(src any) error { if src == nil { *tstz = Timestamptz{} @@ -54,7 +58,7 @@ func (tstz *Timestamptz) Scan(src any) error { switch src := src.(type) { case string: - return scanPlanTextTimestamptzToTimestamptzScanner{}.Scan([]byte(src), tstz) + return (&scanPlanTextTimestamptzToTimestamptzScanner{}).Scan([]byte(src), tstz) case time.Time: *tstz = Timestamptz{Time: src, Valid: true} return nil @@ -63,7 +67,7 @@ func (tstz *Timestamptz) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (tstz Timestamptz) Value() (driver.Value, error) { if !tstz.Valid { return nil, nil @@ -75,6 +79,7 @@ func (tstz Timestamptz) Value() (driver.Value, error) { return tstz.Time, nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (tstz Timestamptz) MarshalJSON() ([]byte, error) { if !tstz.Valid { return []byte("null"), nil @@ -94,6 +99,7 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) { return json.Marshal(s) } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { var s *string err := json.Unmarshal(b, &s) @@ -124,17 +130,21 @@ func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { return nil } -type TimestamptzCodec struct{} +type TimestamptzCodec struct { + // ScanLocation is the location to return scanned timestamptz values in. This does not change the instant in time that + // the timestamptz represents. + ScanLocation *time.Location +} -func (TimestamptzCodec) FormatSupported(format int16) bool { +func (*TimestamptzCodec) FormatSupported(format int16) bool { return format == TextFormatCode || format == BinaryFormatCode } -func (TimestamptzCodec) PreferredFormat() int16 { +func (*TimestamptzCodec) PreferredFormat() int16 { return BinaryFormatCode } -func (TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { +func (*TimestamptzCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { if _, ok := value.(TimestamptzValuer); !ok { return nil } @@ -220,27 +230,26 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by return buf, nil } -func (TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - +func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { switch format { case BinaryFormatCode: switch target.(type) { case TimestamptzScanner: - return scanPlanBinaryTimestamptzToTimestamptzScanner{} + return &scanPlanBinaryTimestamptzToTimestamptzScanner{location: c.ScanLocation} } case TextFormatCode: switch target.(type) { case TimestamptzScanner: - return scanPlanTextTimestamptzToTimestamptzScanner{} + return &scanPlanTextTimestamptzToTimestamptzScanner{location: c.ScanLocation} } } return nil } -type scanPlanBinaryTimestamptzToTimestamptzScanner struct{} +type scanPlanBinaryTimestamptzToTimestamptzScanner struct{ location *time.Location } -func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -264,15 +273,18 @@ func (scanPlanBinaryTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) e microsecFromUnixEpochToY2K/1000000+microsecSinceY2K/1000000, (microsecFromUnixEpochToY2K%1000000*1000)+(microsecSinceY2K%1000000*1000), ) + if plan.location != nil { + tim = tim.In(plan.location) + } tstz = Timestamptz{Time: tim, Valid: true} } return scanner.ScanTimestamptz(tstz) } -type scanPlanTextTimestamptzToTimestamptzScanner struct{} +type scanPlanTextTimestamptzToTimestamptzScanner struct{ location *time.Location } -func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { +func (plan *scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) error { scanner := (dst).(TimestamptzScanner) if src == nil { @@ -312,13 +324,17 @@ func (scanPlanTextTimestamptzToTimestamptzScanner) Scan(src []byte, dst any) err tim = time.Date(year, tim.Month(), tim.Day(), tim.Hour(), tim.Minute(), tim.Second(), tim.Nanosecond(), tim.Location()) } + if plan.location != nil { + tim = tim.In(plan.location) + } + tstz = Timestamptz{Time: tim, Valid: true} } return scanner.ScanTimestamptz(tstz) } -func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { +func (c *TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { if src == nil { return nil, nil } @@ -336,7 +352,7 @@ func (c TimestamptzCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int1 return tstz.Time, nil } -func (c TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { +func (c *TimestamptzCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { if src == nil { return nil, nil } diff --git a/pgtype/timestamptz_test.go b/pgtype/timestamptz_test.go index 942028553..c220d6589 100644 --- a/pgtype/timestamptz_test.go +++ b/pgtype/timestamptz_test.go @@ -38,6 +38,40 @@ func TestTimestamptzCodec(t *testing.T) { }) } +func TestTimestamptzCodecWithLocationUTC(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamptz", + OID: pgtype.TimestamptzOID, + Codec: &pgtype.TimestamptzCodec{ScanLocation: time.UTC}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))}, + }) +} + +func TestTimestamptzCodecWithLocationLocal(t *testing.T) { + skipCockroachDB(t, "Server does not support infinite timestamps (see https://github.com/cockroachdb/cockroach/issues/41564)") + + connTestRunner := defaultConnTestRunner + connTestRunner.AfterConnect = func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + conn.TypeMap().RegisterType(&pgtype.Type{ + Name: "timestamptz", + OID: pgtype.TimestamptzOID, + Codec: &pgtype.TimestamptzCodec{ScanLocation: time.Local}, + }) + } + + pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, nil, "timestamptz", []pgxtest.ValueRoundTripTest{ + {time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local), new(time.Time), isExpectedEq(time.Date(2000, 1, 1, 0, 0, 0, 0, time.Local))}, + }) +} + // https://github.com/jackc/pgx/v4/pgtype/pull/128 func TestTimestamptzTranscodeBigTimeBinary(t *testing.T) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { diff --git a/pgtype/uint32.go b/pgtype/uint32.go index 618a16bdb..085832f7a 100644 --- a/pgtype/uint32.go +++ b/pgtype/uint32.go @@ -3,6 +3,7 @@ package pgtype import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math" "strconv" @@ -24,16 +25,18 @@ type Uint32 struct { Valid bool } +// ScanUint32 implements the [Uint32Scanner] interface. func (n *Uint32) ScanUint32(v Uint32) error { *n = v return nil } +// Uint32Value implements the [Uint32Valuer] interface. func (n Uint32) Uint32Value() (Uint32, error) { return n, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Uint32) Scan(src any) error { if src == nil { *dst = Uint32{} @@ -67,7 +70,7 @@ func (dst *Uint32) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Uint32) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -75,6 +78,31 @@ func (src Uint32) Value() (driver.Value, error) { return int64(src.Uint32), nil } +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Uint32) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + return json.Marshal(src.Uint32) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Uint32) UnmarshalJSON(b []byte) error { + var n *uint32 + err := json.Unmarshal(b, &n) + if err != nil { + return err + } + + if n == nil { + *dst = Uint32{} + } else { + *dst = Uint32{Uint32: *n, Valid: true} + } + + return nil +} + type Uint32Codec struct{} func (Uint32Codec) FormatSupported(format int16) bool { @@ -197,7 +225,6 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu } func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { - switch format { case BinaryFormatCode: switch target.(type) { @@ -205,6 +232,8 @@ func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPl return scanPlanBinaryUint32ToUint32{} case Uint32Scanner: return scanPlanBinaryUint32ToUint32Scanner{} + case TextScanner: + return scanPlanBinaryUint32ToTextScanner{} } case TextFormatCode: switch target.(type) { @@ -282,6 +311,26 @@ func (scanPlanBinaryUint32ToUint32Scanner) Scan(src []byte, dst any) error { return s.ScanUint32(Uint32{Uint32: n, Valid: true}) } +type scanPlanBinaryUint32ToTextScanner struct{} + +func (scanPlanBinaryUint32ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 4 { + return fmt.Errorf("invalid length for uint32: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint32(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + type scanPlanTextAnyToUint32Scanner struct{} func (scanPlanTextAnyToUint32Scanner) Scan(src []byte, dst any) error { diff --git a/pgtype/uint32_test.go b/pgtype/uint32_test.go index d9a924b3c..23000bd59 100644 --- a/pgtype/uint32_test.go +++ b/pgtype/uint32_test.go @@ -17,5 +17,6 @@ func TestUint32Codec(t *testing.T) { }, {pgtype.Uint32{}, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, {nil, new(pgtype.Uint32), isExpectedEq(pgtype.Uint32{})}, + {"1147", new(string), isExpectedEq("1147")}, }) } diff --git a/pgtype/uint64.go b/pgtype/uint64.go new file mode 100644 index 000000000..68fd16613 --- /dev/null +++ b/pgtype/uint64.go @@ -0,0 +1,323 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type Uint64Scanner interface { + ScanUint64(v Uint64) error +} + +type Uint64Valuer interface { + Uint64Value() (Uint64, error) +} + +// Uint64 is the core type that is used to represent PostgreSQL types such as XID8. +type Uint64 struct { + Uint64 uint64 + Valid bool +} + +// ScanUint64 implements the [Uint64Scanner] interface. +func (n *Uint64) ScanUint64(v Uint64) error { + *n = v + return nil +} + +// Uint64Value implements the [Uint64Valuer] interface. +func (n Uint64) Uint64Value() (Uint64, error) { + return n, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Uint64) Scan(src any) error { + if src == nil { + *dst = Uint64{} + return nil + } + + var n uint64 + + switch src := src.(type) { + case int64: + if src < 0 { + return fmt.Errorf("%d is less than the minimum value for Uint64", src) + } + n = uint64(src) + case string: + un, err := strconv.ParseUint(src, 10, 64) + if err != nil { + return err + } + n = un + default: + return fmt.Errorf("cannot scan %T", src) + } + + *dst = Uint64{Uint64: n, Valid: true} + + return nil +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Uint64) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + // If the value is greater than the maximum value for int64, return it as a string instead of losing data or returning + // an error. + if src.Uint64 > math.MaxInt64 { + return strconv.FormatUint(src.Uint64, 10), nil + } + + return int64(src.Uint64), nil +} + +type Uint64Codec struct{} + +func (Uint64Codec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (Uint64Codec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (Uint64Codec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch format { + case BinaryFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecBinaryUint64{} + case Uint64Valuer: + return encodePlanUint64CodecBinaryUint64Valuer{} + case Int64Valuer: + return encodePlanUint64CodecBinaryInt64Valuer{} + } + case TextFormatCode: + switch value.(type) { + case uint64: + return encodePlanUint64CodecTextUint64{} + case Int64Valuer: + return encodePlanUint64CodecTextInt64Valuer{} + } + } + + return nil +} + +type encodePlanUint64CodecBinaryUint64 struct{} + +func (encodePlanUint64CodecBinaryUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return pgio.AppendUint64(buf, v), nil +} + +type encodePlanUint64CodecBinaryUint64Valuer struct{} + +func (encodePlanUint64CodecBinaryUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return pgio.AppendUint64(buf, v.Uint64), nil +} + +type encodePlanUint64CodecBinaryInt64Valuer struct{} + +func (encodePlanUint64CodecBinaryInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return pgio.AppendUint64(buf, uint64(v.Int64)), nil +} + +type encodePlanUint64CodecTextUint64 struct{} + +func (encodePlanUint64CodecTextUint64) Encode(value any, buf []byte) (newBuf []byte, err error) { + v := value.(uint64) + return append(buf, strconv.FormatUint(uint64(v), 10)...), nil +} + +type encodePlanUint64CodecTextUint64Valuer struct{} + +func (encodePlanUint64CodecTextUint64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Uint64Valuer).Uint64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + return append(buf, strconv.FormatUint(v.Uint64, 10)...), nil +} + +type encodePlanUint64CodecTextInt64Valuer struct{} + +func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, err error) { + v, err := value.(Int64Valuer).Int64Value() + if err != nil { + return nil, err + } + + if !v.Valid { + return nil, nil + } + + if v.Int64 < 0 { + return nil, fmt.Errorf("%d is less than minimum value for uint64", v.Int64) + } + + return append(buf, strconv.FormatInt(v.Int64, 10)...), nil +} + +func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case *uint64: + return scanPlanBinaryUint64ToUint64{} + case Uint64Scanner: + return scanPlanBinaryUint64ToUint64Scanner{} + case TextScanner: + return scanPlanBinaryUint64ToTextScanner{} + } + case TextFormatCode: + switch target.(type) { + case *uint64: + return scanPlanTextAnyToUint64{} + case Uint64Scanner: + return scanPlanTextAnyToUint64Scanner{} + } + } + + return nil +} + +func (c Uint64Codec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return int64(n), nil +} + +func (c Uint64Codec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var n uint64 + err := codecScan(c, m, oid, format, src, &n) + if err != nil { + return nil, err + } + return n, nil +} + +type scanPlanBinaryUint64ToUint64 struct{} + +func (scanPlanBinaryUint64ToUint64) Scan(src []byte, dst any) error { + if src == nil { + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + p := (dst).(*uint64) + *p = binary.BigEndian.Uint64(src) + + return nil +} + +type scanPlanBinaryUint64ToUint64Scanner struct{} + +func (scanPlanBinaryUint64ToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := binary.BigEndian.Uint64(src) + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} + +type scanPlanBinaryUint64ToTextScanner struct{} + +func (scanPlanBinaryUint64ToTextScanner) Scan(src []byte, dst any) error { + s, ok := (dst).(TextScanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanText(Text{}) + } + + if len(src) != 8 { + return fmt.Errorf("invalid length for uint64: %v", len(src)) + } + + n := uint64(binary.BigEndian.Uint64(src)) + return s.ScanText(Text{String: strconv.FormatUint(n, 10), Valid: true}) +} + +type scanPlanTextAnyToUint64Scanner struct{} + +func (scanPlanTextAnyToUint64Scanner) Scan(src []byte, dst any) error { + s, ok := (dst).(Uint64Scanner) + if !ok { + return ErrScanTargetTypeChanged + } + + if src == nil { + return s.ScanUint64(Uint64{}) + } + + n, err := strconv.ParseUint(string(src), 10, 64) + if err != nil { + return err + } + + return s.ScanUint64(Uint64{Uint64: n, Valid: true}) +} diff --git a/pgtype/uint64_test.go b/pgtype/uint64_test.go new file mode 100644 index 000000000..33c2622d5 --- /dev/null +++ b/pgtype/uint64_test.go @@ -0,0 +1,30 @@ +package pgtype_test + +import ( + "context" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxtest" +) + +func TestUint64Codec(t *testing.T) { + skipCockroachDB(t, "Server does not support xid8 (https://github.com/cockroachdb/cockroach/issues/36815)") + skipPostgreSQLVersionLessThan(t, 13) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xid8", []pgxtest.ValueRoundTripTest{ + { + pgtype.Uint64{Uint64: 1 << 36, Valid: true}, + new(pgtype.Uint64), + isExpectedEq(pgtype.Uint64{Uint64: 1 << 36, Valid: true}), + }, + {pgtype.Uint64{}, new(pgtype.Uint64), isExpectedEq(pgtype.Uint64{})}, + {nil, new(pgtype.Uint64), isExpectedEq(pgtype.Uint64{})}, + { + uint64(1 << 36), + new(uint64), + isExpectedEq(uint64(1 << 36)), + }, + {"1147", new(string), isExpectedEq("1147")}, + }) +} diff --git a/pgtype/uuid.go b/pgtype/uuid.go index d57c0f2fa..83d0c4127 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -20,11 +20,13 @@ type UUID struct { Valid bool } +// ScanUUID implements the [UUIDScanner] interface. func (b *UUID) ScanUUID(v UUID) error { *b = v return nil } +// UUIDValue implements the [UUIDValuer] interface. func (b UUID) UUIDValue() (UUID, error) { return b, nil } @@ -67,7 +69,7 @@ func encodeUUID(src [16]byte) string { return string(buf[:]) } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *UUID) Scan(src any) error { if src == nil { *dst = UUID{} @@ -87,7 +89,7 @@ func (dst *UUID) Scan(src any) error { return fmt.Errorf("cannot scan %T", src) } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src UUID) Value() (driver.Value, error) { if !src.Valid { return nil, nil @@ -96,6 +98,15 @@ func (src UUID) Value() (driver.Value, error) { return encodeUUID(src.Bytes), nil } +func (src UUID) String() string { + if !src.Valid { + return "" + } + + return encodeUUID(src.Bytes) +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. func (src UUID) MarshalJSON() ([]byte, error) { if !src.Valid { return []byte("null"), nil @@ -108,6 +119,7 @@ func (src UUID) MarshalJSON() ([]byte, error) { return buff.Bytes(), nil } +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (dst *UUID) UnmarshalJSON(src []byte) error { if bytes.Equal(src, []byte("null")) { *dst = UUID{} diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index ef4397be3..52a1a068d 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -5,11 +5,13 @@ import ( "reflect" "testing" - "github.com/stretchr/testify/require" "github.com/yugabyte/pgx/v5/pgtype" "github.com/yugabyte/pgx/v5/pgxtest" + "github.com/stretchr/testify/require" ) +type renamedUUIDByteArray [16]byte + func TestUUIDCodec(t *testing.T) { pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "uuid", []pgxtest.ValueRoundTripTest{ { @@ -43,6 +45,16 @@ func TestUUIDCodec(t *testing.T) { new(pgtype.UUID), isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), }, + { + renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(pgtype.UUID), + isExpectedEq(pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Valid: true}), + }, + { + []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + new(renamedUUIDByteArray), + isExpectedEq(renamedUUIDByteArray{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}), + }, { []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, new(pgtype.UUID), @@ -51,6 +63,38 @@ func TestUUIDCodec(t *testing.T) { }) } +func TestUUID_String(t *testing.T) { + tests := []struct { + name string + src pgtype.UUID + want string + }{ + { + name: "first", + src: pgtype.UUID{ + Bytes: [16]byte{29, 72, 90, 122, 109, 24, 69, 153, 140, 108, 52, 66, 86, 22, 136, 122}, + Valid: true, + }, + want: "1d485a7a-6d18-4599-8c6c-34425616887a", + }, + { + name: "third", + src: pgtype.UUID{ + Bytes: [16]byte{}, + }, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.src.String() + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", got, tt.want) + } + }) + } +} + func TestUUID_MarshalJSON(t *testing.T) { tests := []struct { name string diff --git a/pgtype/xml.go b/pgtype/xml.go new file mode 100644 index 000000000..79e3698a4 --- /dev/null +++ b/pgtype/xml.go @@ -0,0 +1,198 @@ +package pgtype + +import ( + "database/sql" + "database/sql/driver" + "encoding/xml" + "fmt" + "reflect" +) + +type XMLCodec struct { + Marshal func(v any) ([]byte, error) + Unmarshal func(data []byte, v any) error +} + +func (*XMLCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (*XMLCodec) PreferredFormat() int16 { + return TextFormatCode +} + +func (c *XMLCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + switch value.(type) { + case string: + return encodePlanXMLCodecEitherFormatString{} + case []byte: + return encodePlanXMLCodecEitherFormatByteSlice{} + + // Cannot rely on driver.Valuer being handled later because anything can be marshalled. + // + // https://github.com/jackc/pgx/issues/1430 + // + // Check for driver.Valuer must come before xml.Marshaler so that it is guaranteed to be used + // when both are implemented https://github.com/jackc/pgx/issues/1805 + case driver.Valuer: + return &encodePlanDriverValuer{m: m, oid: oid, formatCode: format} + + // Must come before trying wrap encode plans because a pointer to a struct may be unwrapped to a struct that can be + // marshalled. + // + // https://github.com/jackc/pgx/issues/1681 + case xml.Marshaler: + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } + } + + // Because anything can be marshalled the normal wrapping in Map.PlanScan doesn't get a chance to run. So try the + // appropriate wrappers here. + for _, f := range []TryWrapEncodePlanFunc{ + TryWrapDerefPointerEncodePlan, + TryWrapFindUnderlyingTypeEncodePlan, + } { + if wrapperPlan, nextValue, ok := f(value); ok { + if nextPlan := c.PlanEncode(m, oid, format, nextValue); nextPlan != nil { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + return &encodePlanXMLCodecEitherFormatMarshal{ + marshal: c.Marshal, + } +} + +type encodePlanXMLCodecEitherFormatString struct{} + +func (encodePlanXMLCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlString := value.(string) + buf = append(buf, xmlString...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatByteSlice struct{} + +func (encodePlanXMLCodecEitherFormatByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes := value.([]byte) + if xmlBytes == nil { + return nil, nil + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +type encodePlanXMLCodecEitherFormatMarshal struct { + marshal func(v any) ([]byte, error) +} + +func (e *encodePlanXMLCodecEitherFormatMarshal) Encode(value any, buf []byte) (newBuf []byte, err error) { + xmlBytes, err := e.marshal(value) + if err != nil { + return nil, err + } + + buf = append(buf, xmlBytes...) + return buf, nil +} + +func (c *XMLCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch target.(type) { + case *string: + return scanPlanAnyToString{} + + case **string: + // This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better + // solution would be. + // + // https://github.com/jackc/pgx/issues/1470 -- **string + // https://github.com/jackc/pgx/issues/1691 -- ** anything else + + if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok { + if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil { + if _, failed := nextPlan.(*scanPlanFail); !failed { + wrapperPlan.SetNext(nextPlan) + return wrapperPlan + } + } + } + + case *[]byte: + return scanPlanXMLToByteSlice{} + case BytesScanner: + return scanPlanBinaryBytesToBytesScanner{} + + // Cannot rely on sql.Scanner being handled later because scanPlanXMLToXMLUnmarshal will take precedence. + // + // https://github.com/jackc/pgx/issues/1418 + case sql.Scanner: + return &scanPlanSQLScanner{formatCode: format} + } + + return &scanPlanXMLToXMLUnmarshal{ + unmarshal: c.Unmarshal, + } +} + +type scanPlanXMLToByteSlice struct{} + +func (scanPlanXMLToByteSlice) Scan(src []byte, dst any) error { + dstBuf := dst.(*[]byte) + if src == nil { + *dstBuf = nil + return nil + } + + *dstBuf = make([]byte, len(src)) + copy(*dstBuf, src) + return nil +} + +type scanPlanXMLToXMLUnmarshal struct { + unmarshal func(data []byte, v any) error +} + +func (s *scanPlanXMLToXMLUnmarshal) Scan(src []byte, dst any) error { + if src == nil { + dstValue := reflect.ValueOf(dst) + if dstValue.Kind() == reflect.Ptr { + el := dstValue.Elem() + switch el.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Map, reflect.Interface, reflect.Struct: + el.Set(reflect.Zero(el.Type())) + return nil + } + } + + return fmt.Errorf("cannot scan NULL into %T", dst) + } + + elem := reflect.ValueOf(dst).Elem() + elem.Set(reflect.Zero(elem.Type())) + + return s.unmarshal(src, dst) +} + +func (c *XMLCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + if src == nil { + return nil, nil + } + + dstBuf := make([]byte, len(src)) + copy(dstBuf, src) + return dstBuf, nil +} + +func (c *XMLCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var dst any + err := c.Unmarshal(src, &dst) + return dst, err +} diff --git a/pgtype/xml_test.go b/pgtype/xml_test.go new file mode 100644 index 000000000..2c0b899a5 --- /dev/null +++ b/pgtype/xml_test.go @@ -0,0 +1,128 @@ +package pgtype_test + +import ( + "context" + "database/sql" + "encoding/xml" + "testing" + + pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxtest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type xmlStruct struct { + XMLName xml.Name `xml:"person"` + Name string `xml:"name"` + Age int `xml:"age,attr"` +} + +func TestXMLCodec(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "xml", []pgxtest.ValueRoundTripTest{ + {nil, new(*xmlStruct), isExpectedEq((*xmlStruct)(nil))}, + {map[string]any(nil), new(*string), isExpectedEq((*string)(nil))}, + {map[string]any(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))}, + {nil, new([]byte), isExpectedEqBytes([]byte(nil))}, + + // Test sql.Scanner. + {"", new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + + // Test driver.Valuer. + {sql.NullString{String: "", Valid: true}, new(sql.NullString), isExpectedEq(sql.NullString{String: "", Valid: true})}, + }) + + pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "xml", []pgxtest.ValueRoundTripTest{ + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(``), new([]byte), isExpectedEqBytes([]byte(``))}, + {[]byte(``), new(string), isExpectedEq(``)}, + {[]byte(""), new([]byte), isExpectedEqBytes([]byte(""))}, + {xmlStruct{Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10}, new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + {[]byte(`Adam`), new(xmlStruct), isExpectedEq(xmlStruct{XMLName: xml.Name{Local: "person"}, Name: "Adam", Age: 10})}, + }) +} + +// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648 +func TestXMLCodecUnmarshalSQLNull(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Byte arrays are nilified + slice := []byte{10, 4} + err := conn.QueryRow(ctx, "select null::xml").Scan(&slice) + assert.NoError(t, err) + assert.Nil(t, slice) + + // Non-pointer structs are zeroed + m := xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&m) + assert.NoError(t, err) + assert.Empty(t, m) + + // Pointers to structs are nilified + pm := &xmlStruct{Name: "Adam"} + err = conn.QueryRow(ctx, "select null::xml").Scan(&pm) + assert.NoError(t, err) + assert.Nil(t, pm) + + // Pointer to pointer are nilified + n := "" + p := &n + err = conn.QueryRow(ctx, "select null::xml").Scan(&p) + assert.NoError(t, err) + assert.Nil(t, p) + + // A string cannot scan a NULL. + str := "foobar" + err = conn.QueryRow(ctx, "select null::xml").Scan(&str) + assert.EqualError(t, err, "can't scan into dest[0] (col: xml): cannot scan NULL into *string") + }) +} + +func TestXMLCodecPointerToPointerToString(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var s *string + err := conn.QueryRow(ctx, "select ''::xml").Scan(&s) + require.NoError(t, err) + require.NotNil(t, s) + require.Equal(t, "", *s) + + err = conn.QueryRow(ctx, "select null::xml").Scan(&s) + require.NoError(t, err) + require.Nil(t, s) + }) +} + +func TestXMLCodecDecodeValue(t *testing.T) { + skipCockroachDB(t, "CockroachDB does not support XML.") + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { + for _, tt := range []struct { + sql string + expected any + }{ + { + sql: `select 'bar'::xml`, + expected: []byte("bar"), + }, + } { + t.Run(tt.sql, func(t *testing.T) { + rows, err := conn.Query(ctx, tt.sql) + require.NoError(t, err) + + for rows.Next() { + values, err := rows.Values() + require.NoError(t, err) + require.Len(t, values, 1) + require.Equal(t, tt.expected, values[0]) + } + + require.NoError(t, rows.Err()) + }) + } + }) +} diff --git a/pgtype/zeronull/float8.go b/pgtype/zeronull/float8.go index 70bbb8d69..f6f04b7f3 100644 --- a/pgtype/zeronull/float8.go +++ b/pgtype/zeronull/float8.go @@ -8,9 +8,10 @@ import ( type Float8 float64 +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Float8) SkipUnderlyingTypePlan() {} -// ScanFloat64 implements the Float64Scanner interface. +// ScanFloat64 implements the [pgtype.Float64Scanner] interface. func (f *Float8) ScanFloat64(n pgtype.Float8) error { if !n.Valid { *f = 0 @@ -22,6 +23,7 @@ func (f *Float8) ScanFloat64(n pgtype.Float8) error { return nil } +// Float64Value implements the [pgtype.Float64Valuer] interface. func (f Float8) Float64Value() (pgtype.Float8, error) { if f == 0 { return pgtype.Float8{}, nil @@ -29,7 +31,7 @@ func (f Float8) Float64Value() (pgtype.Float8, error) { return pgtype.Float8{Float64: float64(f), Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (f *Float8) Scan(src any) error { if src == nil { *f = 0 @@ -47,7 +49,7 @@ func (f *Float8) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (f Float8) Value() (driver.Value, error) { if f == 0 { return nil, nil diff --git a/pgtype/zeronull/int.go b/pgtype/zeronull/int.go index 403609dab..40475e97e 100644 --- a/pgtype/zeronull/int.go +++ b/pgtype/zeronull/int.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/zeronull/int.go.erb +// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT. + package zeronull import ( @@ -11,27 +12,36 @@ import ( type Int2 int16 +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Int2) SkipUnderlyingTypePlan() {} -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int2) ScanInt64(n int64, valid bool) error { - if !valid { +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int2) ScanInt64(n pgtype.Int8) error { + if !n.Valid { *dst = 0 return nil } - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int64 < math.MinInt16 { + return fmt.Errorf("%d is less than minimum value for Int2", n.Int64) } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int64 > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) } - *dst = Int2(n) + *dst = Int2(n.Int64) return nil } -// Scan implements the database/sql Scanner interface. +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int2) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. func (dst *Int2) Scan(src any) error { if src == nil { *dst = 0 @@ -49,7 +59,7 @@ func (dst *Int2) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int2) Value() (driver.Value, error) { if src == 0 { return nil, nil @@ -59,27 +69,36 @@ func (src Int2) Value() (driver.Value, error) { type Int4 int32 +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Int4) SkipUnderlyingTypePlan() {} -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int4) ScanInt64(n int64, valid bool) error { - if !valid { +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int4) ScanInt64(n pgtype.Int8) error { + if !n.Valid { *dst = 0 return nil } - if n < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int64 < math.MinInt32 { + return fmt.Errorf("%d is less than minimum value for Int4", n.Int64) } - if n > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int64 > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) } - *dst = Int4(n) + *dst = Int4(n.Int64) return nil } -// Scan implements the database/sql Scanner interface. +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int4) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. func (dst *Int4) Scan(src any) error { if src == nil { *dst = 0 @@ -97,7 +116,7 @@ func (dst *Int4) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int4) Value() (driver.Value, error) { if src == 0 { return nil, nil @@ -107,27 +126,36 @@ func (src Int4) Value() (driver.Value, error) { type Int8 int64 +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Int8) SkipUnderlyingTypePlan() {} -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int8) ScanInt64(n int64, valid bool) error { - if !valid { +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int8) ScanInt64(n pgtype.Int8) error { + if !n.Valid { *dst = 0 return nil } - if n < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int64 < math.MinInt64 { + return fmt.Errorf("%d is less than minimum value for Int8", n.Int64) } - if n > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int64 > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) } - *dst = Int8(n) + *dst = Int8(n.Int64) return nil } -// Scan implements the database/sql Scanner interface. +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int8) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. func (dst *Int8) Scan(src any) error { if src == nil { *dst = 0 @@ -145,7 +173,7 @@ func (dst *Int8) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int8) Value() (driver.Value, error) { if src == 0 { return nil, nil diff --git a/pgtype/zeronull/int.go.erb b/pgtype/zeronull/int.go.erb index b7bc72a59..026d7763f 100644 --- a/pgtype/zeronull/int.go.erb +++ b/pgtype/zeronull/int.go.erb @@ -12,27 +12,36 @@ import ( <% pg_bit_size = pg_byte_size * 8 %> type Int<%= pg_byte_size %> int<%= pg_bit_size %> +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} -// ScanInt64 implements the Int64Scanner interface. -func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { - if !valid { +// ScanInt64 implements the [pgtype.Int64Scanner] interface. +func (dst *Int<%= pg_byte_size %>) ScanInt64(n pgtype.Int8) error { + if !n.Valid { *dst = 0 return nil } - if n < math.MinInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int64 < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64) } - if n > math.MaxInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int64 > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) } - *dst = Int<%= pg_byte_size %>(n) + *dst = Int<%= pg_byte_size %>(n.Int64) return nil } -// Scan implements the database/sql Scanner interface. +// Int64Value implements the [pgtype.Int64Valuer] interface. +func (src Int<%= pg_byte_size %>) Int64Value() (pgtype.Int8, error) { + if src == 0 { + return pgtype.Int8{}, nil + } + return pgtype.Int8{Int64: int64(src), Valid: true}, nil +} + +// Scan implements the [database/sql.Scanner] interface. func (dst *Int<%= pg_byte_size %>) Scan(src any) error { if src == nil { *dst = 0 @@ -50,7 +59,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { if src == 0 { return nil, nil diff --git a/pgtype/zeronull/int_test.go b/pgtype/zeronull/int_test.go index cf7c99eb8..989c3f35c 100644 --- a/pgtype/zeronull/int_test.go +++ b/pgtype/zeronull/int_test.go @@ -1,4 +1,5 @@ -// Do not edit. Generated from pgtype/zeronull/int_test.go.erb +// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT. + package zeronull_test import ( diff --git a/pgtype/zeronull/text.go b/pgtype/zeronull/text.go index 6edd1adb2..f0e1e019d 100644 --- a/pgtype/zeronull/text.go +++ b/pgtype/zeronull/text.go @@ -8,9 +8,10 @@ import ( type Text string +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Text) SkipUnderlyingTypePlan() {} -// ScanText implements the TextScanner interface. +// ScanText implements the [pgtype.TextScanner] interface. func (dst *Text) ScanText(v pgtype.Text) error { if !v.Valid { *dst = "" @@ -22,7 +23,7 @@ func (dst *Text) ScanText(v pgtype.Text) error { return nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (dst *Text) Scan(src any) error { if src == nil { *dst = "" @@ -40,7 +41,7 @@ func (dst *Text) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (src Text) Value() (driver.Value, error) { if src == "" { return nil, nil diff --git a/pgtype/zeronull/timestamp.go b/pgtype/zeronull/timestamp.go index 9d3144aff..e8000cb82 100644 --- a/pgtype/zeronull/timestamp.go +++ b/pgtype/zeronull/timestamp.go @@ -10,8 +10,10 @@ import ( type Timestamp time.Time +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (Timestamp) SkipUnderlyingTypePlan() {} +// ScanTimestamp implements the [pgtype.TimestampScanner] interface. func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { if !v.Valid { *ts = Timestamp{} @@ -31,6 +33,7 @@ func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { } } +// TimestampValue implements the [pgtype.TimestampValuer] interface. func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { if time.Time(ts).IsZero() { return pgtype.Timestamp{}, nil @@ -39,7 +42,7 @@ func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (ts *Timestamp) Scan(src any) error { if src == nil { *ts = Timestamp{} @@ -57,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (ts Timestamp) Value() (driver.Value, error) { if time.Time(ts).IsZero() { return nil, nil diff --git a/pgtype/zeronull/timestamptz.go b/pgtype/zeronull/timestamptz.go index 78184842d..e617345fe 100644 --- a/pgtype/zeronull/timestamptz.go +++ b/pgtype/zeronull/timestamptz.go @@ -10,12 +10,9 @@ import ( type Timestamptz time.Time -func (Timestamptz) SkipUnderlyingTypePlan() {} - func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { if !v.Valid { *ts = Timestamptz{} - return nil } switch v.InfinityModifier { @@ -31,6 +28,7 @@ func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { } } +// TimestamptzValue implements the [pgtype.TimestamptzValuer] interface. func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { if time.Time(ts).IsZero() { return pgtype.Timestamptz{}, nil @@ -39,7 +37,7 @@ func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (ts *Timestamptz) Scan(src any) error { if src == nil { *ts = Timestamptz{} @@ -57,7 +55,7 @@ func (ts *Timestamptz) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (ts Timestamptz) Value() (driver.Value, error) { if time.Time(ts).IsZero() { return nil, nil diff --git a/pgtype/zeronull/uuid.go b/pgtype/zeronull/uuid.go index 611a3645c..b04944c27 100644 --- a/pgtype/zeronull/uuid.go +++ b/pgtype/zeronull/uuid.go @@ -8,9 +8,10 @@ import ( type UUID [16]byte +// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface. func (UUID) SkipUnderlyingTypePlan() {} -// ScanUUID implements the UUIDScanner interface. +// ScanUUID implements the [pgtype.UUIDScanner] interface. func (u *UUID) ScanUUID(v pgtype.UUID) error { if !v.Valid { *u = UUID{} @@ -22,6 +23,7 @@ func (u *UUID) ScanUUID(v pgtype.UUID) error { return nil } +// UUIDValue implements the [pgtype.UUIDValuer] interface. func (u UUID) UUIDValue() (pgtype.UUID, error) { if u == (UUID{}) { return pgtype.UUID{}, nil @@ -29,7 +31,7 @@ func (u UUID) UUIDValue() (pgtype.UUID, error) { return pgtype.UUID{Bytes: u, Valid: true}, nil } -// Scan implements the database/sql Scanner interface. +// Scan implements the [database/sql.Scanner] interface. func (u *UUID) Scan(src any) error { if src == nil { *u = UUID{} @@ -47,7 +49,7 @@ func (u *UUID) Scan(src any) error { return nil } -// Value implements the database/sql/driver Valuer interface. +// Value implements the [database/sql/driver.Valuer] interface. func (u UUID) Value() (driver.Value, error) { if u == (UUID{}) { return nil, nil diff --git a/pgx_test.go b/pgx_test.go new file mode 100644 index 000000000..51b4bbc4e --- /dev/null +++ b/pgx_test.go @@ -0,0 +1,22 @@ +package pgx_test + +import ( + "context" + "os" + "testing" + + "github.com/jackc/pgx/v5" + _ "github.com/jackc/pgx/v5/stdlib" +) + +func skipCockroachDB(t testing.TB, msg string) { + conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + if err != nil { + t.Fatal(err) + } + defer conn.Close(context.Background()) + + if conn.PgConn().ParameterStatus("crdb_version") != "" { + t.Skip(msg) + } +} diff --git a/pgxpool/common_test.go b/pgxpool/common_test.go index b1cee1fb7..4465a637f 100644 --- a/pgxpool/common_test.go +++ b/pgxpool/common_test.go @@ -97,7 +97,8 @@ func testCopyFrom(t *testing.T, ctx context.Context, db interface { execer queryer copyFromer -}) { +}, +) { _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) require.NoError(t, err) @@ -141,12 +142,14 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName // Can't test function equality, so just test that they are set or not. assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) + assert.Equalf(t, expected.PrepareConn == nil, actual.PrepareConn == nil, "%s - PrepareConn", testName) assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName) assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName) assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName) + assert.Equalf(t, expected.MinIdleConns, actual.MinIdleConns, "%s - MinIdleConns", testName) assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName) assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName) diff --git a/pgxpool/conn.go b/pgxpool/conn.go index 66d5f06d0..399a38331 100644 --- a/pgxpool/conn.go +++ b/pgxpool/conn.go @@ -26,6 +26,10 @@ func (c *Conn) Release() { res := c.res c.res = nil + if c.p.releaseTracer != nil { + c.p.releaseTracer.TraceRelease(c.p, TraceReleaseData{Conn: conn}) + } + if conn.IsClosed() || conn.PgConn().IsBusy() || conn.PgConn().TxStatus() != 'I' { res.Destroy() // Signal to the health check to run since we just destroyed a connections diff --git a/pgxpool/doc.go b/pgxpool/doc.go index 06cc63d5f..099443bca 100644 --- a/pgxpool/doc.go +++ b/pgxpool/doc.go @@ -8,7 +8,7 @@ The primary way of creating a pool is with [pgxpool.New]: pool, err := pgxpool.New(context.Background(), os.Getenv("DATABASE_URL")) -The database connection string can be in URL or DSN format. PostgreSQL settings, pgx settings, and pool settings can be +The database connection string can be in URL or keyword/value format. PostgreSQL settings, pgx settings, and pool settings can be specified here. In addition, a config struct can be created by [ParseConfig]. config, err := pgxpool.ParseConfig(os.Getenv("DATABASE_URL")) diff --git a/pgxpool/pool.go b/pgxpool/pool.go index 5265f629d..b7612afa3 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -2,7 +2,7 @@ package pgxpool import ( "context" - "fmt" + "errors" "math/rand" "runtime" "strconv" @@ -10,16 +10,19 @@ import ( "sync/atomic" "time" - "github.com/jackc/puddle/v2" + "github.com/yugabyte/puddle/v2" "github.com/yugabyte/pgx/v5" "github.com/yugabyte/pgx/v5/pgconn" ) -var defaultMaxConns = int32(4) -var defaultMinConns = int32(0) -var defaultMaxConnLifetime = time.Hour -var defaultMaxConnIdleTime = time.Minute * 30 -var defaultHealthCheckPeriod = time.Minute +var ( + defaultMaxConns = int32(4) + defaultMinConns = int32(0) + defaultMinIdleConns = int32(0) + defaultMaxConnLifetime = time.Hour + defaultMaxConnIdleTime = time.Minute * 30 + defaultHealthCheckPeriod = time.Minute +) type connResource struct { conn *pgx.Conn @@ -83,10 +86,12 @@ type Pool struct { config *Config beforeConnect func(context.Context, *pgx.ConnConfig) error afterConnect func(context.Context, *pgx.Conn) error - beforeAcquire func(context.Context, *pgx.Conn) bool + prepareConn func(context.Context, *pgx.Conn) (bool, error) afterRelease func(*pgx.Conn) bool beforeClose func(*pgx.Conn) + shouldPing func(context.Context, ShouldPingParams) bool minConns int32 + minIdleConns int32 maxConns int32 maxConnLifetime time.Duration maxConnLifetimeJitter time.Duration @@ -95,10 +100,19 @@ type Pool struct { healthCheckChan chan struct{} + acquireTracer AcquireTracer + releaseTracer ReleaseTracer + closeOnce sync.Once closeChan chan struct{} } +// ShouldPingParams are the parameters passed to ShouldPing. +type ShouldPingParams struct { + Conn *pgx.Conn + IdleDuration time.Duration +} + // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be // modified. type Config struct { @@ -114,8 +128,23 @@ type Config struct { // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the // acquisition or false to indicate that the connection should be destroyed and a different connection should be // acquired. + // + // Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take + // precedence, ignoring BeforeAcquire. BeforeAcquire func(context.Context, *pgx.Conn) bool + // PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection + // is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating + // query will fail with the returned error. + // + // Specifically, this means that: + // + // - If it returns true and a nil error, the query proceeds as normal. + // - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error. + // - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error. + // - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection. + PrepareConn func(context.Context, *pgx.Conn) (bool, error) + // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // return the connection to the pool or false to destroy the connection. AfterRelease func(*pgx.Conn) bool @@ -123,6 +152,10 @@ type Config struct { // BeforeClose is called right before a connection is closed and removed from the pool. BeforeClose func(*pgx.Conn) + // ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness. + // If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second. + ShouldPing func(context.Context, ShouldPingParams) bool + // MaxConnLifetime is the duration since creation after which a connection will be automatically closed. MaxConnLifetime time.Duration @@ -141,6 +174,13 @@ type Config struct { // to create new connections. MinConns int32 + // MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that + // there are always idle connections available. This can help reduce tail latencies during request processing, + // as you can avoid the latency of establishing a new connection while handling requests. It is superior + // to MinConns for this purpose. + // Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes. + MinIdleConns int32 + // HealthCheckPeriod is the duration between checks of the health of idle connections. HealthCheckPeriod time.Duration @@ -178,14 +218,22 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { panic("config must be created by ParseConfig") } + prepareConn := config.PrepareConn + if prepareConn == nil && config.BeforeAcquire != nil { + prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + return config.BeforeAcquire(ctx, conn), nil + } + } + p := &Pool{ config: config, beforeConnect: config.BeforeConnect, afterConnect: config.AfterConnect, - beforeAcquire: config.BeforeAcquire, + prepareConn: prepareConn, afterRelease: config.AfterRelease, beforeClose: config.BeforeClose, minConns: config.MinConns, + minIdleConns: config.MinIdleConns, maxConns: config.MaxConns, maxConnLifetime: config.MaxConnLifetime, maxConnLifetimeJitter: config.MaxConnLifetimeJitter, @@ -195,6 +243,22 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { closeChan: make(chan struct{}), } + if t, ok := config.ConnConfig.Tracer.(AcquireTracer); ok { + p.acquireTracer = t + } + + if t, ok := config.ConnConfig.Tracer.(ReleaseTracer); ok { + p.releaseTracer = t + } + + if config.ShouldPing != nil { + p.shouldPing = config.ShouldPing + } else { + p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool { + return params.IdleDuration > time.Second + } + } + var err error p.p, err = puddle.NewPool( &puddle.Config[*connResource]{ @@ -260,7 +324,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { } go func() { - p.createIdleResources(ctx, int(p.minConns)) + targetIdleResources := max(int(p.minConns), int(p.minIdleConns)) + p.createIdleResources(ctx, targetIdleResources) p.backgroundHealthCheck() }() @@ -270,20 +335,20 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) { // ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the // addition of the following variables: // -// - pool_max_conns: integer greater than 0 -// - pool_min_conns: integer 0 or greater -// - pool_max_conn_lifetime: duration string -// - pool_max_conn_idle_time: duration string -// - pool_health_check_period: duration string -// - pool_max_conn_lifetime_jitter: duration string +// - pool_max_conns: integer greater than 0 (default 4) +// - pool_min_conns: integer 0 or greater (default 0) +// - pool_max_conn_lifetime: duration string (default 1 hour) +// - pool_max_conn_idle_time: duration string (default 30 minutes) +// - pool_health_check_period: duration string (default 1 minute) +// - pool_max_conn_lifetime_jitter: duration string (default 0) // // See Config for definitions of these arguments. // -// # Example DSN -// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 +// # Example Keyword/Value +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca pool_max_conns=10 pool_max_conn_lifetime=1h30m // // # Example URL -// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10 +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca&pool_max_conns=10&pool_max_conn_lifetime=1h30m func ParseConfig(connString string) (*Config, error) { connConfig, err := pgx.ParseConfig(connString) if err != nil { @@ -299,10 +364,10 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err) } if n < 1 { - return nil, fmt.Errorf("pool_max_conns too small: %d", n) + return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err) } config.MaxConns = int32(n) } else { @@ -316,18 +381,29 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_min_conns") n, err := strconv.ParseInt(s, 10, 32) if err != nil { - return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err) } config.MinConns = int32(n) } else { config.MinConns = defaultMinConns } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok { + delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns") + n, err := strconv.ParseInt(s, 10, 32) + if err != nil { + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err) + } + config.MinIdleConns = int32(n) + } else { + config.MinIdleConns = defaultMinIdleConns + } + if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err) } config.MaxConnLifetime = d } else { @@ -338,7 +414,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err) } config.MaxConnIdleTime = d } else { @@ -349,7 +425,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_health_check_period") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_health_check_period: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err) } config.HealthCheckPeriod = d } else { @@ -360,7 +436,7 @@ func ParseConfig(connString string) (*Config, error) { delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") d, err := time.ParseDuration(s) if err != nil { - return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err) + return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err) } config.MaxConnLifetimeJitter = d } @@ -461,7 +537,9 @@ func (p *Pool) checkMinConns() error { // TotalConns can include ones that are being destroyed but we should have // sleep(500ms) around all of the destroys to help prevent that from throwing // off this check - toCreate := p.minConns - p.Stat().TotalConns() + + // Create the number of connections needed to get to both minConns and minIdleConns + toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns()) if toCreate > 0 { return p.createIdleResources(context.Background(), int(toCreate)) } @@ -498,8 +576,22 @@ func (p *Pool) createIdleResources(parentCtx context.Context, targetResources in } // Acquire returns a connection (*Conn) from the Pool -func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { - for { +func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) { + if p.acquireTracer != nil { + ctx = p.acquireTracer.TraceAcquireStart(ctx, p, TraceAcquireStartData{}) + defer func() { + var conn *pgx.Conn + if c != nil { + conn = c.Conn() + } + p.acquireTracer.TraceAcquireEnd(ctx, p, TraceAcquireEndData{Conn: conn, Err: err}) + }() + } + + // Try to acquire from the connection pool up to maxConns + 1 times, so that + // any that fatal errors would empty the pool and still at least try 1 fresh + // connection. + for range p.maxConns + 1 { res, err := p.p.Acquire(ctx) if err != nil { return nil, err @@ -507,7 +599,8 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { cr := res.Value() - if res.IdleDuration() > time.Second { + shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} + if p.shouldPing(ctx, shouldPingParams) { err := cr.conn.Ping(ctx) if err != nil { res.Destroy() @@ -515,12 +608,25 @@ func (p *Pool) Acquire(ctx context.Context) (*Conn, error) { } } - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - return cr.getConn(p, res), nil + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok { + res.Destroy() + } + if err != nil { + if ok { + res.Release() + } + return nil, err + } + if !ok { + continue + } } - res.Destroy() + return cr.getConn(p, res), nil } + return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook") } // AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the @@ -543,11 +649,14 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn { conns := make([]*Conn, 0, len(resources)) for _, res := range resources { cr := res.Value() - if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) { - conns = append(conns, cr.getConn(p, res)) - } else { - res.Destroy() + if p.prepareConn != nil { + ok, err := p.prepareConn(ctx, cr.conn) + if !ok || err != nil { + res.Destroy() + continue + } } + conns = append(conns, cr.getConn(p, res)) } return conns diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index 9ecc8ff9d..b393aee0f 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -43,10 +43,11 @@ func TestConnectConfig(t *testing.T) { func TestParseConfigExtractsPoolArguments(t *testing.T) { t.Parallel() - config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1") + config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1 pool_min_idle_conns=2") assert.NoError(t, err) assert.EqualValues(t, 42, config.MaxConns) assert.EqualValues(t, 1, config.MinConns) + assert.EqualValues(t, 2, config.MinIdleConns) assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns") assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns") } @@ -203,6 +204,64 @@ func TestPoolAcquireChecksIdleConns(t *testing.T) { require.NotContains(t, pids, cPID) } +func TestPoolAcquireChecksIdleConnsWithShouldPing(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer controllerConn.Close(ctx) + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + // Replace the default ShouldPing func + var shouldPingLastCalledWith *pgxpool.ShouldPingParams + config.ShouldPing = func(ctx context.Context, params pgxpool.ShouldPingParams) bool { + shouldPingLastCalledWith = ¶ms + return false + } + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + // All conns are dead they don't know it and neither does the pool. + require.EqualValues(t, 3, pool.Stat().TotalConns()) + + // Wait long enough so the pool will realize it needs to check the connections. + time.Sleep(time.Second) + + // Pool should try all existing connections and find them dead, then create a new connection which should successfully ping. + err = pool.Ping(ctx) + require.NoError(t, err) + + // The original 3 conns should have been terminated and the a new conn established for the ping. + require.EqualValues(t, 1, pool.Stat().TotalConns()) + c, err := pool.Acquire(ctx) + require.NoError(t, err) + + cPID := c.Conn().PgConn().PID() + c.Release() + + time.Sleep(time.Millisecond * 200) + + c, err = pool.Acquire(ctx) + require.NoError(t, err) + conn := c.Conn() + + require.NotNil(t, shouldPingLastCalledWith) + assert.Equal(t, conn, shouldPingLastCalledWith.Conn) + assert.InDelta(t, time.Millisecond*200, shouldPingLastCalledWith.IdleDuration, float64(time.Millisecond*100)) + + c.Release() +} + func TestPoolAcquireFunc(t *testing.T) { t.Parallel() @@ -329,6 +388,64 @@ func TestPoolBeforeAcquire(t *testing.T) { assert.EqualValues(t, 12, acquireAttempts) } +func TestPoolPrepareConn(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + acquireAttempts := 0 + + config.PrepareConn = func(context.Context, *pgx.Conn) (bool, error) { + acquireAttempts++ + var err error + if acquireAttempts%3 == 0 { + err = errors.New("PrepareConn error") + } + return acquireAttempts%2 == 0, err + } + + db, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + t.Cleanup(db.Close) + + var errorCount int + conns := make([]*pgxpool.Conn, 0, 4) + for { + conn, err := db.Acquire(ctx) + if err != nil { + errorCount++ + continue + } + conns = append(conns, conn) + if len(conns) == 4 { + break + } + } + const wantErrorCount = 3 + assert.Equal(t, wantErrorCount, errorCount, "Acquire() should have failed %d times", wantErrorCount) + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, len(conns)*2+wantErrorCount-1, acquireAttempts) + + conns = db.AcquireAllIdle(ctx) + assert.Len(t, conns, 1) + + for _, c := range conns { + c.Release() + } + waitForReleaseToComplete() + + assert.EqualValues(t, 14, acquireAttempts) +} + func TestPoolAfterRelease(t *testing.T) { t.Parallel() @@ -676,7 +793,6 @@ func TestPoolQuery(t *testing.T) { stats = pool.Stat() assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 1, stats.TotalConns()) - } func TestPoolQueryRow(t *testing.T) { @@ -1077,6 +1193,7 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { config.MinConns = int32(12) config.MaxConns = int32(15) +<<<<<<< HEAD acquireAttempts := int64(0) connectAttempts := int64(0) @@ -1097,6 +1214,28 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { for i := 0; i < 500; i++ { time.Sleep(10 * time.Millisecond) +======= + + acquireAttempts := int64(0) + connectAttempts := int64(0) + + config.PrepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { + atomic.AddInt64(&acquireAttempts, 1) + return true, nil + } + config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { + atomic.AddInt64(&connectAttempts, 1) + return nil + } + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + for i := 0; i < 500; i++ { + time.Sleep(10 * time.Millisecond) + +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e stat := pool.Stat() if stat.IdleConns() == 12 && stat.AcquireCount() == 0 && stat.TotalConns() == 12 && atomic.LoadInt64(&acquireAttempts) == 0 && atomic.LoadInt64(&connectAttempts) == 12 { return @@ -1104,7 +1243,10 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { diff --git a/pgxpool/stat.go b/pgxpool/stat.go index cfa0c4c56..e02b6ac39 100644 --- a/pgxpool/stat.go +++ b/pgxpool/stat.go @@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 { func (s *Stat) MaxIdleDestroyCount() int64 { return s.idleDestroyCount } + +// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires +// from the pool for a resource to be released or constructed because the pool was +// empty. +func (s *Stat) EmptyAcquireWaitTime() time.Duration { + return s.s.EmptyAcquireWaitTime() +} diff --git a/pgxpool/tracer.go b/pgxpool/tracer.go new file mode 100644 index 000000000..78b9d15a2 --- /dev/null +++ b/pgxpool/tracer.go @@ -0,0 +1,33 @@ +package pgxpool + +import ( + "context" + + "github.com/jackc/pgx/v5" +) + +// AcquireTracer traces Acquire. +type AcquireTracer interface { + // TraceAcquireStart is called at the beginning of Acquire. + // The returned context is used for the rest of the call and will be passed to the TraceAcquireEnd. + TraceAcquireStart(ctx context.Context, pool *Pool, data TraceAcquireStartData) context.Context + // TraceAcquireEnd is called when a connection has been acquired. + TraceAcquireEnd(ctx context.Context, pool *Pool, data TraceAcquireEndData) +} + +type TraceAcquireStartData struct{} + +type TraceAcquireEndData struct { + Conn *pgx.Conn + Err error +} + +// ReleaseTracer traces Release. +type ReleaseTracer interface { + // TraceRelease is called at the beginning of Release. + TraceRelease(pool *Pool, data TraceReleaseData) +} + +type TraceReleaseData struct { + Conn *pgx.Conn +} diff --git a/pgxpool/tracer_test.go b/pgxpool/tracer_test.go new file mode 100644 index 000000000..10724d94c --- /dev/null +++ b/pgxpool/tracer_test.go @@ -0,0 +1,130 @@ +package pgxpool_test + +import ( + "context" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/stretchr/testify/require" +) + +type testTracer struct { + traceAcquireStart func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context + traceAcquireEnd func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) + traceRelease func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) +} + +type ctxKey string + +func (tt *testTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + if tt.traceAcquireStart != nil { + return tt.traceAcquireStart(ctx, pool, data) + } + return ctx +} + +func (tt *testTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + if tt.traceAcquireEnd != nil { + tt.traceAcquireEnd(ctx, pool, data) + } +} + +func (tt *testTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + if tt.traceRelease != nil { + tt.traceRelease(pool, data) + } +} + +func (tt *testTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + return ctx +} + +func (tt *testTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { +} + +func TestTraceAcquire(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + traceAcquireStartCalled := false + tracer.traceAcquireStart = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context { + traceAcquireStartCalled = true + require.NotNil(t, pool) + return context.WithValue(ctx, ctxKey("fromTraceAcquireStart"), "foo") + } + + traceAcquireEndCalled := false + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.Equal(t, "foo", ctx.Value(ctxKey("fromTraceAcquireStart"))) + require.NotNil(t, pool) + require.NotNil(t, data.Conn) + require.NoError(t, data.Err) + } + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + defer c.Release() + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) + + traceAcquireStartCalled = false + traceAcquireEndCalled = false + tracer.traceAcquireEnd = func(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + traceAcquireEndCalled = true + require.NotNil(t, pool) + require.Nil(t, data.Conn) + require.Error(t, data.Err) + } + + ctx, cancel = context.WithCancel(ctx) + cancel() + _, err = pool.Acquire(ctx) + require.ErrorIs(t, err, context.Canceled) + require.True(t, traceAcquireStartCalled) + require.True(t, traceAcquireEndCalled) +} + +func TestTraceRelease(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + defer pool.Close() + + traceReleaseCalled := false + tracer.traceRelease = func(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) { + traceReleaseCalled = true + require.NotNil(t, pool) + require.NotNil(t, data.Conn) + } + + c, err := pool.Acquire(ctx) + require.NoError(t, err) + c.Release() + require.True(t, traceReleaseCalled) +} diff --git a/pgxpool/tx.go b/pgxpool/tx.go index 8e004722b..fa42ed0f5 100644 --- a/pgxpool/tx.go +++ b/pgxpool/tx.go @@ -18,9 +18,10 @@ func (tx *Tx) Begin(ctx context.Context) (pgx.Tx, error) { return tx.t.Begin(ctx) } -// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return ErrTxClosed -// if the Tx is already closed, but is otherwise safe to call multiple times. If the commit fails with a rollback status -// (e.g. the transaction was already in a broken state) then ErrTxCommitRollback will be returned. +// Commit commits the transaction and returns the associated connection back to the Pool. Commit will return an error +// where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call multiple times. If +// the commit fails with a rollback status (e.g. the transaction was already in a broken state) then ErrTxCommitRollback +// will be returned. func (tx *Tx) Commit(ctx context.Context) error { err := tx.t.Commit(ctx) if tx.c != nil { @@ -30,9 +31,9 @@ func (tx *Tx) Commit(ctx context.Context) error { return err } -// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return ErrTxClosed -// if the Tx is already closed, but is otherwise safe to call multiple times. Hence, defer tx.Rollback() is safe even if -// tx.Commit() will be called first in a non-error condition. +// Rollback rolls back the transaction and returns the associated connection back to the Pool. Rollback will return +// where an error where errors.Is(ErrTxClosed) is true if the Tx is already closed, but is otherwise safe to call +// multiple times. Hence, defer tx.Rollback() is safe even if tx.Commit() will be called first in a non-error condition. func (tx *Tx) Rollback(ctx context.Context) error { err := tx.t.Rollback(ctx) if tx.c != nil { diff --git a/query_test.go b/query_test.go index 88d7da718..46bc0efe4 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "database/sql" + "database/sql/driver" + "encoding/json" "errors" "fmt" "os" @@ -418,7 +420,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) { t.Fatal("Expected Rows to have an error after an improper read but it didn't") } - if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" { + if rows.Err().Error() != "can't scan into dest[0] (col: n): cannot scan int4 (OID 23) in binary format into *time.Time" { t.Fatalf("Expected different Rows.Err(): %v", rows.Err()) } @@ -560,13 +562,12 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } if _, ok := rows.Err().(*pgconn.PgError); !ok { - t.Fatalf("Expected pgx.PgError, got %v", rows.Err()) + t.Fatalf("Expected pgconn.PgError, got %v", rows.Err()) } ensureConnValid(t, conn) }() } - } func TestQueryEncodeError(t *testing.T) { @@ -1171,6 +1172,161 @@ func TestConnQueryDatabaseSQLDriverValuerWithAutoGeneratedPointerReceiver(t *tes ensureConnValid(t, conn) } +type nilPointerAsEmptyJSONObject struct { + ID string + Name string +} + +func (v *nilPointerAsEmptyJSONObject) Value() (driver.Value, error) { + if v == nil { + return "{}", nil + } + + return json.Marshal(v) +} + +// https://github.com/jackc/pgx/issues/1566 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilPointerImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v *nilPointerAsEmptyJSONObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = &nilPointerAsEmptyJSONObject{ID: "1", Name: "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 *nilPointerAsEmptyJSONObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilSliceAsEmptySlice []byte + +func (j nilSliceAsEmptySlice) Value() (driver.Value, error) { + if len(j) == 0 { + return []byte("[]"), nil + } + + return []byte(j), nil +} + +func (j *nilSliceAsEmptySlice) UnmarshalJSON(data []byte) error { + *j = bytes.Clone(data) + return nil +} + +// https://github.com/jackc/pgx/issues/1860 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilSliceImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilSliceAsEmptySlice + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "[]", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilSliceAsEmptySlice(`{"name": "foo"}`) + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilSliceAsEmptySlice + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + +type nilMapAsEmptyObject map[string]any + +func (j nilMapAsEmptyObject) Value() (driver.Value, error) { + if j == nil { + return []byte("{}"), nil + } + + return json.Marshal(j) +} + +func (j *nilMapAsEmptyObject) UnmarshalJSON(data []byte) error { + var m map[string]any + err := json.Unmarshal(data, &m) + if err != nil { + return err + } + + *j = m + + return nil +} + +// https://github.com/jackc/pgx/pull/2019#discussion_r1605806751 +func TestConnQueryDatabaseSQLDriverValuerCalledOnNilMapImplementers(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, "create temporary table t(v json not null)") + + var v nilMapAsEmptyObject + commandTag, err := conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var s string + err = conn.QueryRow(context.Background(), "select v from t").Scan(&s) + require.NoError(t, err) + require.Equal(t, "{}", s) + + _, err = conn.Exec(context.Background(), `delete from t`) + require.NoError(t, err) + + v = nilMapAsEmptyObject{"name": "foo"} + commandTag, err = conn.Exec(context.Background(), `insert into t(v) values($1)`, v) + require.NoError(t, err) + require.Equal(t, "INSERT 0 1", commandTag.String()) + + var v2 nilMapAsEmptyObject + err = conn.QueryRow(context.Background(), "select v from t").Scan(&v2) + require.NoError(t, err) + require.Equal(t, v, v2) + + ensureConnValid(t, conn) +} + func TestConnQueryDatabaseSQLDriverScannerWithBinaryPgTypeThatAcceptsSameType(t *testing.T) { t.Parallel() @@ -2051,7 +2207,10 @@ insert into products (name, price) values } rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare // cases it may be useful to detect the error as early as possible. if err != nil { diff --git a/rows.go b/rows.go index 4e46b7578..e9ceb5004 100644 --- a/rows.go +++ b/rows.go @@ -40,22 +40,19 @@ type Rows interface { // when there was an error executing the query. FieldDescriptions() []pgconn.FieldDescription - // Next prepares the next row for reading. It returns true if there is another - // row and false if no more rows are available or a fatal error has occurred. - // It automatically closes rows when all rows are read. + // Next prepares the next row for reading. It returns true if there is another row and false if no more rows are + // available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows + // having been read or due to an error). // - // Callers should check rows.Err() after rows.Next() returns false to detect - // whether result-set reading ended prematurely due to an error. See - // Conn.Query for details. + // Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended + // prematurely due to an error. See Conn.Query for details. // - // For simpler error handling, consider using the higher-level pgx v5 - // CollectRows() and ForEachRow() helpers instead. + // For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead. Next() bool - // Scan reads the values from the current row into dest values positionally. - // dest can include pointers to core types, values implementing the Scanner - // interface, and nil. nil will skip the value entirely. It is an error to - // call Scan without first calling Next() and checking that it returned true. + // Scan reads the values from the current row into dest values positionally. dest can include pointers to core types, + // values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan + // without first calling Next() and checking that it returned true. Rows is automatically closed upon error. Scan(dest ...any) error // Values returns the decoded row values. As with Scan(), it is an error to @@ -187,6 +184,17 @@ func (rows *baseRows) Close() { } else if rows.queryTracer != nil { rows.queryTracer.TraceQueryEnd(rows.ctx, rows.conn, TraceQueryEndData{rows.commandTag, rows.err}) } + + // Zero references to other memory allocations. This allows them to be GC'd even when the Rows still referenced. In + // particular, when using pgxpool GC could be delayed as pgxpool.poolRows are allocated in large slices. + // + // https://github.com/jackc/pgx/pull/2269 + rows.values = nil + rows.scanPlans = nil + rows.scanTypes = nil + rows.ctx = nil + rows.sql = "" + rows.args = nil } func (rows *baseRows) CommandTag() pgconn.CommandTag { @@ -271,7 +279,7 @@ func (rows *baseRows) Scan(dest ...any) error { err := rows.scanPlans[i].Scan(values[i], dst) if err != nil { - err = ScanArgError{ColumnIndex: i, Err: err} + err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} rows.fatal(err) return err } @@ -333,11 +341,16 @@ func (rows *baseRows) Conn() *Conn { type ScanArgError struct { ColumnIndex int + FieldName string Err error } func (e ScanArgError) Error() string { - return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown + return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err) + } + + return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err) } func (e ScanArgError) Unwrap() error { @@ -365,7 +378,7 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d) if err != nil { - return ScanArgError{ColumnIndex: i, Err: err} + return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err} } } @@ -418,6 +431,8 @@ type CollectableRow interface { type RowToFunc[T any] func(row CollectableRow) (T, error) // AppendRows iterates through rows, calling fn for each row, and appending the results into a slice of T. +// +// This function closes the rows automatically on return. func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { defer rows.Close() @@ -437,12 +452,16 @@ func AppendRows[T any, S ~[]T](slice S, rows Rows, fn RowToFunc[T]) (S, error) { } // CollectRows iterates through rows, calling fn for each row, and collecting the results into a slice of T. +// +// This function closes the rows automatically on return. func CollectRows[T any](rows Rows, fn RowToFunc[T]) ([]T, error) { - return AppendRows([]T(nil), rows, fn) + return AppendRows([]T{}, rows, fn) } // CollectOneRow calls fn for the first row in rows and returns the result. If no rows are found returns an error where errors.Is(ErrNoRows) is true. // CollectOneRow is to CollectRows as QueryRow is to Query. +// +// This function closes the rows automatically on return. func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { defer rows.Close() @@ -461,6 +480,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { return value, err } + // The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open. + // rows.Close() must be called before rows.Err() so we explicitly call it here. rows.Close() return value, rows.Err() } @@ -468,6 +489,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { // CollectExactlyOneRow calls fn for the first row in rows and returns the result. // - If no rows are found returns an error where errors.Is(ErrNoRows) is true. // - If more than 1 row is found returns an error where errors.Is(ErrTooManyRows) is true. +// +// This function closes the rows automatically on return. func CollectExactlyOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) { defer rows.Close() @@ -536,12 +559,12 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error { return nil } -// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row +// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be // ignored. func RowToStructByPos[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } @@ -550,7 +573,7 @@ func RowToStructByPos[T any](row CollectableRow) (T, error) { // the field will be ignored. func RowToAddrOfStructByPos[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&positionalStructRowScanner{ptrToStruct: &value}) + err := (&positionalStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } @@ -558,46 +581,60 @@ type positionalStructRowScanner struct { ptrToStruct any } -func (rs *positionalStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") +func (rs *positionalStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fields := lookupStructFields(typ) + if len(rows.RawValues()) > len(fields) { + return fmt.Errorf( + "got %d values, but dst struct has only %d fields", + len(rows.RawValues()), + len(fields), + ) } - - dstElemValue := dstValue.Elem() - scanTargets := rs.appendScanTargets(dstElemValue, nil) - - if len(rows.RawValues()) > len(scanTargets) { - return fmt.Errorf("got %d values, but dst struct has only %d fields", len(rows.RawValues()), len(scanTargets)) - } - + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) return rows.Scan(scanTargets...) } -func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any) []any { - dstElemType := dstElemValue.Type() +// Map from reflect.Type -> []structRowField +var positionalStructFieldMap sync.Map - if scanTargets == nil { - scanTargets = make([]any, 0, dstElemType.NumField()) +func lookupStructFields(t reflect.Type) []structRowField { + if cached, ok := positionalStructFieldMap.Load(t); ok { + return cached.([]structRowField) } - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) + fieldStack := make([]int, 0, 1) + fields := computeStructFields(t, make([]structRowField, 0, t.NumField()), &fieldStack) + fieldsIface, _ := positionalStructFieldMap.LoadOrStore(t, fields) + return fieldsIface.([]structRowField) +} + +func computeStructFields( + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) []structRowField { + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets = rs.appendScanTargets(dstElemValue.Field(i), scanTargets) + fields = computeStructFields(sf.Type, fields, fieldStack) } else if sf.PkgPath == "" { dbTag, _ := sf.Tag.Lookup(structTagKey) if dbTag == "-" { // Field is ignored, skip it. continue } - scanTargets = append(scanTargets, dstElemValue.Field(i).Addr().Interface()) + fields = append(fields, structRowField{ + path: append([]int(nil), *fieldStack...), + }) } } - - return scanTargets + *fieldStack = (*fieldStack)[:tail] + return fields } // RowToStructByName returns a T scanned from row. T must be a struct. T must have the same number of named public @@ -605,7 +642,7 @@ func (rs *positionalStructRowScanner) appendScanTargets(dstElemValue reflect.Val // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByName[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return value, err } @@ -615,7 +652,7 @@ func RowToStructByName[T any](row CollectableRow) (T, error) { // then the field will be ignored. func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value}) + err := (&namedStructRowScanner{ptrToStruct: &value}).ScanRow(row) return &value, err } @@ -624,7 +661,7 @@ func RowToAddrOfStructByName[T any](row CollectableRow) (*T, error) { // column name can be overridden with a "db" struct tag. If the "db" struct tag is "-" then the field will be ignored. func RowToStructByNameLax[T any](row CollectableRow) (T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return value, err } @@ -634,7 +671,7 @@ func RowToStructByNameLax[T any](row CollectableRow) (T, error) { // then the field will be ignored. func RowToAddrOfStructByNameLax[T any](row CollectableRow) (*T, error) { var value T - err := row.Scan(&namedStructRowScanner{ptrToStruct: &value, lax: true}) + err := (&namedStructRowScanner{ptrToStruct: &value, lax: true}).ScanRow(row) return &value, err } @@ -643,64 +680,123 @@ type namedStructRowScanner struct { lax bool } -func (rs *namedStructRowScanner) ScanRow(rows Rows) error { - dst := rs.ptrToStruct - dstValue := reflect.ValueOf(dst) - if dstValue.Kind() != reflect.Ptr { - return fmt.Errorf("dst not a pointer") - } - - dstElemValue := dstValue.Elem() - scanTargets, err := rs.appendScanTargets(dstElemValue, nil, rows.FieldDescriptions()) +func (rs *namedStructRowScanner) ScanRow(rows CollectableRow) error { + typ := reflect.TypeOf(rs.ptrToStruct).Elem() + fldDescs := rows.FieldDescriptions() + namedStructFields, err := lookupNamedStructFields(typ, fldDescs) if err != nil { return err } - - for i, t := range scanTargets { - if t == nil { - return fmt.Errorf("struct doesn't have corresponding row field %s", rows.FieldDescriptions()[i].Name) - } + if !rs.lax && namedStructFields.missingField != "" { + return fmt.Errorf("cannot find field %s in returned row", namedStructFields.missingField) } - + fields := namedStructFields.fields + scanTargets := setupStructScanTargets(rs.ptrToStruct, fields) return rows.Scan(scanTargets...) } -const structTagKey = "db" +// Map from namedStructFieldMap -> *namedStructFields +var namedStructFieldMap sync.Map -func fieldPosByName(fldDescs []pgconn.FieldDescription, field string) (i int) { - i = -1 - for i, desc := range fldDescs { - - // Snake case support. - field = strings.ReplaceAll(field, "_", "") - descName := strings.ReplaceAll(desc.Name, "_", "") +type namedStructFieldsKey struct { + t reflect.Type + colNames string +} - if strings.EqualFold(descName, field) { - return i - } - } - return +type namedStructFields struct { + fields []structRowField + // missingField is the first field from the struct without a corresponding row field. + // This is used to construct the correct error message for non-lax queries. + missingField string } -func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, scanTargets []any, fldDescs []pgconn.FieldDescription) ([]any, error) { - var err error - dstElemType := dstElemValue.Type() +func lookupNamedStructFields( + t reflect.Type, + fldDescs []pgconn.FieldDescription, +) (*namedStructFields, error) { + key := namedStructFieldsKey{ + t: t, + colNames: joinFieldNames(fldDescs), + } + if cached, ok := namedStructFieldMap.Load(key); ok { + return cached.(*namedStructFields), nil + } - if scanTargets == nil { - scanTargets = make([]any, len(fldDescs)) + // We could probably do two-levels of caching, where we compute the key -> fields mapping + // for a type only once, cache it by type, then use that to compute the column -> fields + // mapping for a given set of columns. + fieldStack := make([]int, 0, 1) + fields, missingField := computeNamedStructFields( + fldDescs, + t, + make([]structRowField, len(fldDescs)), + &fieldStack, + ) + for i, f := range fields { + if f.path == nil { + return nil, fmt.Errorf( + "struct doesn't have corresponding row field %s", + fldDescs[i].Name, + ) + } } - for i := 0; i < dstElemType.NumField(); i++ { - sf := dstElemType.Field(i) + fieldsIface, _ := namedStructFieldMap.LoadOrStore( + key, + &namedStructFields{fields: fields, missingField: missingField}, + ) + return fieldsIface.(*namedStructFields), nil +} + +func joinFieldNames(fldDescs []pgconn.FieldDescription) string { + switch len(fldDescs) { + case 0: + return "" + case 1: + return fldDescs[0].Name + } + + totalSize := len(fldDescs) - 1 // Space for separator bytes. + for _, d := range fldDescs { + totalSize += len(d.Name) + } + var b strings.Builder + b.Grow(totalSize) + b.WriteString(fldDescs[0].Name) + for _, d := range fldDescs[1:] { + b.WriteByte(0) // Join with NUL byte as it's (presumably) not a valid column character. + b.WriteString(d.Name) + } + return b.String() +} + +func computeNamedStructFields( + fldDescs []pgconn.FieldDescription, + t reflect.Type, + fields []structRowField, + fieldStack *[]int, +) ([]structRowField, string) { + var missingField string + tail := len(*fieldStack) + *fieldStack = append(*fieldStack, 0) + for i := 0; i < t.NumField(); i++ { + sf := t.Field(i) + (*fieldStack)[tail] = i if sf.PkgPath != "" && !sf.Anonymous { // Field is unexported, skip it. continue } // Handle anonymous struct embedding, but do not try to handle embedded pointers. if sf.Anonymous && sf.Type.Kind() == reflect.Struct { - scanTargets, err = rs.appendScanTargets(dstElemValue.Field(i), scanTargets, fldDescs) - if err != nil { - return nil, err + var missingSubField string + fields, missingSubField = computeNamedStructFields( + fldDescs, + sf.Type, + fields, + fieldStack, + ) + if missingField == "" { + missingField = missingSubField } } else { dbTag, dbTagPresent := sf.Tag.Lookup(structTagKey) @@ -715,19 +811,60 @@ func (rs *namedStructRowScanner) appendScanTargets(dstElemValue reflect.Value, s if !dbTagPresent { colName = sf.Name } - fpos := fieldPosByName(fldDescs, colName) + fpos := fieldPosByName(fldDescs, colName, !dbTagPresent) if fpos == -1 { - if rs.lax { - continue + if missingField == "" { + missingField = colName } - return nil, fmt.Errorf("cannot find field %s in returned row", colName) + continue + } + fields[fpos] = structRowField{ + path: append([]int(nil), *fieldStack...), + } + } + } + *fieldStack = (*fieldStack)[:tail] + + return fields, missingField +} + +const structTagKey = "db" + +func fieldPosByName(fldDescs []pgconn.FieldDescription, field string, normalize bool) (i int) { + i = -1 + + if normalize { + field = strings.ReplaceAll(field, "_", "") + } + for i, desc := range fldDescs { + if normalize { + if strings.EqualFold(strings.ReplaceAll(desc.Name, "_", ""), field) { + return i } - if fpos >= len(scanTargets) && !rs.lax { - return nil, fmt.Errorf("cannot find field %s in returned row", colName) + } else { + if desc.Name == field { + return i } - scanTargets[fpos] = dstElemValue.Field(i).Addr().Interface() } } + return +} - return scanTargets, err +// structRowField describes a field of a struct. +// +// TODO: It would be a bit more efficient to track the path using the pointer +// offset within the (outermost) struct and use unsafe.Pointer arithmetic to +// construct references when scanning rows. However, it's not clear it's worth +// using unsafe for this. +type structRowField struct { + path []int +} + +func setupStructScanTargets(receiver any, fields []structRowField) []any { + scanTargets := make([]any, len(fields)) + v := reflect.ValueOf(receiver).Elem() + for i, f := range fields { + scanTargets[i] = v.FieldByIndex(f.path).Addr().Interface() + } + return scanTargets } diff --git a/rows_test.go b/rows_test.go index 6d3d25c4e..5ba699668 100644 --- a/rows_test.go +++ b/rows_test.go @@ -175,6 +175,21 @@ func TestCollectRows(t *testing.T) { }) } +func TestCollectRowsEmpty(t *testing.T) { + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select n from generate_series(1, 0) n`) + numbers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (int32, error) { + var n int32 + err := row.Scan(&n) + return n, err + }) + require.NoError(t, err) + require.NotNil(t, numbers) + + assert.Empty(t, numbers) + }) +} + // This example uses CollectRows with a manually written collector function. In most cases RowTo, RowToAddrOf, // RowToStructByPos, RowToAddrOfStructByPos, or another generic function would be used. func ExampleCollectRows() { @@ -652,6 +667,41 @@ func TestRowToStructByName(t *testing.T) { }) } +func TestRowToStructByNameDbTags(t *testing.T) { + type person struct { + Last string `db:"last_name"` + First string `db:"first_name"` + Age int32 `db:"age"` + AccountID string `db:"account_id"` + AnotherAccountID string `db:"account__id"` + } + + defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + rows, _ := conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id from generate_series(0, 9) n`) + slice, err := pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.NoError(t, err) + + assert.Len(t, slice, 10) + for i := range slice { + assert.Equal(t, "Smith", slice[i].Last) + assert.Equal(t, "John", slice[i].First) + assert.EqualValues(t, i, slice[i].Age) + assert.Equal(t, "d5e49d3f", slice[i].AccountID) + assert.Equal(t, "5e49d321", slice[i].AnotherAccountID) + } + + // check missing fields in a returned row + rows, _ = conn.Query(ctx, `select 'Smith' as last_name, n as age from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToStructByName[person]) + assert.ErrorContains(t, err, "cannot find field first_name in returned row") + + // check missing field in a destination struct + rows, _ = conn.Query(ctx, `select 'John' as first_name, 'Smith' as last_name, n as age, 'd5e49d3f' as account_id, '5e49d321' as account__id, null as ignore from generate_series(0, 9) n`) + _, err = pgx.CollectRows(rows, pgx.RowToAddrOfStructByName[person]) + assert.ErrorContains(t, err, "struct doesn't have corresponding row field ignore") + }) +} + func TestRowToStructByNameEmbeddedStruct(t *testing.T) { type Name struct { Last string `db:"last_name"` diff --git a/stdlib/sql.go b/stdlib/sql.go index c92abc997..0154e787a 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -7,7 +7,7 @@ // return err // } // -// Or from a DSN string. +// Or from a keyword/value string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5433 database=pgx_test sslmode=disable") // if err != nil { @@ -75,6 +75,7 @@ import ( "math" "math/rand" "reflect" + "slices" "strconv" "strings" "sync" @@ -98,7 +99,7 @@ func init() { // if pgx driver was already registered by different pgx major version then we // skip registration under the default name. - if !contains(sql.Drivers(), "pgx") { + if !slices.Contains(sql.Drivers(), "pgx") { sql.Register("pgx", pgxDriver) } sql.Register("pgx/v5", pgxDriver) @@ -120,17 +121,6 @@ func init() { } } -// TODO replace by slices.Contains when experimental package will be merged to stdlib -// https://pkg.go.dev/golang.org/x/exp/slices#Contains -func contains(list []string, y string) bool { - for _, x := range list { - if x == y { - return true - } - } - return false -} - // OptionOpenDB options for configuring the driver when opening a new db pool. type OptionOpenDB func(*connector) @@ -226,7 +216,8 @@ func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { // OpenDBFromPool creates a new *sql.DB from the given *pgxpool.Pool. Note that this method automatically sets the // maximum number of idle connections in *sql.DB to zero, since they must be managed from the *pgxpool.Pool. This is -// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. +// required to avoid acquiring all the connections from the pgxpool and starving any direct users of the pgxpool. Note +// that closing the returned *sql.DB will not close the *pgxpool.Pool. func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB { c := GetPoolConnector(pool, opts...) db := sql.OpenDB(c) @@ -480,7 +471,8 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam return nil, driver.ErrBadConn } - args := namedValueToInterface(argsV) + args := make([]any, len(argsV)) + convertNamedArguments(args, argsV) commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry @@ -497,8 +489,9 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - args := []any{databaseSQLResultFormats} - args = append(args, namedValueToInterface(argsV)...) + args := make([]any, 1+len(argsV)) + args[0] = databaseSQLResultFormats + convertNamedArguments(args[1:], argsV) rows, err := c.conn.Query(ctx, query, args...) if err != nil { @@ -805,6 +798,16 @@ func (r *Rows) Next(dest []driver.Value) error { } return d.Value() } + case pgtype.XMLOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } default: var d string scanPlan := m.PlanScan(dataTypeOID, format, &d) @@ -847,28 +850,14 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } -func valueToInterface(argsV []driver.Value) []any { - args := make([]any, 0, len(argsV)) - for _, v := range argsV { - if v != nil { - args = append(args, v.(any)) - } else { - args = append(args, nil) - } - } - return args -} - -func namedValueToInterface(argsV []driver.NamedValue) []any { - args := make([]any, 0, len(argsV)) - for _, v := range argsV { +func convertNamedArguments(args []any, argsV []driver.NamedValue) { + for i, v := range argsV { if v.Value != nil { - args = append(args, v.Value.(any)) + args[i] = v.Value.(any) } else { - args = append(args, nil) + args[i] = nil } } - return args } type wrapTx struct { diff --git a/testsetup/generate_certs.go b/testsetup/generate_certs.go index 945c6c5e2..7f478d4f5 100644 --- a/testsetup/generate_certs.go +++ b/testsetup/generate_certs.go @@ -106,12 +106,12 @@ func main() { panic(err) } - writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw") + err = writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw") if err != nil { panic(err) } - writeCertificate("pgx_sslcert.crt", clientBytes) + err = writeCertificate("pgx_sslcert.crt", clientBytes) if err != nil { panic(err) } @@ -161,7 +161,10 @@ func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password } return nil +<<<<<<< HEAD +======= +>>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } func writeCertificate(path string, certBytes []byte) error { diff --git a/tracelog/tracelog.go b/tracelog/tracelog.go index d45143923..6cec720c7 100644 --- a/tracelog/tracelog.go +++ b/tracelog/tracelog.go @@ -6,10 +6,12 @@ import ( "encoding/hex" "errors" "fmt" + "sync" "time" "unicode/utf8" "github.com/yugabyte/pgx/v5" + "github.com/yugabyte/pgx/v5/pgxpool" ) // LogLevel represents the pgx logging level. See LogLevel* constants for @@ -102,7 +104,7 @@ func logQueryArgs(args []any) []any { } case string: if len(v) > 64 { - var l int = 0 + l := 0 for w := 0; l < 64; l += w { _, w = utf8.DecodeRuneInString(v[l:]) } @@ -117,11 +119,38 @@ func logQueryArgs(args []any) []any { return logArgs } -// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, and pgx.CopyFromTracer. All fields are -// required. +// TraceLogConfig holds the configuration for key names +type TraceLogConfig struct { + TimeKey string +} + +// DefaultTraceLogConfig returns the default configuration for TraceLog +func DefaultTraceLogConfig() *TraceLogConfig { + return &TraceLogConfig{ + TimeKey: "time", + } +} + +// TraceLog implements pgx.QueryTracer, pgx.BatchTracer, pgx.ConnectTracer, pgx.CopyFromTracer, pgxpool.AcquireTracer, +// and pgxpool.ReleaseTracer. Logger and LogLevel are required. Config will be automatically initialized on the +// first use if nil. type TraceLog struct { Logger Logger LogLevel LogLevel + + Config *TraceLogConfig + ensureConfigOnce sync.Once +} + +// ensureConfig initializes the Config field with default values if it is nil. +func (tl *TraceLog) ensureConfig() { + tl.ensureConfigOnce.Do( + func() { + if tl.Config == nil { + tl.Config = DefaultTraceLogConfig() + } + }, + ) } type ctxKey int @@ -133,6 +162,7 @@ const ( tracelogCopyFromCtxKey tracelogConnectCtxKey tracelogPrepareCtxKey + tracelogAcquireCtxKey ) type traceQueryData struct { @@ -150,6 +180,7 @@ func (tl *TraceLog) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pg } func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + tl.ensureConfig() queryData := ctx.Value(tracelogQueryCtxKey).(*traceQueryData) endTime := time.Now() @@ -157,13 +188,13 @@ func (tl *TraceLog) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx. if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), "time": interval, "commandTag": data.CommandTag.String()}) + tl.log(ctx, conn, LogLevelInfo, "Query", map[string]any{"sql": queryData.sql, "args": logQueryArgs(queryData.args), tl.Config.TimeKey: interval, "commandTag": data.CommandTag.String()}) } } @@ -191,6 +222,7 @@ func (tl *TraceLog) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pg } func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) { + tl.ensureConfig() queryData := ctx.Value(tracelogBatchCtxKey).(*traceBatchData) endTime := time.Now() @@ -198,13 +230,13 @@ func (tl *TraceLog) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx. if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "BatchClose", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{"time": interval}) + tl.log(ctx, conn, LogLevelInfo, "BatchClose", map[string]any{tl.Config.TimeKey: interval}) } } @@ -223,6 +255,7 @@ func (tl *TraceLog) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data } func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) { + tl.ensureConfig() copyFromData := ctx.Value(tracelogCopyFromCtxKey).(*traceCopyFromData) endTime := time.Now() @@ -230,13 +263,13 @@ func (tl *TraceLog) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data p if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, "time": interval, "rowCount": data.CommandTag.RowsAffected()}) + tl.log(ctx, conn, LogLevelInfo, "CopyFrom", map[string]any{"tableName": copyFromData.TableName, "columnNames": copyFromData.ColumnNames, "err": data.Err, tl.Config.TimeKey: interval, "rowCount": data.CommandTag.RowsAffected()}) } } @@ -253,6 +286,7 @@ func (tl *TraceLog) TraceConnectStart(ctx context.Context, data pgx.TraceConnect } func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) { + tl.ensureConfig() connectData := ctx.Value(tracelogConnectCtxKey).(*traceConnectData) endTime := time.Now() @@ -261,11 +295,11 @@ func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEn if data.Err != nil { if tl.shouldLog(LogLevelError) { tl.Logger.Log(ctx, LogLevelError, "Connect", map[string]any{ - "host": connectData.connConfig.Host, - "port": connectData.connConfig.Port, - "database": connectData.connConfig.Database, - "time": interval, - "err": data.Err, + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, + "err": data.Err, }) } return @@ -274,10 +308,10 @@ func (tl *TraceLog) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEn if data.Conn != nil { if tl.shouldLog(LogLevelInfo) { tl.log(ctx, data.Conn, LogLevelInfo, "Connect", map[string]any{ - "host": connectData.connConfig.Host, - "port": connectData.connConfig.Port, - "database": connectData.connConfig.Database, - "time": interval, + "host": connectData.connConfig.Host, + "port": connectData.connConfig.Port, + "database": connectData.connConfig.Database, + tl.Config.TimeKey: interval, }) } } @@ -298,6 +332,7 @@ func (tl *TraceLog) TracePrepareStart(ctx context.Context, _ *pgx.Conn, data pgx } func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) { + tl.ensureConfig() prepareData := ctx.Value(tracelogPrepareCtxKey).(*tracePrepareData) endTime := time.Now() @@ -305,13 +340,51 @@ func (tl *TraceLog) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pg if data.Err != nil { if tl.shouldLog(LogLevelError) { - tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, "time": interval}) + tl.log(ctx, conn, LogLevelError, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "err": data.Err, tl.Config.TimeKey: interval}) } return } if tl.shouldLog(LogLevelInfo) { - tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, "time": interval, "alreadyPrepared": data.AlreadyPrepared}) + tl.log(ctx, conn, LogLevelInfo, "Prepare", map[string]any{"name": prepareData.name, "sql": prepareData.sql, tl.Config.TimeKey: interval, "alreadyPrepared": data.AlreadyPrepared}) + } +} + +type traceAcquireData struct { + startTime time.Time +} + +func (tl *TraceLog) TraceAcquireStart(ctx context.Context, _ *pgxpool.Pool, _ pgxpool.TraceAcquireStartData) context.Context { + return context.WithValue(ctx, tracelogAcquireCtxKey, &traceAcquireData{ + startTime: time.Now(), + }) +} + +func (tl *TraceLog) TraceAcquireEnd(ctx context.Context, _ *pgxpool.Pool, data pgxpool.TraceAcquireEndData) { + tl.ensureConfig() + acquireData := ctx.Value(tracelogAcquireCtxKey).(*traceAcquireData) + + endTime := time.Now() + interval := endTime.Sub(acquireData.startTime) + + if data.Err != nil { + if tl.shouldLog(LogLevelError) { + tl.Logger.Log(ctx, LogLevelError, "Acquire", map[string]any{"err": data.Err, tl.Config.TimeKey: interval}) + } + return + } + + if data.Conn != nil { + if tl.shouldLog(LogLevelDebug) { + tl.log(ctx, data.Conn, LogLevelDebug, "Acquire", map[string]any{tl.Config.TimeKey: interval}) + } + } +} + +func (tl *TraceLog) TraceRelease(_ *pgxpool.Pool, data pgxpool.TraceReleaseData) { + if tl.shouldLog(LogLevelDebug) { + // there is no context on the TraceRelease callback + tl.log(context.Background(), data.Conn, LogLevelDebug, "Release", map[string]any{}) } } diff --git a/tracelog/tracelog_test.go b/tracelog/tracelog_test.go index 4f4ec8c9a..7180e8cf6 100644 --- a/tracelog/tracelog_test.go +++ b/tracelog/tracelog_test.go @@ -6,14 +6,17 @@ import ( "log" "os" "strings" + "sync" "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/yugabyte/pgx/v5" + "github.com/yugabyte/pgx/v5/pgxpool" "github.com/yugabyte/pgx/v5/pgxtest" "github.com/yugabyte/pgx/v5/tracelog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" ) var defaultConnTestRunner pgxtest.ConnTestRunner @@ -35,18 +38,29 @@ type testLog struct { type testLogger struct { logs []testLog + + mux sync.Mutex } func (l *testLogger) Log(ctx context.Context, level tracelog.LogLevel, msg string, data map[string]any) { + l.mux.Lock() + defer l.mux.Unlock() + data["ctxdata"] = ctx.Value("ctxdata") l.logs = append(l.logs, testLog{lvl: level, msg: msg, data: data}) } func (l *testLogger) Clear() { + l.mux.Lock() + defer l.mux.Unlock() + l.logs = l.logs[0:0] } func (l *testLogger) FilterByMsg(msg string) (res []testLog) { + l.mux.Lock() + defer l.mux.Unlock() + for _, log := range l.logs { if log.msg == msg { res = append(res, log) @@ -348,7 +362,6 @@ func TestLogBatchStatementsOnExec(t *testing.T) { assert.Equal(t, "BatchQuery", logger.logs[1].msg) assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) assert.Equal(t, "BatchClose", logger.logs[2].msg) - }) } @@ -391,6 +404,88 @@ func TestLogBatchStatementsOnBatchResultClose(t *testing.T) { }) } +func TestLogAcquire(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + + poolConfig, err := pgxpool.ParseConfig(config.ConnString()) + require.NoError(t, err) + + poolConfig.ConnConfig = config + pool1, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool1.Close() + + conn1, err := pool1.Acquire(ctx) + require.NoError(t, err) + defer conn1.Release() + require.Len(t, logger.logs, 2) // Has both the Connect and Acquire logs + require.Equal(t, "Acquire", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelDebug, logger.logs[1].lvl) + + logger.Clear() + + // create a 2nd pool with a bad host to verify the error handling + poolConfig, err = pgxpool.ParseConfig("host=/invalid") + require.NoError(t, err) + poolConfig.ConnConfig.Tracer = tracer + + pool2, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool2.Close() + + conn2, err := pool2.Acquire(ctx) + require.Error(t, err) + require.Nil(t, conn2) + require.Len(t, logger.logs, 2) + require.Equal(t, "Acquire", logger.logs[1].msg) + require.Equal(t, tracelog.LogLevelError, logger.logs[1].lvl) +} + +func TestLogRelease(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + + poolConfig, err := pgxpool.ParseConfig(config.ConnString()) + require.NoError(t, err) + + poolConfig.ConnConfig = config + pool1, err := pgxpool.NewWithConfig(ctx, poolConfig) + require.NoError(t, err) + defer pool1.Close() + + conn1, err := pool1.Acquire(ctx) + require.NoError(t, err) + + logger.Clear() + conn1.Release() + require.Len(t, logger.logs, 1) + require.Equal(t, "Release", logger.logs[0].msg) + require.Equal(t, tracelog.LogLevelDebug, logger.logs[0].lvl) +} + func TestLogPrepare(t *testing.T) { t.Parallel() @@ -457,3 +552,42 @@ func TestLogPrepare(t *testing.T) { require.Equal(t, err, logger.logs[0].data["err"]) }) } + +// https://github.com/jackc/pgx/pull/2120 +func TestConcurrentUsage(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + logger := &testLogger{} + tracer := &tracelog.TraceLog{ + Logger: logger, + LogLevel: tracelog.LogLevelTrace, + } + + config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + config.ConnConfig.Tracer = tracer + + for i := 0; i < 50; i++ { + func() { + pool, err := pgxpool.NewWithConfig(ctx, config) + require.NoError(t, err) + + defer pool.Close() + + eg := errgroup.Group{} + + for i := 0; i < 5; i++ { + eg.Go(func() error { + _, err := pool.Exec(ctx, `select 1`) + return err + }) + } + + err = eg.Wait() + require.NoError(t, err) + }() + } +} diff --git a/tracer_test.go b/tracer_test.go index 4b73884c1..413629dea 100644 --- a/tracer_test.go +++ b/tracer_test.go @@ -566,3 +566,43 @@ func TestTraceConnect(t *testing.T) { require.True(t, traceConnectStartCalled) require.True(t, traceConnectEndCalled) } + +// Ensure tracer runs within a transaction. +// +// https://github.com/jackc/pgx/issues/2304 +func TestTraceWithinTx(t *testing.T) { + t.Parallel() + + tracer := &testTracer{} + + ctr := defaultConnTestRunner + ctr.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig { + config := defaultConnTestRunner.CreateConfig(ctx, t) + config.Tracer = tracer + return config + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var queries []string + tracer.traceQueryStart = func(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + queries = append(queries, data.SQL) + return ctx + } + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + _, err = tx.Exec(ctx, `select $1::text`, "testing") + require.NoError(t, err) + err = tx.Commit(ctx) + require.NoError(t, err) + + require.Len(t, queries, 3) + require.Equal(t, `begin`, queries[0]) + require.Equal(t, `select $1::text`, queries[1]) + require.Equal(t, `commit`, queries[2]) + }) +} diff --git a/tx.go b/tx.go index 3d674f623..9f1d906eb 100644 --- a/tx.go +++ b/tx.go @@ -3,7 +3,6 @@ package pgx import ( "context" "errors" - "fmt" "strconv" "strings" @@ -48,6 +47,8 @@ type TxOptions struct { // BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax // such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings. BeginQuery string + // CommitQuery is the SQL query that will be executed to commit the transaction. + CommitQuery string } var emptyTxOptions TxOptions @@ -101,11 +102,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) { if err != nil { // begin should never fail unless there is an underlying connection issue or // a context timeout. In either case, the connection is possibly broken. - c.die(errors.New("failed to begin transaction")) + c.die() return nil, err } - return &dbTx{conn: c}, nil + return &dbTx{ + conn: c, + commitQuery: txOptions.CommitQuery, + }, nil } // Tx represents a database transaction. @@ -154,6 +158,7 @@ type dbTx struct { conn *Conn savepointNum int64 closed bool + commitQuery string } // Begin starts a pseudo nested transaction implemented with a savepoint. @@ -177,7 +182,12 @@ func (tx *dbTx) Commit(ctx context.Context) error { return ErrTxClosed } - commandTag, err := tx.conn.Exec(ctx, "commit") + commandSQL := "commit" + if tx.commitQuery != "" { + commandSQL = tx.commitQuery + } + + commandTag, err := tx.conn.Exec(ctx, commandSQL) tx.closed = true if err != nil { if tx.conn.PgConn().TxStatus() != 'I' { @@ -205,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error { tx.closed = true if err != nil { // A rollback failure leaves the connection in an undefined state - tx.conn.die(fmt.Errorf("rollback failed: %w", err)) + tx.conn.die() return err } diff --git a/values.go b/values.go index dbb756495..6ea0600fb 100644 --- a/values.go +++ b/values.go @@ -15,10 +15,6 @@ const ( ) func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { - if anynil.Is(arg) { - return nil, nil - } - buf, err := m.Encode(0, TextFormatCode, arg, []byte{}) if err != nil { return nil, err @@ -30,10 +26,6 @@ func convertSimpleArgument(m *pgtype.Map, arg any) (any, error) { } func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { - if anynil.Is(arg) { - return pgio.AppendInt32(buf, -1), nil - } - sp := len(buf) buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) diff --git a/values_test.go b/values_test.go index fda033b29..55b623692 100644 --- a/values_test.go +++ b/values_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "bytes" "context" + "fmt" "net" "os" "reflect" @@ -116,7 +117,6 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) testJSONStruct(t, conn, typename) } - } func testJSONString(t testing.TB, conn *pgx.Conn, typename string) { @@ -215,8 +215,13 @@ func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typena input := []int{1, 2, 234432} var output []int16 err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output) - if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" { - t.Errorf("%s: Expected *json.UnmarkalTypeError, but got %v", typename, err) + fieldName := typename + if conn.PgConn().ParameterStatus("crdb_version") != "" && typename == "json" { + fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb. + } + expectedMessage := fmt.Sprintf("can't scan into dest[0] (col: %s): json: cannot unmarshal number 234432 into Go value of type int16", fieldName) + if err == nil || err.Error() != expectedMessage { + t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err) } } @@ -591,7 +596,9 @@ func TestArrayDecoding(t *testing.T) { assert func(testing.TB, any, any) }{ { - "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + "select $1::bool[]", + []bool{true, false, true}, + &[]bool{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) { t.Errorf("failed to encode bool[]") @@ -599,7 +606,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, + "select $1::smallint[]", + []int16{2, 4, 484, 32767}, + &[]int16{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) { t.Errorf("failed to encode smallint[]") @@ -607,7 +616,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, + "select $1::smallint[]", + []uint16{2, 4, 484, 32767}, + &[]uint16{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { t.Errorf("failed to encode smallint[]") @@ -615,7 +626,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + "select $1::int[]", + []int32{2, 4, 484}, + &[]int32{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) { t.Errorf("failed to encode int[]") @@ -623,7 +636,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, + "select $1::int[]", + []uint32{2, 4, 484, 2147483647}, + &[]uint32{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { t.Errorf("failed to encode int[]") @@ -631,7 +646,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, + "select $1::bigint[]", + []int64{2, 4, 484, 9223372036854775807}, + &[]int64{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) { t.Errorf("failed to encode bigint[]") @@ -639,7 +656,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, + "select $1::bigint[]", + []uint64{2, 4, 484, 9223372036854775807}, + &[]uint64{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { t.Errorf("failed to encode bigint[]") @@ -647,7 +666,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + "select $1::text[]", + []string{"it's", "over", "9000!"}, + &[]string{}, func(t testing.TB, query, scan any) { if !reflect.DeepEqual(query, *(scan.(*[]string))) { t.Errorf("failed to encode text[]") @@ -655,7 +676,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + "select $1::timestamptz[]", + []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 0o0)}, + &[]time.Time{}, func(t testing.TB, query, scan any) { queryTimeSlice := query.([]time.Time) scanTimeSlice := *(scan.(*[]time.Time)) @@ -666,7 +689,9 @@ func TestArrayDecoding(t *testing.T) { }, }, { - "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, + "select $1::bytea[]", + [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, + &[][]byte{}, func(t testing.TB, query, scan any) { queryBytesSliceSlice := query.([][]byte) scanBytesSliceSlice := *(scan.(*[][]byte)) From 5ac3b0d069dceffcd3a99f1f1a2d62610f3dfed1 Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Thu, 30 Oct 2025 14:19:11 +0000 Subject: [PATCH 2/6] Resolve merge conflicts --- conn.go | 24 ++++-------------------- derived_types.go | 2 +- internal/sanitize/sanitize.go | 7 ------- pgconn/pgconn.go | 2 +- pgproto3/bind.go | 2 ++ pgproto3/copy_both_response.go | 1 + pgproto3/copy_in_response.go | 1 + pgproto3/copy_out_response.go | 1 + pgproto3/parse.go | 5 +++-- pgproto3/sasl_initial_response.go | 2 +- pgproto3/startup_message.go | 6 +++--- pgtype/bool.go | 8 -------- pgtype/box.go | 4 ---- pgtype/composite.go | 4 ---- pgtype/date.go | 4 ---- pgtype/float8.go | 4 ---- pgtype/hstore.go | 4 ---- pgtype/inet.go | 4 ---- pgtype/line.go | 7 +++---- pgtype/lseg.go | 4 ---- pgtype/numeric.go | 4 ---- pgtype/path.go | 4 ---- pgtype/point.go | 4 ---- pgtype/polygon.go | 4 ---- pgtype/range.go | 4 ---- pgtype/uint64.go | 2 +- pgxpool/pool.go | 2 +- pgxpool/tracer.go | 2 +- rows.go | 1 + values.go | 1 - 30 files changed, 25 insertions(+), 99 deletions(-) diff --git a/conn.go b/conn.go index 07223db22..c0ed4566e 100644 --- a/conn.go +++ b/conn.go @@ -32,10 +32,6 @@ type ConnConfig struct { // query exec mode. StatementCacheCapacity int - // StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement" - // query exec mode. - StatementCacheCapacity int - // DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with // "cache_describe" query exec mode. DescriptionCacheCapacity int @@ -60,11 +56,6 @@ type ParseConfigOptions struct { pgconn.ParseConfigOptions } -// ParseConfigOptions contains options that control how a config is built such as getsslpassword. -type ParseConfigOptions struct { - pgconn.ParseConfigOptions -} - // Copy returns a deep copy of the config that is safe to use and modify. // The only exception is the tls.Config: // according to the tls.Config docs it must not be modified after creation. @@ -101,6 +92,8 @@ type Conn struct { wbuf []byte eqb ExtendedQueryBuilder + + closeCntUpdated bool } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of @@ -173,16 +166,6 @@ func ConnectWithOptions(ctx context.Context, connString string, options ParseCon } } -// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to -// provide a GetSSLPassword function. -func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) { - connConfig, err := ParseConfigWithOptions(connString, options) - if err != nil { - return nil, err - } - return connect(ctx, connConfig) -} - // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // connConfig must have been created by ParseConfig. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { @@ -256,7 +239,6 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con } else { return nil, err } - descriptionCacheCapacity = int(n) } refreshInterval := int64(REFRESH_INTERVAL_SECONDS) @@ -290,6 +272,8 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con } } else { return nil, pgconn.NewParseConfigError(connString, "invalid failed_host_reconnect_delay_secs", err) + } + } defaultQueryExecMode := QueryExecModeCacheStatement if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok { diff --git a/derived_types.go b/derived_types.go index 72c0a2423..f11b40d32 100644 --- a/derived_types.go +++ b/derived_types.go @@ -7,7 +7,7 @@ import ( "strconv" "strings" - "github.com/jackc/pgx/v5/pgtype" + "github.com/yugabyte/pgx/v5/pgtype" ) /* diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go index 033dde2b8..b516817cb 100644 --- a/internal/sanitize/sanitize.go +++ b/internal/sanitize/sanitize.go @@ -431,17 +431,10 @@ var queryPool = &pool[*Query]{ // as necessary. This function is only safe when standard_conforming_strings is // on. func SanitizeSQL(sql string, args ...any) (string, error) { -<<<<<<< HEAD - query, err := NewQuery(sql) - if err != nil { - return "", err - } -======= query := queryPool.get() query.init(sql) defer queryPool.put(query) ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e return query.Sanitize(args...) } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index bfa73cf41..1df8f1eff 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -19,8 +19,8 @@ import ( "github.com/yugabyte/pgx/v5/internal/iobufpool" "github.com/yugabyte/pgx/v5/internal/pgio" + "github.com/yugabyte/pgx/v5/pgconn/ctxwatch" "github.com/yugabyte/pgx/v5/pgconn/internal/bgreader" - "github.com/yugabyte/pgx/v5/pgconn/internal/ctxwatch" "github.com/yugabyte/pgx/v5/pgproto3" ) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index 1a53e70fa..1070ea3b3 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 05dc1f2ac..cb411494b 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index d0601f1eb..67cd6df91 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 6851bc817..f32c9185e 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/yugabyte/pgx/v5/internal/pgio" ) diff --git a/pgproto3/parse.go b/pgproto3/parse.go index e921feef6..8b6e1cbb1 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -70,9 +70,10 @@ func (src *Parse) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendUint32(dst, oid) } -<<<<<<< HEAD - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return finishMessage(dst, sp) +} +// MarshalJSON implements encoding/json.Marshaler. func (src Parse) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go index 9eb1b6a4b..f4835bd31 100644 --- a/pgproto3/sasl_initial_response.go +++ b/pgproto3/sasl_initial_response.go @@ -6,7 +6,7 @@ import ( "encoding/json" "errors" - "github.com/jackc/pgx/v5/internal/pgio" + "github.com/yugabyte/pgx/v5/internal/pgio" ) type SASLInitialResponse struct { diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index c141b27f4..935fbf2e3 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -77,11 +77,11 @@ func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { } dst = append(dst, 0) -<<<<<<< HEAD - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) + return finishMessage(dst, sp) +} - return dst // MarshalJSON implements encoding/json.Marshaler. +func (src StartupMessage) MarshalJSON() ([]byte, error) { return json.Marshal(struct { Type string ProtocolVersion uint32 diff --git a/pgtype/bool.go b/pgtype/bool.go index b74fe4414..955f01fe8 100644 --- a/pgtype/bool.go +++ b/pgtype/bool.go @@ -204,10 +204,6 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte, } func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { @@ -335,11 +331,7 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error { return s.ScanBool(Bool{Bool: v, Valid: true}) } -<<<<<<< HEAD -// https://www.postgresql.org/docs/11/datatype-boolean.html -======= // https://www.postgresql.org/docs/current/datatype-boolean.html ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e func planTextToBool(src []byte) (bool, error) { s := string(bytes.ToLower(bytes.TrimSpace(src))) diff --git a/pgtype/box.go b/pgtype/box.go index 8f869744a..d51c6433b 100644 --- a/pgtype/box.go +++ b/pgtype/box.go @@ -129,10 +129,6 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/composite.go b/pgtype/composite.go index b3197359d..071f510ed 100644 --- a/pgtype/composite.go +++ b/pgtype/composite.go @@ -276,10 +276,6 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt default: return nil, fmt.Errorf("unknown format code %d", format) } -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } type CompositeBinaryScanner struct { diff --git a/pgtype/date.go b/pgtype/date.go index 8824b50db..a50b7fd8c 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -227,10 +227,6 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/float8.go b/pgtype/float8.go index be781a5e4..f0d9ec3af 100644 --- a/pgtype/float8.go +++ b/pgtype/float8.go @@ -213,10 +213,6 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e } func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/hstore.go b/pgtype/hstore.go index 4a19338b0..e5d1f8313 100644 --- a/pgtype/hstore.go +++ b/pgtype/hstore.go @@ -164,10 +164,6 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e } func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/inet.go b/pgtype/inet.go index 6363cf44b..b92edb239 100644 --- a/pgtype/inet.go +++ b/pgtype/inet.go @@ -107,10 +107,6 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/line.go b/pgtype/line.go index 77e993ab0..63794c4e3 100644 --- a/pgtype/line.go +++ b/pgtype/line.go @@ -25,7 +25,10 @@ type Line struct { } // ScanLine implements the [LineScanner] interface. +func (line *Line) ScanLine(v Line) error { + *line = v return nil +} // LineValue implements the [LineValuer] interface. func (line Line) LineValue() (Line, error) { @@ -128,10 +131,6 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/lseg.go b/pgtype/lseg.go index 2f22f9126..031222187 100644 --- a/pgtype/lseg.go +++ b/pgtype/lseg.go @@ -129,10 +129,6 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/numeric.go b/pgtype/numeric.go index 80efbe90c..93183aa91 100644 --- a/pgtype/numeric.go +++ b/pgtype/numeric.go @@ -564,10 +564,6 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) { } func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/path.go b/pgtype/path.go index c2fe94ad6..9893969db 100644 --- a/pgtype/path.go +++ b/pgtype/path.go @@ -156,10 +156,6 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err } func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/point.go b/pgtype/point.go index 40275c1b0..de85eddc2 100644 --- a/pgtype/point.go +++ b/pgtype/point.go @@ -182,10 +182,6 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er } func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/polygon.go b/pgtype/polygon.go index f8df718e0..ed1c1f239 100644 --- a/pgtype/polygon.go +++ b/pgtype/polygon.go @@ -141,10 +141,6 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte, } func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e switch format { case BinaryFormatCode: switch target.(type) { diff --git a/pgtype/range.go b/pgtype/range.go index a73d4d7f9..62d699905 100644 --- a/pgtype/range.go +++ b/pgtype/range.go @@ -275,10 +275,6 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { } return ubr, nil -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } // Range is a generic range type. diff --git a/pgtype/uint64.go b/pgtype/uint64.go index 68fd16613..1475a1ab9 100644 --- a/pgtype/uint64.go +++ b/pgtype/uint64.go @@ -7,7 +7,7 @@ import ( "math" "strconv" - "github.com/jackc/pgx/v5/internal/pgio" + "github.com/yugabyte/pgx/v5/internal/pgio" ) type Uint64Scanner interface { diff --git a/pgxpool/pool.go b/pgxpool/pool.go index b7612afa3..4a963b9c3 100644 --- a/pgxpool/pool.go +++ b/pgxpool/pool.go @@ -10,7 +10,7 @@ import ( "sync/atomic" "time" - "github.com/yugabyte/puddle/v2" + "github.com/jackc/puddle/v2" "github.com/yugabyte/pgx/v5" "github.com/yugabyte/pgx/v5/pgconn" ) diff --git a/pgxpool/tracer.go b/pgxpool/tracer.go index 78b9d15a2..9ae4343ce 100644 --- a/pgxpool/tracer.go +++ b/pgxpool/tracer.go @@ -3,7 +3,7 @@ package pgxpool import ( "context" - "github.com/jackc/pgx/v5" + "github.com/yugabyte/pgx/v5" ) // AcquireTracer traces Acquire. diff --git a/rows.go b/rows.go index e9ceb5004..eb3aaa1e9 100644 --- a/rows.go +++ b/rows.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "sync" "time" "github.com/yugabyte/pgx/v5/pgconn" diff --git a/values.go b/values.go index 6ea0600fb..37ec5d799 100644 --- a/values.go +++ b/values.go @@ -4,7 +4,6 @@ import ( "errors" "github.com/yugabyte/pgx/v5/pgtype" - "github.com/yugabyte/pgx/v5/internal/anynil" "github.com/yugabyte/pgx/v5/internal/pgio" ) From 98ce80a306d15c69066aa63232d4933d04ec61f6 Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Tue, 11 Nov 2025 04:15:06 +0000 Subject: [PATCH 3/6] Resolving more merge conflicts --- batch_test.go | 4 ---- bench_test.go | 12 ------------ conn_test.go | 3 --- derived_types_test.go | 2 +- large_objects.go | 2 +- pgx_test.go | 4 ++-- query_test.go | 4 ---- 7 files changed, 4 insertions(+), 27 deletions(-) diff --git a/batch_test.go b/batch_test.go index 06f62f672..4cfc0f39d 100644 --- a/batch_test.go +++ b/batch_test.go @@ -488,10 +488,6 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { defer cancel() pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e batch := &pgx.Batch{} batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n") diff --git a/bench_test.go b/bench_test.go index 28edcd02a..36dc818d8 100644 --- a/bench_test.go +++ b/bench_test.go @@ -614,9 +614,6 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) { func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 5) } -func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { - benchmarkWriteNRowsViaBatchInsert(b, 5) -} func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 5) @@ -633,9 +630,6 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) { func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10) } -func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { - benchmarkWriteNRowsViaBatchInsert(b, 10) -} func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 10) @@ -652,9 +646,6 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) { func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 100) } -func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { - benchmarkWriteNRowsViaBatchInsert(b, 100) -} func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 100) @@ -687,9 +678,6 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) { func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { benchmarkWriteNRowsViaMultiInsert(b, 10000) } -func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { - benchmarkWriteNRowsViaBatchInsert(b, 10000) -} func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { benchmarkWriteNRowsViaBatchInsert(b, 10000) diff --git a/conn_test.go b/conn_test.go index 311b8c536..290628db8 100644 --- a/conn_test.go +++ b/conn_test.go @@ -1540,8 +1540,6 @@ func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) { require.EqualValues(t, 1, n) }) } -<<<<<<< HEAD -======= // https://github.com/jackc/pgx/issues/1847 func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) { @@ -1590,4 +1588,3 @@ func TestErrNoRows(t *testing.T) { require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows") } ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e diff --git a/derived_types_test.go b/derived_types_test.go index 6fb6e1d36..4a5384ae3 100644 --- a/derived_types_test.go +++ b/derived_types_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/jackc/pgx/v5" + "github.com/yugabyte/pgx/v5" "github.com/stretchr/testify/require" ) diff --git a/large_objects.go b/large_objects.go index 9d21afdce..c4f10a8b9 100644 --- a/large_objects.go +++ b/large_objects.go @@ -5,7 +5,7 @@ import ( "errors" "io" - "github.com/jackc/pgx/v5/pgtype" + "github.com/yugabyte/pgx/v5/pgtype" ) // The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of diff --git a/pgx_test.go b/pgx_test.go index 51b4bbc4e..d12d75fd3 100644 --- a/pgx_test.go +++ b/pgx_test.go @@ -5,8 +5,8 @@ import ( "os" "testing" - "github.com/jackc/pgx/v5" - _ "github.com/jackc/pgx/v5/stdlib" + "github.com/yugabyte/pgx/v5" + _ "github.com/yugabyte/pgx/v5/stdlib" ) func skipCockroachDB(t testing.TB, msg string) { diff --git a/query_test.go b/query_test.go index 46bc0efe4..5bce9fb8f 100644 --- a/query_test.go +++ b/query_test.go @@ -2207,10 +2207,6 @@ insert into products (name, price) values } rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare // cases it may be useful to detect the error as early as possible. if err != nil { From 796142ed09f40d41e5e76282315765f759bdc35c Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Tue, 11 Nov 2025 04:55:35 +0000 Subject: [PATCH 4/6] Add check for undefined loadBalance value --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index c0ed4566e..cd1a9370e 100644 --- a/conn.go +++ b/conn.go @@ -173,7 +173,7 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { // connections with the same config. See https://github.com/jackc/pgx/issues/618. connConfig = connConfig.Copy() - if connConfig.loadBalance != "false" { + if connConfig.loadBalance != "false" && connConfig.loadBalance != "" { return connectLoadBalanced(ctx, connConfig) } else { return connect(ctx, connConfig) From be212e5e58e8f1296e9edf8a4fb0cf2a33c52457 Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Thu, 13 Nov 2025 05:35:13 +0000 Subject: [PATCH 5/6] Resolving merge conflicts --- README.md | 45 +------------------------------------ pgconn/benchmark_test.go | 16 ------------- pgconn/config_test.go | 4 ---- pgtype/array_codec_test.go | 4 ---- pgtype/multirange_test.go | 4 ---- pgtype/numeric_test.go | 4 ---- pgtype/range_codec_test.go | 8 ------- pgxpool/pool_test.go | 27 ---------------------- testsetup/generate_certs.go | 4 ---- 9 files changed, 1 insertion(+), 115 deletions(-) diff --git a/README.md b/README.md index a57ff84ab..61b48f75b 100644 --- a/README.md +++ b/README.md @@ -110,10 +110,7 @@ import ( "fmt" "os" -<<<<<<< HEAD "github.com/yugabyte/pgx/v5" -======= - "github.com/jackc/pgx/v5" ) conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { @@ -147,10 +144,6 @@ See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-starte * `COPY` protocol support for faster bulk data loads * Tracing and logging support * Connection pool with after-connect hook for arbitrary connection setup -<<<<<<< HEAD -======= -* `LISTEN` / `NOTIFY` ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings * `hstore` support * `json` and `jsonb` support @@ -163,12 +156,7 @@ See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-starte ## Choosing Between the pgx and database/sql Interfaces -<<<<<<< HEAD The pgx interface is faster. -======= -The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available -through the `database/sql` interface. ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e The pgx interface is recommended when: @@ -179,11 +167,7 @@ It is also possible to use the `database/sql` interface and convert a connection ## Testing -<<<<<<< HEAD -See CONTRIBUTING.md for setup instructions. -======= See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions. ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e ## Architecture @@ -191,11 +175,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube. ## Supported Go and PostgreSQL Versions -<<<<<<< HEAD -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.20 and higher and PostgreSQL 12 and higher. -======= -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/). ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e +pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. ## Version Policy @@ -207,11 +187,7 @@ pgx follows semantic versioning for the documented public API on stable releases pglogrepl provides functionality to act as a client for PostgreSQL logical replication. -<<<<<<< HEAD -pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.16 and higher and PostgreSQL 10 and higher. -======= ### [github.com/jackc/pgmock](https://github.com/jackc/pgmock) ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e pgmock offers the ability to create a server that mocks the PostgreSQL wire protocol. This is used internally to test pgx by purposely inducing unusual errors. pgproto3 and pgmock together provide most of the foundational tooling required to implement a PostgreSQL proxy or MitM (such as for a custom connection pooler). @@ -225,14 +201,6 @@ pgerrcode contains constants for the PostgreSQL error codes. ## Adapters for 3rd Party Types -<<<<<<< HEAD -### [github.com/yugabyte/pgx/v4/pgxpool](https://github.com/yugabyte/pgx/tree/master/pgxpool) - - -### [github.com/yugabyte/pgx/v4/stdlib](https://github.com/yugabyte/pgx/tree/master/stdlib) - -* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) -======= * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) * [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos)) @@ -243,7 +211,6 @@ pgerrcode contains constants for the PostgreSQL error codes. * [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) * [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx) ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e ## Adapters for 3rd Party Loggers @@ -273,11 +240,7 @@ Library for scanning data from a database into Go structs and more. A carefully designed SQL client for making using SQL easier, more productive, and less error-prone on Golang. -<<<<<<< HEAD -### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) -======= ### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e Adds GSSAPI / Kerberos authentication support. @@ -290,11 +253,6 @@ Explicit data mapping and scanning library for Go structs and slices. Type safe and flexible package for scanning database data into Go types. Supports, structs, maps, slices and custom mapping functions. -<<<<<<< HEAD -### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) - -Code first migration library for native pgx (no database/sql abstraction). -======= ### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx) Code first migration library for native pgx (no database/sql abstraction). @@ -314,4 +272,3 @@ Simplifies working with the pgx library, providing convenient scanning of nested ## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter) Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters. ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e diff --git a/pgconn/benchmark_test.go b/pgconn/benchmark_test.go index 81893a047..86c76a51c 100644 --- a/pgconn/benchmark_test.go +++ b/pgconn/benchmark_test.go @@ -78,10 +78,6 @@ func BenchmarkExec(b *testing.B) { } } _, err = rr.Close() -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -130,10 +126,6 @@ func BenchmarkExecPossibleToCancel(b *testing.B) { } } _, err = rr.Close() -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -190,10 +182,6 @@ func BenchmarkExecPrepared(b *testing.B) { } } _, err = rr.Close() -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } @@ -236,10 +224,6 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { } } _, err = rr.Close() -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e if err != nil { b.Fatal(err) } diff --git a/pgconn/config_test.go b/pgconn/config_test.go index 17a7f5f4c..a12ade0e8 100644 --- a/pgconn/config_test.go +++ b/pgconn/config_test.go @@ -133,10 +133,6 @@ func TestParseConfig(t *testing.T) { name: "sslmode prefer", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", config: &pgconn.Config{ -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e User: "jack", Password: "secret", Host: "localhost", diff --git a/pgtype/array_codec_test.go b/pgtype/array_codec_test.go index 0019cb3d2..3e1355470 100644 --- a/pgtype/array_codec_test.go +++ b/pgtype/array_codec_test.go @@ -256,10 +256,6 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) { skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) require.NoError(t, err) diff --git a/pgtype/multirange_test.go b/pgtype/multirange_test.go index 13f120025..98ddd0af0 100644 --- a/pgtype/multirange_test.go +++ b/pgtype/multirange_test.go @@ -71,10 +71,6 @@ func TestMultirangeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for _, tt := range []struct { sql string expected any diff --git a/pgtype/numeric_test.go b/pgtype/numeric_test.go index f201f8606..06d1b93bb 100644 --- a/pgtype/numeric_test.go +++ b/pgtype/numeric_test.go @@ -198,10 +198,6 @@ func TestNumericMarshalJSON(t *testing.T) { skipCockroachDB(t, "server formats numeric text format differently") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for i, tt := range []struct { decString string }{ diff --git a/pgtype/range_codec_test.go b/pgtype/range_codec_test.go index f004fba3a..b2a1cbc3d 100644 --- a/pgtype/range_codec_test.go +++ b/pgtype/range_codec_test.go @@ -75,10 +75,6 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e var r pgtype.Range[pgtype.Int4] err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) @@ -132,10 +128,6 @@ func TestRangeCodecDecodeValue(t *testing.T) { skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e for _, tt := range []struct { sql string expected any diff --git a/pgxpool/pool_test.go b/pgxpool/pool_test.go index b393aee0f..196484cbc 100644 --- a/pgxpool/pool_test.go +++ b/pgxpool/pool_test.go @@ -1193,28 +1193,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { config.MinConns = int32(12) config.MaxConns = int32(15) -<<<<<<< HEAD - - acquireAttempts := int64(0) - connectAttempts := int64(0) - - config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool { - atomic.AddInt64(&acquireAttempts, 1) - return true - } - config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { - atomic.AddInt64(&connectAttempts, 1) - return nil - } - - pool, err := pgxpool.NewWithConfig(ctx, config) - require.NoError(t, err) - defer pool.Close() - - for i := 0; i < 500; i++ { - time.Sleep(10 * time.Millisecond) - -======= acquireAttempts := int64(0) connectAttempts := int64(0) @@ -1235,7 +1213,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { for i := 0; i < 500; i++ { time.Sleep(10 * time.Millisecond) ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e stat := pool.Stat() if stat.IdleConns() == 12 && stat.AcquireCount() == 0 && stat.TotalConns() == 12 && atomic.LoadInt64(&acquireAttempts) == 0 && atomic.LoadInt64(&connectAttempts) == 12 { return @@ -1243,10 +1220,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) { } t.Fatal("did not reach min pool size") -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } func TestPoolSendBatchBatchCloseTwice(t *testing.T) { diff --git a/testsetup/generate_certs.go b/testsetup/generate_certs.go index 7f478d4f5..d465b6c52 100644 --- a/testsetup/generate_certs.go +++ b/testsetup/generate_certs.go @@ -161,10 +161,6 @@ func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password } return nil -<<<<<<< HEAD - -======= ->>>>>>> a2fca037434a0a7096b095d4ed87cdffb03b626e } func writeCertificate(path string, certBytes []byte) error { From e68042b0f18ce8d86beab7d4084e6ed7d0df1b40 Mon Sep 17 00:00:00 2001 From: Harsh Daryani Date: Mon, 17 Nov 2025 04:16:53 +0000 Subject: [PATCH 6/6] Readme update --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 61b48f75b..bc77b9037 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,9 @@ import ( "github.com/yugabyte/pgx/v5" ) + +func main() { + // urlExample := "postgres://username:password@localhost:5433/database_name" conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) if err != nil { fmt.Fprintf(os.Stderr, "Unable to connect to database: %v\n", err)