-
Notifications
You must be signed in to change notification settings - Fork 518
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: create a generic store and stub out dialect queries (#477)
- Loading branch information
Showing
18 changed files
with
671 additions
and
493 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,364 +1,41 @@ | ||
package goose | ||
|
||
import ( | ||
"database/sql" | ||
"fmt" | ||
|
||
"github.com/pressly/goose/v3/internal/dialect" | ||
) | ||
|
||
// SQLDialect abstracts the details of specific SQL dialects | ||
// for goose's few SQL specific statements | ||
type SQLDialect interface { | ||
createVersionTableSQL() string // sql string to create the db version table | ||
insertVersionSQL() string // sql string to insert the initial version table row | ||
deleteVersionSQL() string // sql string to delete version | ||
migrationSQL() string // sql string to retrieve migrations | ||
dbVersionQuery(db *sql.DB) (*sql.Rows, error) | ||
func init() { | ||
store, _ = dialect.NewStore(dialect.Postgres, TableName()) | ||
} | ||
|
||
var dialect SQLDialect = &PostgresDialect{} | ||
|
||
// GetDialect gets the SQLDialect | ||
func GetDialect() SQLDialect { | ||
return dialect | ||
} | ||
var store dialect.Store | ||
|
||
// SetDialect sets the SQLDialect | ||
func SetDialect(d string) error { | ||
switch d { | ||
// SetDialect sets the dialect to use for the goose package. | ||
func SetDialect(s string) error { | ||
var d dialect.Dialect | ||
switch s { | ||
case "postgres", "pgx": | ||
dialect = &PostgresDialect{} | ||
d = dialect.Postgres | ||
case "mysql": | ||
dialect = &MySQLDialect{} | ||
d = dialect.Mysql | ||
case "sqlite3", "sqlite": | ||
dialect = &Sqlite3Dialect{} | ||
d = dialect.Sqlite3 | ||
case "mssql": | ||
dialect = &SqlServerDialect{} | ||
d = dialect.Sqlserver | ||
case "redshift": | ||
dialect = &RedshiftDialect{} | ||
d = dialect.Redshift | ||
case "tidb": | ||
dialect = &TiDBDialect{} | ||
d = dialect.Tidb | ||
case "clickhouse": | ||
dialect = &ClickHouseDialect{} | ||
d = dialect.Clickhouse | ||
case "vertica": | ||
dialect = &VerticaDialect{} | ||
d = dialect.Vertica | ||
default: | ||
return fmt.Errorf("%q: unknown dialect", d) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
//////////////////////////// | ||
// Postgres | ||
//////////////////////////// | ||
|
||
// PostgresDialect struct. | ||
type PostgresDialect struct{} | ||
|
||
func (pg PostgresDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id serial NOT NULL, | ||
version_id bigint NOT NULL, | ||
is_applied boolean NOT NULL, | ||
tstamp timestamp NULL default now(), | ||
PRIMARY KEY(id) | ||
);`, TableName()) | ||
} | ||
|
||
func (pg PostgresDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) | ||
} | ||
|
||
func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m PostgresDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (pg PostgresDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// MySQL | ||
//////////////////////////// | ||
|
||
// MySQLDialect struct. | ||
type MySQLDialect struct{} | ||
|
||
func (m MySQLDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id serial NOT NULL, | ||
version_id bigint NOT NULL, | ||
is_applied boolean NOT NULL, | ||
tstamp timestamp NULL default now(), | ||
PRIMARY KEY(id) | ||
);`, TableName()) | ||
} | ||
|
||
func (m MySQLDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) | ||
} | ||
|
||
func (m MySQLDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m MySQLDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (m MySQLDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// MSSQL | ||
//////////////////////////// | ||
|
||
// SqlServerDialect struct. | ||
type SqlServerDialect struct{} | ||
|
||
func (m SqlServerDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, | ||
version_id BIGINT NOT NULL, | ||
is_applied BIT NOT NULL, | ||
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP | ||
);`, TableName()) | ||
} | ||
|
||
func (m SqlServerDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2);", TableName()) | ||
} | ||
|
||
func (m SqlServerDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m SqlServerDialect) migrationSQL() string { | ||
const tpl = ` | ||
WITH Migrations AS | ||
( | ||
SELECT tstamp, is_applied, | ||
ROW_NUMBER() OVER (ORDER BY tstamp) AS 'RowNumber' | ||
FROM %s | ||
WHERE version_id=@p1 | ||
) | ||
SELECT tstamp, is_applied | ||
FROM Migrations | ||
WHERE RowNumber BETWEEN 1 AND 2 | ||
ORDER BY tstamp DESC | ||
` | ||
return fmt.Sprintf(tpl, TableName()) | ||
} | ||
|
||
func (m SqlServerDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=@p1;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// sqlite3 | ||
//////////////////////////// | ||
|
||
// Sqlite3Dialect struct. | ||
type Sqlite3Dialect struct{} | ||
|
||
func (m Sqlite3Dialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
version_id INTEGER NOT NULL, | ||
is_applied INTEGER NOT NULL, | ||
tstamp TIMESTAMP DEFAULT (datetime('now')) | ||
);`, TableName()) | ||
} | ||
|
||
func (m Sqlite3Dialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) | ||
} | ||
|
||
func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m Sqlite3Dialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (m Sqlite3Dialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// Redshift | ||
//////////////////////////// | ||
|
||
// RedshiftDialect struct. | ||
type RedshiftDialect struct{} | ||
|
||
func (rs RedshiftDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id integer NOT NULL identity(1, 1), | ||
version_id bigint NOT NULL, | ||
is_applied boolean NOT NULL, | ||
tstamp timestamp NULL default sysdate, | ||
PRIMARY KEY(id) | ||
);`, TableName()) | ||
} | ||
|
||
func (rs RedshiftDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2);", TableName()) | ||
} | ||
|
||
func (rs RedshiftDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m RedshiftDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (rs RedshiftDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=$1;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// TiDB | ||
//////////////////////////// | ||
|
||
// TiDBDialect struct. | ||
type TiDBDialect struct{} | ||
|
||
func (m TiDBDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT UNIQUE, | ||
version_id bigint NOT NULL, | ||
is_applied boolean NOT NULL, | ||
tstamp timestamp NULL default now(), | ||
PRIMARY KEY(id) | ||
);`, TableName()) | ||
} | ||
|
||
func (m TiDBDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) | ||
} | ||
|
||
func (m TiDBDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m TiDBDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (m TiDBDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// ClickHouse | ||
//////////////////////////// | ||
|
||
// ClickHouseDialect struct. | ||
type ClickHouseDialect struct{} | ||
|
||
func (m ClickHouseDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s ( | ||
version_id Int64, | ||
is_applied UInt8, | ||
date Date default now(), | ||
tstamp DateTime default now() | ||
) | ||
ENGINE = MergeTree() | ||
ORDER BY (date)`, TableName()) | ||
} | ||
|
||
func (m ClickHouseDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied FROM %s ORDER BY version_id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
return fmt.Errorf("%q: unknown dialect", s) | ||
} | ||
return rows, err | ||
} | ||
|
||
func (m ClickHouseDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)", TableName()) | ||
} | ||
|
||
func (m ClickHouseDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (m ClickHouseDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2", TableName()) | ||
} | ||
|
||
//////////////////////////// | ||
// Vertica | ||
//////////////////////////// | ||
|
||
// VerticaDialect struct. | ||
type VerticaDialect struct{} | ||
|
||
func (v VerticaDialect) createVersionTableSQL() string { | ||
return fmt.Sprintf(`CREATE TABLE %s ( | ||
id identity(1,1) NOT NULL, | ||
version_id bigint NOT NULL, | ||
is_applied boolean NOT NULL, | ||
tstamp timestamp NULL default now(), | ||
PRIMARY KEY(id) | ||
);`, TableName()) | ||
} | ||
|
||
func (v VerticaDialect) insertVersionSQL() string { | ||
return fmt.Sprintf("INSERT INTO %s (version_id, is_applied) VALUES (?, ?);", TableName()) | ||
} | ||
|
||
func (v VerticaDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) { | ||
rows, err := db.Query(fmt.Sprintf("SELECT version_id, is_applied from %s ORDER BY id DESC", TableName())) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return rows, err | ||
} | ||
|
||
func (m VerticaDialect) migrationSQL() string { | ||
return fmt.Sprintf("SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1", TableName()) | ||
} | ||
|
||
func (v VerticaDialect) deleteVersionSQL() string { | ||
return fmt.Sprintf("DELETE FROM %s WHERE version_id=?;", TableName()) | ||
var err error | ||
store, err = dialect.NewStore(d, TableName()) | ||
return err | ||
} |
Oops, something went wrong.