diff --git a/README.md b/README.md index 3b3b507..ad4fa6a 100644 --- a/README.md +++ b/README.md @@ -74,4 +74,4 @@ func main() { ## LICENSE -See `LICENSE` file for details. +See [LICENSE](/LICENSE) file for details. diff --git a/go.mod b/go.mod index 197046b..4626edd 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,13 @@ module github.com/randlabs/go-postgres go 1.19 -require github.com/jackc/pgx/v5 v5.5.1 +require github.com/jackc/pgx/v5 v5.5.4 require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect - golang.org/x/crypto v0.17.0 // indirect + golang.org/x/crypto v0.21.0 // indirect golang.org/x/sync v0.5.0 // indirect golang.org/x/text v0.14.0 // indirect ) diff --git a/go.sum b/go.sum index 2d0bd63..fd4ef9e 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= -github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= +github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -14,8 +14,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= diff --git a/postgres.go b/postgres.go index 666040f..30f7eb4 100644 --- a/postgres.go +++ b/postgres.go @@ -4,6 +4,9 @@ import ( "context" "errors" "fmt" + "net/url" + "strconv" + "strings" "sync" "time" @@ -58,16 +61,19 @@ const ( // New creates a new postgresql database driver. func New(ctx context.Context, opts Options) (*Database, error) { - var sslMode string - - // Create database object - db := Database{} - db.err.mutex = sync.Mutex{} - - // Setup basic configuration options + // Validate options + if len(opts.Host) == 0 { + return nil, errors.New("invalid host") + } + if len(opts.User) == 0 { + return nil, errors.New("invalid user name") + } + if len(opts.Name) == 0 { + return nil, errors.New("invalid database name") + } + sslMode := "disable" switch opts.SSLMode { case SSLModeDisable: - sslMode = "disable" case SSLModeAllow: sslMode = "prefer" case SSLModeRequired: @@ -76,6 +82,10 @@ func New(ctx context.Context, opts Options) (*Database, error) { return nil, errors.New("invalid SSL mode") } + // Create database object + db := Database{} + db.err.mutex = sync.Mutex{} + connString := fmt.Sprintf( "host='%s' port=%d user='%s' password='%s' dbname='%s' sslmode=%s", encodeDSN(opts.Host), opts.Port, encodeDSN(opts.User), encodeDSN(opts.Password), encodeDSN(opts.Name), @@ -110,6 +120,82 @@ func New(ctx context.Context, opts Options) (*Database, error) { return &db, nil } +// NewFromURL creates a new postgresql database driver from an URL +func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) { + opts := Options{} + + u, err := url.ParseRequestURI(rawUrl) + if err != nil { + return nil, errors.New("invalid url provided") + } + + // Check schema + if u.Scheme != "pg" && u.Scheme != "postgres" && u.Scheme != "postgresql" { + return nil, errors.New("invalid url schema") + } + + // Check host name and port + opts.Host = u.Hostname() + if len(opts.Host) == 0 { + return nil, errors.New("invalid host") + } + s := u.Port() + if len(s) == 0 { + opts.Port = 5432 + } else { + val, err2 := strconv.Atoi(s) + if err2 != nil || val < 1 || val > 65535 { + return nil, errors.New("invalid port") + } + opts.Port = uint16(val) + } + + // Check user and password + if u.User == nil { + return nil, errors.New("invalid user name") + } + opts.User = u.User.Username() + if len(opts.User) == 0 { + return nil, errors.New("invalid user name") + } + + // Check database name + if len(u.Path) < 1 || (!strings.HasPrefix(u.Path, "/")) || strings.Index(u.Path[1:], "/") >= 0 { + return nil, errors.New("invalid database name") + } + opts.Name = u.Path[1:] + + // Check ssl mode + opts.SSLMode = SSLModeDisable + switch u.Query().Get("sslmode") { + case "allow": + opts.SSLMode = SSLModeAllow + + case "required": + opts.SSLMode = SSLModeRequired + + case "disabled": + fallthrough + case "": + + default: + return nil, errors.New("invalid SSL mode") + } + + // Check max connections count + s = u.Query().Get("maxconn") + if len(s) > 0 { + val, err2 := strconv.Atoi(s) + if err2 != nil || val < 0 { + return nil, errors.New("invalid max connections count") + } + opts.MaxConns = int32(val) + } + + // Create + return New(ctx, opts) +} + // Close shutdown the connection pool func (db *Database) Close() { if db.pool != nil { diff --git a/postgres_test.go b/postgres_test.go index 980dfc8..e8515e3 100644 --- a/postgres_test.go +++ b/postgres_test.go @@ -66,6 +66,7 @@ type TestJSON struct { } var ( + pgUrl string pgHost string pgPort uint pgUsername string @@ -82,10 +83,11 @@ var ( // ----------------------------------------------------------------------------- func init() { + flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.") flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')") flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)") flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')") - flag.StringVar(&pgPassword, "password", "", "Specifies the user passwonrd.") + flag.StringVar(&pgPassword, "password", "", "Specifies the user password.") flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.") testJSON = TestJSON{ @@ -102,56 +104,55 @@ func init() { // ----------------------------------------------------------------------------- func TestPostgres(t *testing.T) { + var db *postgres.Database + var err error + // Parse and check command-line parameters flag.Parse() checkSettings(t) + ctx := context.Background() + // Create database driver - db, err := postgres.New(context.Background(), postgres.Options{ - Host: pgHost, - Port: uint16(pgPort), - User: pgUsername, - Password: pgPassword, - Name: pgDatabaseName, - }) + if len(pgUrl) > 0 { + db, err = postgres.NewFromURL(ctx, pgUrl) + } else { + db, err = postgres.New(ctx, postgres.Options{ + Host: pgHost, + Port: uint16(pgPort), + User: pgUsername, + Password: pgPassword, + Name: pgDatabaseName, + }) + } if err != nil { t.Fatalf("%v", err.Error()) } - // We comment the next defer line because we want to do a clean database pool shutdown on errors and - // calling fatal exits the process. - // defer db.Close() - - ctx := context.Background() + defer db.Close() t.Log("Creating test table") err = createTestTable(ctx, db) if err != nil { - db.Close() t.Fatalf("%v", err.Error()) } t.Log("Inserting test data") err = insertTestData(ctx, db) if err != nil { - db.Close() t.Fatalf("%v", err.Error()) } t.Log("Reading test data") err = readTestData(ctx, db) if err != nil { - db.Close() t.Fatalf("%v", err.Error()) } t.Log("Reading test data (multi-row)") err = readMultiTestData(ctx, db) if err != nil { - db.Close() t.Fatalf("%v", err.Error()) } - - db.Close() } // -----------------------------------------------------------------------------