Skip to content

Commit

Permalink
feat: add WithAutoCreateDatabase option
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Jul 28, 2022
1 parent 546111f commit 8bf4958
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 11 deletions.
3 changes: 1 addition & 2 deletions ch/chschema/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import (
)

const (
discardUnknownColumnsFlag = internal.Flag(1) << iota
columnarFlag
columnarFlag = internal.Flag(1) << iota
afterScanBlockHookFlag
)

Expand Down
16 changes: 13 additions & 3 deletions ch/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import (

const (
discardUnknownColumnsFlag internal.Flag = 1 << iota
autoCreateDatabaseFlag
)

type Config struct {
chpool.Config

Compression bool

Network string
Addr string
User string
Password string
Expand All @@ -42,6 +42,11 @@ type Config struct {
MaxRetryBackoff time.Duration
}

func (cfg *Config) clone() *Config {
clone := *cfg
return &clone
}

func (cfg *Config) netDialer() *net.Dialer {
return &net.Dialer{
Timeout: cfg.DialTimeout,
Expand All @@ -62,7 +67,6 @@ func defaultConfig() *Config {

Compression: true,

Network: "tcp",
Addr: "localhost:9000",
User: "default",
Database: "default",
Expand Down Expand Up @@ -93,7 +97,13 @@ func WithCompression(enabled bool) Option {
}
}

// WithAddr configures TCP host:port or Unix socket depending on Network.
func WithAutoCreateDatabase(enabled bool) Option {
return func(db *DB) {
db.flags.Set(autoCreateDatabaseFlag)
}
}

// WithAddr configures TCP host:port.
func WithAddr(addr string) Option {
return func(db *DB) {
db.cfg.Addr = addr
Expand Down
42 changes: 37 additions & 5 deletions ch/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@ type DB struct {
}

func Connect(opts ...Option) *DB {
db := &DB{
cfg: defaultConfig(),
db := newDB(defaultConfig(), opts...)
if db.flags.Has(autoCreateDatabaseFlag) {
db.autoCreateDatabase()
}
return db
}

func newDB(cfg *Config, opts ...Option) *DB {
db := &DB{
cfg: cfg,
}
for _, opt := range opts {
opt(db)
}
db.pool = newConnPool(db.cfg)

return db
}

Expand All @@ -53,12 +59,12 @@ func newConnPool(cfg *Config) *chpool.ConnPool {
if cfg.TLSConfig != nil {
return tls.DialWithDialer(
cfg.netDialer(),
cfg.Network,
"tcp",
cfg.Addr,
cfg.TLSConfig,
)
}
return cfg.netDialer().DialContext(ctx, cfg.Network, cfg.Addr)
return cfg.netDialer().DialContext(ctx, "tcp", cfg.Addr)
}
return chpool.New(&poolcfg)
}
Expand Down Expand Up @@ -106,6 +112,32 @@ func (db *DB) Stats() DBStats {
}
}

func (db *DB) autoCreateDatabase() {
ctx := context.Background()

switch err := db.Ping(ctx); err := err.(type) {
case nil: // all is good
return
case *Error:
if err.Code != 81 { // 81 - database does not exist
return
}
default:
// ignore the error
return
}

cfg := db.cfg.clone()
cfg.Database = ""

tmp := newDB(cfg)
defer tmp.Close()

if _, err := tmp.Exec("CREATE DATABASE IF NOT EXISTS ?", Ident(db.cfg.Database)); err != nil {
internal.Logger.Printf("create database %q failed: %s", db.cfg.Database, err)
}
}

func (db *DB) getConn(ctx context.Context) (*chpool.Conn, error) {
cn, err := db.pool.Get(ctx)
if err != nil {
Expand Down
26 changes: 25 additions & 1 deletion ch/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func chDB(opts ...ch.Option) *ch.DB {
dsn = "clickhouse://localhost:9000/test?sslmode=disable"
}

opts = append(opts, ch.WithDSN(dsn))
opts = append(opts, ch.WithDSN(dsn), ch.WithAutoCreateDatabase(true))
db := ch.Connect(opts...)
db.AddQueryHook(chdebug.NewQueryHook(
chdebug.WithEnabled(false),
Expand All @@ -32,6 +32,30 @@ func chDB(opts ...ch.Option) *ch.DB {
return db
}

func TestAutoCreateDatabase(t *testing.T) {
ctx := context.Background()
dbName := "auto_create_database"

{
db := ch.Connect()
defer db.Close()

_, err := db.Exec("DROP DATABASE IF EXISTS ?", ch.Ident(dbName))
require.NoError(t, err)
}

{
db := ch.Connect(
ch.WithDatabase(dbName),
ch.WithAutoCreateDatabase(true),
)
defer db.Close()

err := db.Ping(ctx)
require.NoError(t, err)
}
}

func TestCHError(t *testing.T) {
ctx := context.Background()

Expand Down

0 comments on commit 8bf4958

Please sign in to comment.