Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ func main() {

## LICENSE

See `LICENSE` file for details.
See [LICENSE](/LICENSE) file for details.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ module github.com/randlabs/go-postgres

go 1.19

require github.com/jackc/pgx/v5 v5.5.1
require github.com/jackc/pgx/v5 v5.5.4

require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/sync v0.5.0 // indirect
golang.org/x/text v0.14.0 // indirect
)
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand All @@ -14,8 +14,8 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE=
golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
Expand Down
102 changes: 94 additions & 8 deletions postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -58,16 +61,19 @@ const (

// New creates a new postgresql database driver.
func New(ctx context.Context, opts Options) (*Database, error) {
var sslMode string

// Create database object
db := Database{}
db.err.mutex = sync.Mutex{}

// Setup basic configuration options
// Validate options
if len(opts.Host) == 0 {
return nil, errors.New("invalid host")
}
if len(opts.User) == 0 {
return nil, errors.New("invalid user name")
}
if len(opts.Name) == 0 {
return nil, errors.New("invalid database name")
}
sslMode := "disable"
switch opts.SSLMode {
case SSLModeDisable:
sslMode = "disable"
case SSLModeAllow:
sslMode = "prefer"
case SSLModeRequired:
Expand All @@ -76,6 +82,10 @@ func New(ctx context.Context, opts Options) (*Database, error) {
return nil, errors.New("invalid SSL mode")
}

// Create database object
db := Database{}
db.err.mutex = sync.Mutex{}

connString := fmt.Sprintf(
"host='%s' port=%d user='%s' password='%s' dbname='%s' sslmode=%s",
encodeDSN(opts.Host), opts.Port, encodeDSN(opts.User), encodeDSN(opts.Password), encodeDSN(opts.Name),
Expand Down Expand Up @@ -110,6 +120,82 @@ func New(ctx context.Context, opts Options) (*Database, error) {
return &db, nil
}

// NewFromURL creates a new postgresql database driver from an URL
func NewFromURL(ctx context.Context, rawUrl string) (*Database, error) {
opts := Options{}

u, err := url.ParseRequestURI(rawUrl)
if err != nil {
return nil, errors.New("invalid url provided")
}

// Check schema
if u.Scheme != "pg" && u.Scheme != "postgres" && u.Scheme != "postgresql" {
return nil, errors.New("invalid url schema")
}

// Check host name and port
opts.Host = u.Hostname()
if len(opts.Host) == 0 {
return nil, errors.New("invalid host")
}
s := u.Port()
if len(s) == 0 {
opts.Port = 5432
} else {
val, err2 := strconv.Atoi(s)
if err2 != nil || val < 1 || val > 65535 {
return nil, errors.New("invalid port")
}
opts.Port = uint16(val)
}

// Check user and password
if u.User == nil {
return nil, errors.New("invalid user name")
}
opts.User = u.User.Username()
if len(opts.User) == 0 {
return nil, errors.New("invalid user name")
}

// Check database name
if len(u.Path) < 1 || (!strings.HasPrefix(u.Path, "/")) || strings.Index(u.Path[1:], "/") >= 0 {
return nil, errors.New("invalid database name")
}
opts.Name = u.Path[1:]

// Check ssl mode
opts.SSLMode = SSLModeDisable
switch u.Query().Get("sslmode") {
case "allow":
opts.SSLMode = SSLModeAllow

case "required":
opts.SSLMode = SSLModeRequired

case "disabled":
fallthrough
case "":

default:
return nil, errors.New("invalid SSL mode")
}

// Check max connections count
s = u.Query().Get("maxconn")
if len(s) > 0 {
val, err2 := strconv.Atoi(s)
if err2 != nil || val < 0 {
return nil, errors.New("invalid max connections count")
}
opts.MaxConns = int32(val)
}

// Create
return New(ctx, opts)
}

// Close shutdown the connection pool
func (db *Database) Close() {
if db.pool != nil {
Expand Down
39 changes: 20 additions & 19 deletions postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type TestJSON struct {
}

var (
pgUrl string
pgHost string
pgPort uint
pgUsername string
Expand All @@ -82,10 +83,11 @@ var (
// -----------------------------------------------------------------------------

func init() {
flag.StringVar(&pgUrl, "url", "", "Specifies the Postgres URL.")
flag.StringVar(&pgHost, "host", "127.0.0.1", "Specifies the Postgres server host. (Defaults to '127.0.0.1')")
flag.UintVar(&pgPort, "port", 5432, "Specifies the Postgres server port. (Defaults to 5432)")
flag.StringVar(&pgUsername, "user", "postgres", "Specifies the user name. (Defaults to 'postgres')")
flag.StringVar(&pgPassword, "password", "", "Specifies the user passwonrd.")
flag.StringVar(&pgPassword, "password", "", "Specifies the user password.")
flag.StringVar(&pgDatabaseName, "db", "", "Specifies the database name.")

testJSON = TestJSON{
Expand All @@ -102,56 +104,55 @@ func init() {
// -----------------------------------------------------------------------------

func TestPostgres(t *testing.T) {
var db *postgres.Database
var err error

// Parse and check command-line parameters
flag.Parse()
checkSettings(t)

ctx := context.Background()

// Create database driver
db, err := postgres.New(context.Background(), postgres.Options{
Host: pgHost,
Port: uint16(pgPort),
User: pgUsername,
Password: pgPassword,
Name: pgDatabaseName,
})
if len(pgUrl) > 0 {
db, err = postgres.NewFromURL(ctx, pgUrl)
} else {
db, err = postgres.New(ctx, postgres.Options{
Host: pgHost,
Port: uint16(pgPort),
User: pgUsername,
Password: pgPassword,
Name: pgDatabaseName,
})
}
if err != nil {
t.Fatalf("%v", err.Error())
}
// We comment the next defer line because we want to do a clean database pool shutdown on errors and
// calling fatal exits the process.
// defer db.Close()

ctx := context.Background()
defer db.Close()

t.Log("Creating test table")
err = createTestTable(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Inserting test data")
err = insertTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Reading test data")
err = readTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

t.Log("Reading test data (multi-row)")
err = readMultiTestData(ctx, db)
if err != nil {
db.Close()
t.Fatalf("%v", err.Error())
}

db.Close()
}

// -----------------------------------------------------------------------------
Expand Down