Skip to content

Commit

Permalink
fix: use global table name in queries (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed May 8, 2023
1 parent 7ce30b7 commit 8759239
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 161 deletions.
4 changes: 2 additions & 2 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func init() {
store, _ = dialect.NewStore(dialect.Postgres, TableName())
store, _ = dialect.NewStore(dialect.Postgres)
}

var store dialect.Store
Expand Down Expand Up @@ -36,6 +36,6 @@ func SetDialect(s string) error {
return fmt.Errorf("%q: unknown dialect", s)
}
var err error
store, err = dialect.NewStore(d, TableName())
store, err = dialect.NewStore(d)
return err
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package dialectquery

import "fmt"

type Clickhouse struct {
Table string
}
type Clickhouse struct{}

var _ Querier = (*Clickhouse)(nil)

func (c *Clickhouse) CreateTable() string {
func (c *Clickhouse) CreateTable(tableName string) string {
q := `CREATE TABLE IF NOT EXISTS %s (
version_id Int64,
is_applied UInt8,
Expand All @@ -17,25 +15,25 @@ func (c *Clickhouse) CreateTable() string {
)
ENGINE = MergeTree()
ORDER BY (date)`
return fmt.Sprintf(q, c.Table)
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) InsertVersion() string {
func (c *Clickhouse) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, c.Table)
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) DeleteVersion() string {
func (c *Clickhouse) DeleteVersion(tableName string) string {
q := `ALTER TABLE %s DELETE WHERE version_id = $1 SETTINGS mutations_sync = 2`
return fmt.Sprintf(q, c.Table)
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) GetMigrationByVersion() string {
func (c *Clickhouse) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id = $1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, c.Table)
return fmt.Sprintf(q, tableName)
}

func (c *Clickhouse) ListMigrations() string {
func (c *Clickhouse) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY version_id DESC`
return fmt.Sprintf(q, c.Table)
return fmt.Sprintf(q, tableName)
}
10 changes: 5 additions & 5 deletions internal/dialect/dialectquery/dialectquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,25 @@ package dialectquery
// specific query.
type Querier interface {
// CreateTable returns the SQL query string to create the db version table.
CreateTable() string
CreateTable(tableName string) string

// InsertVersion returns the SQL query string to insert a new version into
// the db version table.
InsertVersion() string
InsertVersion(tableName string) string

// DeleteVersion returns the SQL query string to delete a version from
// the db version table.
DeleteVersion() string
DeleteVersion(tableName string) string

// GetMigrationByVersion returns the SQL query string to get a single
// migration by version.
//
// The query should return the timestamp and is_applied columns.
GetMigrationByVersion() string
GetMigrationByVersion(tableName string) string

// ListMigrations returns the SQL query string to list all migrations in
// descending order by id.
//
// The query should return the version_id and is_applied columns.
ListMigrations() string
ListMigrations(tableName string) string
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,37 @@ package dialectquery

import "fmt"

type Mysql struct {
Table string
}
type Mysql struct{}

var _ Querier = (*Mysql)(nil)

func (m *Mysql) CreateTable() string {
func (m *Mysql) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
)`
return fmt.Sprintf(q, m.Table)
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) InsertVersion() string {
func (m *Mysql) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, m.Table)
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) DeleteVersion() string {
func (m *Mysql) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, m.Table)
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) GetMigrationByVersion() string {
func (m *Mysql) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, m.Table)
return fmt.Sprintf(q, tableName)
}

func (m *Mysql) ListMigrations() string {
func (m *Mysql) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, m.Table)
return fmt.Sprintf(q, tableName)
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,37 @@ package dialectquery

import "fmt"

type Postgres struct {
Table string
}
type Postgres struct{}

var _ Querier = (*Postgres)(nil)

