Skip to content

Commit

Permalink
refactor: create a generic store and stub out dialect queries (#477)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman authored Mar 14, 2023
1 parent 33106fc commit c462979
Show file tree
Hide file tree
Showing 18 changed files with 671 additions and 493 deletions.
365 changes: 21 additions & 344 deletions dialect.go
Original file line number Diff line number Diff line change
@@ -1,364 +1,41 @@
package goose

import (
"database/sql"
"fmt"

"github.com/pressly/goose/v3/internal/dialect"
)

// SQLDialect abstracts the details of specific SQL dialects
// for goose's few SQL specific statements
type SQLDialect interface {
createVersionTableSQL() string // sql string to create the db version table
insertVersionSQL() string // sql string to insert the initial version table row
deleteVersionSQL() string // sql string to delete version
migrationSQL() string // sql string to retrieve migrations
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
func init() {
store, _ = dialect.NewStore(dialect.Postgres, TableName())
}

var dialect SQLDialect = &PostgresDialect{}

// GetDialect gets the SQLDialect
func GetDialect() SQLDialect {
return dialect
}
var store dialect.Store

// SetDialect sets the SQLDialect
func SetDialect(d string) error {
switch d {
// SetDialect sets the dialect to use for the goose package.
func SetDialect(s string) error {
var d dialect.Dialect
switch s {
case "postgres", "pgx":
dialect = &PostgresDialect{}
d = dialect.Postgres
case "mysql":
dialect = &MySQLDialect{}
d = dialect.Mysql
case "sqlite3", "sqlite":
dialect = &Sqlite3Dialect{}
d = dialect.Sqlite3
case "mssql":
dialect = &SqlServerDialect{}
d = dialect.Sqlserver
case "redshift":
dialect = &RedshiftDialect{}
d = dialect.Redshift
case "tidb":
dialect = &TiDBDialect{}
d = dialect.Tidb
case "clickhouse":
dialect = &ClickHouseDialect{}
d = dialect.Clickhouse
case "vertica":
dialect = &VerticaDialect{}
d = dialect.Vertica
default:
return fmt.Errorf("%q: unknown dialect", d)
}

return nil
}

////////////////////////////
// Postgres
////////////////////////////

// PostgresDialect struct.
type PostgresDialect struct{}

func (pg PostgresDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}

func (pg PostgresDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
}

func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m PostgresDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (pg PostgresDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}

////////////////////////////
// MySQL
////////////////////////////

// MySQLDialect struct.
type MySQLDialect struct{}

func (m MySQLDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}

func (m MySQLDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}

func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m MySQLDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (m MySQLDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}

////////////////////////////
// MSSQL
////////////////////////////

// SqlServerDialect struct.
type SqlServerDialect struct{}

func (m SqlServerDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL,
is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
);`, TableName())
}

func (m SqlServerDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName())
}

func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m SqlServerDialect) migrationSQL() string {
const tpl = `
WITH Migrations AS
(
SELECT tstamp, is_applied,
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber'
FROM %s
WHERE version_id=@p1
)
SELECT tstamp, is_applied
FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC
`
return fmt.Sprintf(tpl, TableName())
}

func (m SqlServerDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName())
}

////////////////////////////
// sqlite3
////////////////////////////

// Sqlite3Dialect struct.
type Sqlite3Dialect struct{}

func (m Sqlite3Dialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL,
is_applied INTEGER NOT NULL,
tstamp TIMESTAMP DEFAULT (datetime('now'))
);`, TableName())
}

func (m Sqlite3Dialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}

func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m Sqlite3Dialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (m Sqlite3Dialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}

////////////////////////////
// Redshift
////////////////////////////

// RedshiftDialect struct.
type RedshiftDialect struct{}

func (rs RedshiftDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default sysdate,
PRIMARY KEY(id)
);`, TableName())
}

func (rs RedshiftDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName())
}

func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m RedshiftDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (rs RedshiftDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName())
}

////////////////////////////
// TiDB
////////////////////////////

// TiDBDialect struct.
type TiDBDialect struct{}

func (m TiDBDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}

func (m TiDBDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}

func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m TiDBDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (m TiDBDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
}

////////////////////////////
// ClickHouse
////////////////////////////

// ClickHouseDialect struct.
type ClickHouseDialect struct{}

func (m ClickHouseDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (
version_id Int64,
is_applied UInt8,
date Date default now(),
tstamp DateTime default now()
)
ENGINE = MergeTree()
ORDER BY (date)`, TableName())
}

func (m ClickHouseDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY version_id DESC", TableName()))
if err != nil {
return nil, err
return fmt.Errorf("%q: unknown dialect", s)
}
return rows, err
}

func (m ClickHouseDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)", TableName())
}

func (m ClickHouseDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (m ClickHouseDialect) deleteVersionSQL() string {
return fmt.Sprintf("ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2", TableName())
}

////////////////////////////
// Vertica
////////////////////////////

// VerticaDialect struct.
type VerticaDialect struct{}

func (v VerticaDialect) createVersionTableSQL() string {
return fmt.Sprintf(`CREATE TABLE %s (
id identity(1,1) NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`, TableName())
}

func (v VerticaDialect) insertVersionSQL() string {
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName())
}

func (v VerticaDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName()))
if err != nil {
return nil, err
}

return rows, err
}

func (m VerticaDialect) migrationSQL() string {
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName())
}

func (v VerticaDialect) deleteVersionSQL() string {
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName())
var err error
store, err = dialect.NewStore(d, TableName())
return err
}
Loading

0 comments on commit c462979

Please sign in to comment.