-
Notifications
You must be signed in to change notification settings - Fork 3
/
database.go
74 lines (68 loc) · 1.87 KB
/
database.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// The MIT License (MIT)
//
// Copyright (c) 2017 Arnaud Vazard
//
// See LICENSE file.
// Package database manages the database used by the bot
package database
import (
"database/sql"
"errors"
_ "github.com/mattes/migrate/driver/sqlite3"
"github.com/mattes/migrate/migrate"
_ "github.com/mattn/go-sqlite3"
"log"
"os"
)
var dbPtr *sql.DB
// NewDatabase creates a new database.
// If database name is an empty string the default path will be used ("./storage/db.sqlite"),
// else it will be used as the path for the database file.
// If reset is true destroy the database before opening it (which will recreate it).
func NewDatabase(databaseName string, migrationsFolder string, reset bool) *sql.DB {
// Use default name if not specified
if databaseName == "" {
// check if the storage directory exist, if not create it
storage, err := os.Stat("./storage")
if err != nil {
os.Mkdir("./storage", os.ModeDir)
} else if !storage.IsDir() {
// check if the storage is indeed a directory or not
log.Fatal("\"storage\" exist but is not a directory")
}
databaseName = "./storage/db.sqlite"
}
if reset {
os.Remove(databaseName)
}
db, err := sql.Open("sqlite3", databaseName)
if err != nil {
log.Fatal(err)
}
if migrationsFolder == "" {
migrationsFolder = "database/migrations"
}
// Apply migrations
allErrors, ok := migrate.UpSync("sqlite3://"+databaseName, migrationsFolder)
if !ok {
for _, err := range allErrors {
log.Println(err)
}
log.Fatal("Error while applying migrations, exiting ...")
}
dbPtr = db
return db
}
// AddUser adds an user to the database.
func AddUser(nick, email string) (err error) {
if dbPtr == nil {
return errors.New("Database pointer is nil")
}
sqlStmt := `INSERT OR REPLACE INTO User VALUES ($1, $2)`
_, err = dbPtr.Exec(sqlStmt, nick, email)
if err != nil {
log.Printf("%q: %s\n", err, sqlStmt)
return err
}
return nil
}