func (p *Postgres) CreateTable() string {
func (p *Postgres) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
)`
return fmt.Sprintf(q, p.Table)
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) InsertVersion() string {
func (p *Postgres) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, p.Table)
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) DeleteVersion() string {
func (p *Postgres) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, p.Table)
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) GetMigrationByVersion() string {
func (p *Postgres) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, p.Table)
return fmt.Sprintf(q, tableName)
}

func (p *Postgres) ListMigrations() string {
func (p *Postgres) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, p.Table)
return fmt.Sprintf(q, tableName)
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/redshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,37 @@ package dialectquery

import "fmt"

type Redshift struct {
Table string
}
type Redshift struct{}

var _ Querier = (*Redshift)(nil)

func (r *Redshift) CreateTable() string {
func (r *Redshift) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id integer NOT NULL identity(1, 1),
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default sysdate,
PRIMARY KEY(id)
)`
return fmt.Sprintf(q, r.Table)
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) InsertVersion() string {
func (r *Redshift) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES ($1, $2)`
return fmt.Sprintf(q, r.Table)
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) DeleteVersion() string {
func (r *Redshift) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=$1`
return fmt.Sprintf(q, r.Table)
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) GetMigrationByVersion() string {
func (r *Redshift) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=$1 ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, r.Table)
return fmt.Sprintf(q, tableName)
}

func (r *Redshift) ListMigrations() string {
func (r *Redshift) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, r.Table)
return fmt.Sprintf(q, tableName)
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,36 @@ package dialectquery

import "fmt"

type Sqlite3 struct {
Table string
}
type Sqlite3 struct{}

var _ Querier = (*Sqlite3)(nil)

func (s *Sqlite3) CreateTable() string {
func (s *Sqlite3) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL,
is_applied INTEGER NOT NULL,
tstamp TIMESTAMP DEFAULT (datetime('now'))
)`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) InsertVersion() string {
func (s *Sqlite3) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (?, ?)`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) DeleteVersion() string {
func (s *Sqlite3) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=?`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) GetMigrationByVersion() string {
func (s *Sqlite3) GetMigrationByVersion(tableName string) string {
q := `SELECT tstamp, is_applied FROM %s WHERE version_id=? ORDER BY tstamp DESC LIMIT 1`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlite3) ListMigrations() string {
func (s *Sqlite3) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied from %s ORDER BY id DESC`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}
24 changes: 11 additions & 13 deletions internal/dialect/dialectquery/sqlserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,31 @@ package dialectquery

import "fmt"

type Sqlserver struct {
Table string
}
type Sqlserver struct{}

var _ Querier = (*Sqlserver)(nil)

func (s *Sqlserver) CreateTable() string {
func (s *Sqlserver) CreateTable(tableName string) string {
q := `CREATE TABLE %s (
id INT NOT NULL IDENTITY(1,1) PRIMARY KEY,
version_id BIGINT NOT NULL,
is_applied BIT NOT NULL,
tstamp DATETIME NULL DEFAULT CURRENT_TIMESTAMP
)`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlserver) InsertVersion() string {
func (s *Sqlserver) InsertVersion(tableName string) string {
q := `INSERT INTO %s (version_id, is_applied) VALUES (@p1, @p2)`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlserver) DeleteVersion() string {
func (s *Sqlserver) DeleteVersion(tableName string) string {
q := `DELETE FROM %s WHERE version_id=@p1`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlserver) GetMigrationByVersion() string {
func (s *Sqlserver) GetMigrationByVersion(tableName string) string {
q := `
WITH Migrations AS
(
Expand All @@ -42,10 +40,10 @@ FROM Migrations
WHERE RowNumber BETWEEN 1 AND 2
ORDER BY tstamp DESC
`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}

func (s *Sqlserver) ListMigrations() string {
func (s *Sqlserver) ListMigrations(tableName string) string {
q := `SELECT version_id, is_applied FROM %s ORDER BY id DESC`
return fmt.Sprintf(q, s.Table)
return fmt.Sprintf(q, tableName)
}
Loading

0 comments on commit 8759239

Please sign in to comment.