/
session.go
220 lines (192 loc) · 5.82 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package db
import (
"database/sql"
"fmt"
"io/ioutil"
"github.com/quasoft/pgmig/mig"
// Import postgres DB driver
_ "github.com/lib/pq"
)
// Session represents a user session to a specific PostgreSQL database
type Session struct {
Host string
Port string
Database string
Username string
SslMode string
ChangelogName string
Interactive bool
db *sql.DB
}
// NewSession creates a new database session object
func NewSession() *Session {
return &Session{}
}
// Connect creates a new connection to the database and makes sure it is responding by pinging it.
func (s *Session) Connect() error {
// Build connection string
password := getPassword(s.Interactive)
connStr := buildConnString(s.Host, s.Port, s.Database, s.Username, password, s.SslMode)
// Open connection
db, err := sql.Open("postgres", connStr)
if err != nil {
return fmt.Errorf("could not open DB connection: %v", err)
}
// Test DB connection (ping)
var dummy string
err = db.QueryRow("SELECT 1;").Scan(&dummy)
if err != nil || dummy != "1" {
db.Close()
return fmt.Errorf("could not ping DB: %v", err)
}
s.db = db
return nil
}
// Disconnect closes the database connection
func (s *Session) Disconnect() error {
if s.db == nil {
return nil
}
return s.db.Close()
}
// EnsureChangelogExists creates the changelog table if it does not exist
func (s *Session) EnsureChangelogExists() error {
// TODO: Remove unused fields from table structure
sql := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "%s" (
id serial,
version integer NOT NULL,
file_name varchar(2048) NOT NULL,
applied_by varchar(100) NOT NULL DEFAULT CURRENT_USER,
date_time timestamp without time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
state bool NOT NULL DEFAULT false,
CONSTRAINT "%s_pkey" PRIMARY KEY(id),
CONSTRAINT "%s_version_unique" UNIQUE(version)
)`,
sanitizeIdentifier(s.ChangelogName),
sanitizeIdentifier(s.ChangelogName),
sanitizeIdentifier(s.ChangelogName),
)
_, err := s.db.Exec(sql)
return err
}
// insertLog records the migration in the changelog
func (s *Session) insertLog(m mig.File) error {
sql := fmt.Sprintf(
`INSERT INTO %s (version, file_name) VALUES($1, $2)`,
sanitizeIdentifier(s.ChangelogName),
)
_, err := s.db.Exec(sql, m.Ver, m.FileName)
return err
}
func (s *Session) updateLog(migVer int, state bool) error {
sql := fmt.Sprintf(
`UPDATE %s SET state = $1 WHERE version = $2`,
sanitizeIdentifier(s.ChangelogName),
)
_, err := s.db.Exec(sql, state, migVer)
return err
}
// lastMigratedVer returns the version of the last migration file that was applied successfully,
// according to the changelog table
func (s *Session) lastMigratedVer() (int, error) {
query := fmt.Sprintf(
`SELECT COALESCE(MAX(version), 0) FROM "%s" WHERE state = true`,
sanitizeIdentifier(s.ChangelogName),
)
var migVer int
err := s.db.QueryRow(query).Scan(&migVer)
if err == sql.ErrNoRows {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("could not get version of last migration from changelog table: %v", err)
}
return migVer, nil
}
// failed checks if the specified migration is in failed state
func (s *Session) failed(migVer int) (bool, error) {
sql := fmt.Sprintf(
`SELECT COUNT(*) FROM "%s" WHERE state = false AND version = $1`,
sanitizeIdentifier(s.ChangelogName),
)
var cnt int
err := s.db.QueryRow(sql, migVer).Scan(&cnt)
if err != nil {
return false, fmt.Errorf("could not check in changelog %s if migration #%d failed: %v", s.ChangelogName, migVer, err)
}
return cnt > 0, nil
}
// wasApplied checks if the specified migration was applied to DB
func (s *Session) wasApplied(migVer int) (bool, error) {
sql := fmt.Sprintf(
`SELECT COUNT(*) FROM "%s" WHERE state = true AND version = $1`,
sanitizeIdentifier(s.ChangelogName),
)
var cnt int
err := s.db.QueryRow(sql, migVer).Scan(&cnt)
if err != nil {
return false, fmt.Errorf("could not check in changelog %s if migration #%d was applied: %v", s.ChangelogName, migVer, err)
}
return cnt > 0, nil
}
func (s *Session) Apply(m mig.File) error {
bytes, err := ioutil.ReadFile(m.Path)
if err != nil {
return fmt.Errorf("could not read migration file %s: %v", m.FileName, err)
}
sql := string(bytes)
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("could not open transaction: %v", err)
}
hasFailed, err := s.failed(m.Ver)
if err != nil {
return fmt.Errorf("could not check state of migration #%d for file %s: %v", m.Ver, m.FileName, err)
}
if !hasFailed {
err = s.insertLog(m)
}
if err != nil {
tx.Rollback()
return fmt.Errorf("could not add migration #%d for file %s to changelog, rolling back...: %v", m.Ver, m.FileName, err)
}
_, err = s.db.Exec(sql)
if err != nil {
tx.Rollback()
return fmt.Errorf("could not execute migration #%d from file %s: %v", m.Ver, m.FileName, err)
}
err = s.updateLog(m.Ver, true)
if err != nil {
tx.Rollback()
return fmt.Errorf("could not mark migration #%d for file %s as completed in DB, rolling back...: %v", m.Ver, m.FileName, err)
}
return tx.Commit()
}
// PendingMigrations returns a list of migration files that have not been applied yet, according to the changelog
func (s *Session) PendingMigrations(dir *mig.Dir) ([]mig.File, error) {
// TODO: Use version of last applied migration and only check later migrations
// Get version of last applied migration
lastVer, err := s.lastMigratedVer()
if err != nil {
return nil, fmt.Errorf("could not determine version of last migration: %v", err)
}
allMigrations, err := dir.Migrations()
if err != nil {
return nil, err
}
var pending []mig.File
for _, m := range allMigrations {
if m.Ver <= lastVer {
continue
}
// Make sure the specific migration was not applied
applied, err := s.wasApplied(m.Ver)
if err != nil {
return pending, err
}
if !applied {
pending = append(pending, m)
}
}
return pending, nil
}