-
Notifications
You must be signed in to change notification settings - Fork 4
/
db.go
333 lines (282 loc) · 10.2 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
package db
import (
"database/sql"
"errors"
"fmt"
"os"
"strings"
"github.com/sirgwain/craig-stars/config"
"github.com/sirgwain/craig-stars/cs"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/jmoiron/sqlx"
"github.com/jmoiron/sqlx/reflectx"
"github.com/mattn/go-sqlite3"
sqldblogger "github.com/simukti/sqldb-logger"
)
// DBConn represents a connection to the database
// Some database connections are read only, others are readwrite
// A readWrite connection can be wrapped in a transaction, but be warned, this locks the
// database until the transaction completes or fails.
type DBConn interface {
Connect(config *config.Config) error
Close() error
// create a new read client
NewReadClient() Client
NewReadWriteClient() Client
// for write clients we use transactions
BeginTransaction() (Client, error)
Rollback(c Client) error
Commit(c Client) error
// wrap a function call inside a transaction
WrapInTransaction(wrap func(c Client) error) error
}
// A database Client interface is used to make all calls that modify the database
type Client interface {
// private transaction management methods used by DBConn RollBack, Commit
rollback() error
commit() error
// private method used during DBConn Connect to upgrade a client
// this is
ensureUpgrade() error
GetUsers() ([]cs.User, error)
GetUser(id int64) (*cs.User, error)
GetUserByUsername(username string) (*cs.User, error)
GetGuestUser(hash string) (*cs.User, error)
GetGuestUserForGame(gameID int64, playerNum int) (*cs.User, error)
GetGuestUsersForGame(gameID int64) ([]cs.User, error)
CreateUser(user *cs.User) error
UpdateUser(user *cs.User) error
DeleteUser(id int64) error
DeleteGameUsers(gameID int64) error
GetUsersForGame(gameID int64) ([]cs.User, error)
GetRaces() ([]cs.Race, error)
GetRacesForUser(userID int64) ([]cs.Race, error)
GetRace(id int64) (*cs.Race, error)
CreateRace(race *cs.Race) error
UpdateRace(race *cs.Race) error
DeleteRace(id int64) error
DeleteUserRaces(userID int64) error
GetTechStores() ([]cs.TechStore, error)
CreateTechStore(tech *cs.TechStore) error
GetTechStore(id int64) (*cs.TechStore, error)
GetRulesForGame(gameID int64) (*cs.Rules, error)
GetGames() ([]cs.Game, error)
GetGamesWithPlayers() ([]cs.GameWithPlayers, error)
GetGamesForHost(userID int64) ([]cs.GameWithPlayers, error)
GetGamesForUser(userID int64) ([]cs.GameWithPlayers, error)
GetOpenGames() ([]cs.GameWithPlayers, error)
GetOpenGamesByHash(hash string) ([]cs.GameWithPlayers, error)
GetGame(id int64) (*cs.GameWithPlayers, error)
GetGameWithPlayersStatus(gameID int64) (*cs.GameWithPlayers, error)
GetFullGame(id int64) (*cs.FullGame, error)
CreateGame(game *cs.Game) error
UpdateGame(game *cs.Game) error
UpdateGameState(gameID int64, state cs.GameState) error
UpdateFullGame(fullGame *cs.FullGame) error
UpdateGameHost(gameID int64, hostId int64) error
DeleteGame(id int64) error
DeleteUserGames(hostID int64) error
GetPlayers() ([]cs.Player, error)
GetPlayersForUser(userID int64) ([]cs.Player, error)
GetPlayer(id int64) (*cs.Player, error)
GetLightPlayerForGame(gameID, userID int64) (*cs.Player, error)
GetPlayersStatusForGame(gameID int64) ([]*cs.Player, error)
GetPlayerForGame(gameID, userID int64) (*cs.Player, error)
GetPlayerIntelsForGame(gameID, userID int64) (*cs.PlayerIntels, error)
GetPlayerByNum(gameID int64, num int) (*cs.Player, error)
GetFullPlayerForGame(gameID, userID int64) (*cs.FullPlayer, error)
GetPlayerMapObjects(gameID, userID int64) (*cs.PlayerMapObjects, error)
GetPlayerWithDesignsForGame(gameID int64, num int) (*cs.Player, error)
CreatePlayer(player *cs.Player) error
UpdatePlayer(player *cs.Player) error
SubmitPlayerTurn(gameID int64, num int, submittedTurn bool) error
UpdatePlayerOrders(player *cs.Player) error
UpdatePlayerRelations(player *cs.Player) error
UpdatePlayerSpec(player *cs.Player) error
UpdatePlayerPlans(player *cs.Player) error
UpdatePlayerSalvageIntels(player *cs.Player) error
UpdateLightPlayer(player *cs.Player) error
UpdatePlayerUserId(player *cs.Player) error
DeletePlayer(id int64) error
GetShipDesignsForPlayer(gameID int64, playerNum int) ([]*cs.ShipDesign, error)
GetShipDesign(id int64) (*cs.ShipDesign, error)
GetShipDesignByNum(gameID int64, playerNum, num int) (*cs.ShipDesign, error)
CreateShipDesign(shipDesign *cs.ShipDesign) error
UpdateShipDesign(shipDesign *cs.ShipDesign) error
DeleteShipDesign(id int64) error
GetPlanet(id int64) (*cs.Planet, error)
GetPlanetByNum(gameID int64, num int) (*cs.Planet, error)
GetPlanetsForPlayer(gameID int64, playerNum int) ([]*cs.Planet, error)
UpdatePlanet(planet *cs.Planet) error
UpdatePlanetSpec(planet *cs.Planet) error
GetFleet(id int64) (*cs.Fleet, error)
GetFleetByNum(gameID int64, playerNum int, num int) (*cs.Fleet, error)
GetFleetsByNums(gameID int64, playerNum int, nums []int) ([]*cs.Fleet, error)
CreateFleet(fleet *cs.Fleet) error
UpdateFleet(fleet *cs.Fleet) error
CreateUpdateOrDeleteFleets(gameID int64, fleets []*cs.Fleet) error
DeleteFleet(id int64) error
GetFleetsForPlayer(gameID int64, playerNum int) ([]*cs.Fleet, error)
GetFleetsOrbitingPlanet(gameID int64, planetNum int) ([]*cs.Fleet, error)
GetMineField(id int64) (*cs.MineField, error)
GetMineFieldsForPlayer(gameID int64, playerNum int) ([]*cs.MineField, error)
UpdateMineField(fleet *cs.MineField) error
GetMineralPacket(id int64) (*cs.MineralPacket, error)
GetMineralPacketsForPlayer(gameID int64, playerNum int) ([]*cs.MineralPacket, error)
GetSalvagesForGame(gameID int64) ([]*cs.Salvage, error)
GetSalvagesForPlayer(gameID int64, playerNum int) ([]*cs.Salvage, error)
GetSalvageByNum(gameID int64, num int) (*cs.Salvage, error)
CreateSalvage(salvage *cs.Salvage) error
UpdateSalvage(salvage *cs.Salvage) error
}
type dbConn struct {
dbRead *sqlx.DB
dbWrite *sqlx.DB
databaseInMemory bool
usersInMemory bool
}
type client struct {
reader sqlReader
writer sqlWriter
tx *sqlx.Tx
converter Converter
}
type sqlReader interface {
Select(dest interface{}, query string, args ...interface{}) error
Get(dest interface{}, query string, args ...interface{}) error
Rebind(query string) string
}
type sqlWriter interface {
NamedExec(query string, arg interface{}) (sql.Result, error)
Exec(query string, args ...any) (sql.Result, error)
}
func NewConn() DBConn {
return &dbConn{}
}
func (conn *dbConn) NewReadClient() Client {
return &client{
reader: conn.dbRead,
converter: c,
}
}
func (conn *dbConn) NewReadWriteClient() Client {
return &client{
reader: conn.dbRead,
writer: conn.dbWrite,
converter: c,
}
}
// create a new dbClient from a transaction
func newTransactionClient(tx *sqlx.Tx) *client {
return &client{
reader: tx,
writer: tx,
tx: tx,
converter: c,
}
}
func (conn *dbConn) BeginTransaction() (Client, error) {
tx, err := conn.dbWrite.Beginx()
if err != nil {
return nil, err
}
return newTransactionClient(tx), nil
}
func (conn *dbConn) Rollback(c Client) error {
return c.rollback()
}
func (conn *dbConn) Commit(c Client) error {
return c.commit()
}
// helper function to wrap a series of db calls in a transaction
func (conn *dbConn) WrapInTransaction(wrap func(c Client) error) error {
c, err := conn.BeginTransaction()
if err != nil {
return err
}
defer func() { conn.Rollback(c) }()
if err := wrap(c); err != nil {
return err
}
return conn.Commit(c)
}
func (c *dbConn) Connect(cfg *config.Config) error {
c.databaseInMemory = strings.Contains(cfg.Database.Filename, ":memory:")
c.usersInMemory = strings.Contains(cfg.Database.UsersFilename, ":memory:")
// if we are using a file based db, we have to exec the schema sql when we first
// set it up
if !c.databaseInMemory && cfg.Database.Recreate {
// check if the db exists
info, err := os.Stat(cfg.Database.Filename)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
}
// delete the db and recreate it if we are configured for that
if info != nil {
log.Debug().Msgf("Deleting existing database %s", cfg.Database.Filename)
os.Remove(cfg.Database.Filename)
}
}
// make sure the database is up to date
c.mustMigrate(cfg)
// create a new logger for logging database calls
var zlogger zerolog.Logger
if cfg.Database.DebugLogging {
zlogger = zerolog.New(os.Stderr).With().Timestamp().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr}).Level(zerolog.DebugLevel)
} else {
zlogger = zerolog.New(os.Stderr).With().Timestamp().Logger().Level(zerolog.WarnLevel)
}
loggerAdapter := newLoggerWithLogger(&zlogger)
// dsn is like file::memory:?cache=shared, or file:data.db?_journal=WAL
dsn := fmt.Sprintf("file:%s%s", cfg.Database.Filename, cfg.Database.ReadConnectionParams)
log.Debug().Msgf("Connecting to database %s", dsn)
connectHook := func(conn *sqlite3.SQLiteConn) error {
log.Debug().Msgf("Attaching Users database %s", cfg.Database.UsersFilename)
if _, err := conn.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as users;", cfg.Database.UsersFilename), nil); err != nil {
return err
}
if _, err := conn.Exec("PRAGMA foreign_keys = ON;", nil); err != nil {
return err
}
return nil
}
dbRead := sqldblogger.OpenDriver(dsn, &sqlite3.SQLiteDriver{ConnectHook: connectHook}, loggerAdapter)
dbWrite := sqldblogger.OpenDriver(dsn, &sqlite3.SQLiteDriver{ConnectHook: connectHook}, loggerAdapter)
c.dbRead = sqlx.NewDb(dbRead, "sqlite3")
if c.databaseInMemory {
// no separate write connetion for in memory dbs
c.dbWrite = c.dbRead
} else {
c.dbWrite = sqlx.NewDb(dbWrite, "sqlite3")
c.dbWrite.SetMaxOpenConns(1)
}
// Create a new mapper which will use the struct field tag "json" instead of "db"
c.dbRead.Mapper = reflectx.NewMapperFunc("json", strings.ToLower)
c.dbWrite.Mapper = reflectx.NewMapperFunc("json", strings.ToLower)
// do some special processing for in memory databases
if c.databaseInMemory {
c.setupInMemoryDatabase()
}
// make sure the data is updated
if !cfg.Database.SkipUpgrade {
c.mustUpgrade()
}
log.Info().Msg("connect() complete")
return nil
}
func (c *dbConn) Close() error {
if err := c.dbRead.Close(); err != nil {
return err
}
if err := c.dbWrite.Close(); err != nil {
return err
}
return nil
}
func (c *client) rollback() error {
return c.tx.Rollback()
}
func (c *client) commit() error {
return c.tx.Commit()
}