diff --git a/src/a2sapi.go b/src/a2sapi.go index ea91361..f289b4b 100644 --- a/src/a2sapi.go +++ b/src/a2sapi.go @@ -67,12 +67,8 @@ func launch(isDebug bool) { } // Initialize the application-wide configuration config.InitConfig() - - // Verify that geolocation DB can be read (will panic if it cannot) - _, err := db.OpenCountryDB() - if err != nil { - os.Exit(1) - } + // Initialize the application-wide database connections (panic on failure) + db.InitDBs() if !runSilent { printStartInfo() diff --git a/src/db/country.go b/src/db/country.go index b5f81e2..bc54ed5 100644 --- a/src/db/country.go +++ b/src/db/country.go @@ -13,6 +13,11 @@ import ( "github.com/oschwald/maxminddb-golang" ) +// CDB represents a database containing geolocation information. +type CDB struct { + db *maxminddb.Reader +} + // This is an intermediate struct to represent the MaxMind DB format, not for JSON type mmdbformat struct { Country struct { @@ -37,11 +42,10 @@ func getDefaultCountryData() models.DbCountry { } // OpenCountryDB opens the country lookup database for reading. The caller of -// this function will be responsinble for calling a .Close() on the Reader pointer -// returned by this function. -func OpenCountryDB() (*maxminddb.Reader, error) { - // Note: the caller of this function needs to handle db.Close() - db, err := maxminddb.Open(constants.CountryDbFilePath) +// this function will be responsinble for calling .Close(). +func OpenCountryDB() (*CDB, error) { + // Note: the caller of this function needs to handle .Close() + conn, err := maxminddb.Open(constants.CountryDbFilePath) if err != nil { dir := "build_nix" if runtime.GOOS == "windows" { @@ -56,15 +60,23 @@ http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz and extract the "GeoLite2-City.mmdb" file into a directory called "db" in the same directory as the a2sapi executable. Error: %s`, dir, err)) } - return db, nil + return &CDB{db: conn}, nil +} + +// Close closes the country geolocation database. +func (cdb *CDB) Close() { + err := cdb.db.Close() + if err != nil { + logger.LogAppErrorf("Error closing country database: %s", err) + } } // GetCountryInfo attempts to retrieve the country information for a given IP, // returning the result as a country model object over the corresponding result channel. -func GetCountryInfo(ch chan<- models.DbCountry, db *maxminddb.Reader, ipstr string) { +func (cdb *CDB) GetCountryInfo(ch chan<- models.DbCountry, ipstr string) { ip := net.ParseIP(ipstr) c := &mmdbformat{} - err := db.Lookup(ip, c) + err := cdb.db.Lookup(ip, c) if err != nil { ch <- getDefaultCountryData() return diff --git a/src/db/country_test.go b/src/db/country_test.go index 8cf5f5d..736f18d 100644 --- a/src/db/country_test.go +++ b/src/db/country_test.go @@ -29,7 +29,7 @@ func TestGetCountryInfo(t *testing.T) { c := make(chan models.DbCountry, 1) ip := "192.211.62.11" cinfo := models.DbCountry{} - go GetCountryInfo(c, cdb, ip) + go cdb.GetCountryInfo(c, ip) cinfo = <-c if !strings.EqualFold(cinfo.CountryCode, "US") { t.Fatalf("Expected country code to be US for IP: %s, got: %s", @@ -37,7 +37,7 @@ func TestGetCountryInfo(t *testing.T) { } ip = "89.20.244.197" cinfo = models.DbCountry{} - go GetCountryInfo(c, cdb, ip) + go cdb.GetCountryInfo(c, ip) cinfo = <-c if !strings.EqualFold(cinfo.CountryCode, "NO") { t.Fatalf("Expected country code to be NO for IP: %s, got: %s", diff --git a/src/db/database.go b/src/db/database.go new file mode 100644 index 0000000..b57b81a --- /dev/null +++ b/src/db/database.go @@ -0,0 +1,56 @@ +package db + +// database.go - Database initilization. + +import ( + "a2sapi/src/constants" + "a2sapi/src/logger" + "a2sapi/src/util" + "fmt" +) + +// CountryDB is a package-level variable that contains a country +// geoelocation database connection. It is initialized once for re-usability +// when building server lists. +var CountryDB *CDB + +// ServerDB is a package-level variable that contains a server information +// database connection. It is initialized once for re-usability when building +// server lists. +var ServerDB *SDB + +// InitDBs initializes the geolocation and server information databases for +// re-use across server list builds. Panics on failure to initialize. +func InitDBs() { + if CountryDB != nil && ServerDB != nil { + return + } + + cdb, err := OpenCountryDB() + if err != nil { + panic(fmt.Sprintf("Unable to initialize country database connection: %s", + err)) + } + sdb, err := OpenServerDB() + if err != nil { + panic(fmt.Sprintf( + "Unable to initialize server information database connection: %s", err)) + } + // Set package-level variables + CountryDB = cdb + ServerDB = sdb +} + +func verifyServerDbPath() error { + if err := util.CreateDirectory(constants.DbDirectory); err != nil { + logger.LogAppError(err) + panic(fmt.Sprintf("Unable to create database directory %s: %s", + constants.DbDirectory, err)) + } + if err := createServerDBtable(constants.GetServerDBPath()); err != nil { + logger.LogAppErrorf("Unable to verify database path: %s", err) + panic("Unable to verify database path") + } + + return nil +} diff --git a/src/db/dbfile.go b/src/db/dbfile.go deleted file mode 100644 index 7a30606..0000000 --- a/src/db/dbfile.go +++ /dev/null @@ -1,24 +0,0 @@ -package db - -// dbfile.go - database file operations - -import ( - "a2sapi/src/constants" - "a2sapi/src/logger" - "a2sapi/src/util" - "fmt" -) - -func verifyServerDbPath() error { - if err := util.CreateDirectory(constants.DbDirectory); err != nil { - logger.LogAppError(err) - panic(fmt.Sprintf("Unable to create database directory %s: %s", - constants.DbDirectory, err)) - } - if err := createServerDB(constants.GetServerDBPath()); err != nil { - logger.LogAppErrorf("Unable to verify database path: %s", err) - panic("Unable to verify database path") - } - - return nil -} diff --git a/src/db/servers.go b/src/db/servers.go index cd9d2e8..c391a25 100644 --- a/src/db/servers.go +++ b/src/db/servers.go @@ -14,7 +14,12 @@ import ( _ "github.com/mattn/go-sqlite3" ) -func createServerDB(dbfile string) error { +// SDB represents a database containing the server ID and game information. +type SDB struct { + db *sql.DB +} + +func createServerDBtable(dbfile string) error { create := `CREATE TABLE servers ( server_id INTEGER NOT NULL, host TEXT NOT NULL, @@ -63,8 +68,8 @@ func createServerDB(dbfile string) error { return nil } -func serverExists(db *sql.DB, host string, game string) (bool, error) { - rows, err := db.Query( +func (sdb *SDB) serverExists(host string, game string) (bool, error) { + rows, err := sdb.db.Query( "SELECT host, game FROM servers WHERE host =? AND GAME =? LIMIT 1", host, game) if err != nil { @@ -88,8 +93,8 @@ func serverExists(db *sql.DB, host string, game string) (bool, error) { return false, nil } -func getHostAndGame(db *sql.DB, id string) (host, game string, err error) { - rows, err := db.Query("SELECT host, game FROM servers WHERE server_id =? LIMIT 1", +func (sdb *SDB) getHostAndGame(id string) (host, game string, err error) { + rows, err := sdb.db.Query("SELECT host, game FROM servers WHERE server_id =? LIMIT 1", id) if err != nil { return host, game, @@ -109,28 +114,36 @@ func getHostAndGame(db *sql.DB, id string) (host, game string, err error) { // OpenServerDB Opens a database connection to the server database file or if // that file does not exists, creates it and then opens a database connection to it. -func OpenServerDB() (*sql.DB, error) { +func OpenServerDB() (*SDB, error) { if err := verifyServerDbPath(); err != nil { // will panic if not verified return nil, logger.LogAppError(err) } - db, err := sql.Open("sqlite3", constants.GetServerDBPath()) + conn, err := sql.Open("sqlite3", constants.GetServerDBPath()) if err != nil { return nil, logger.LogAppError(err) } - return db, nil + return &SDB{db: conn}, nil +} + +// Close closes the server database's underlying connection. +func (sdb *SDB) Close() { + err := sdb.db.Close() + if err != nil { + logger.LogAppErrorf("Error closing server DB: %s", err) + } } // AddServersToDB inserts a specified host and port with its game name into the // server database. -func AddServersToDB(db *sql.DB, hostsgames map[string]string) { +func (sdb *SDB) AddServersToDB(hostsgames map[string]string) { toInsert := make(map[string]string, len(hostsgames)) for host, game := range hostsgames { // If direct queries are enabled, don't add 'Unspecified' game to server DB if game == filters.GameUnspecified.String() { continue } - exists, err := serverExists(db, host, game) + exists, err := sdb.serverExists(host, game) if err != nil { continue } @@ -139,7 +152,7 @@ func AddServersToDB(db *sql.DB, hostsgames map[string]string) { } toInsert[host] = game } - tx, err := db.Begin() + tx, err := sdb.db.Begin() if err != nil { logger.LogAppErrorf("AddServersToDB error creating tx: %s", err) return @@ -171,11 +184,11 @@ func AddServersToDB(db *sql.DB, hostsgames map[string]string) { // server detail list or the list of server details in response to a request // coming in over the API. It sends its results over a map channel consisting of // a host to id mapping. -func GetIDsForServerList(result chan map[string]int64, db *sql.DB, +func (sdb *SDB) GetIDsForServerList(result chan map[string]int64, hosts map[string]string) { m := make(map[string]int64, len(hosts)) for host, game := range hosts { - rows, err := db.Query( + rows, err := sdb.db.Query( "SELECT server_id FROM servers WHERE host =? AND game =? LIMIT 1", host, game) if err != nil { @@ -203,11 +216,11 @@ func GetIDsForServerList(result chan map[string]int64, db *sql.DB, // set of hosts (represented by query string values) from the server database // file in response to a query from the API. Sends the results over a DbServerID // channel for consumption. -func GetIDsAPIQuery(result chan *models.DbServerID, db *sql.DB, hosts []string) { +func (sdb *SDB) GetIDsAPIQuery(result chan *models.DbServerID, hosts []string) { m := &models.DbServerID{} for _, h := range hosts { logger.WriteDebug("DB: GetIDsAPIQuery, host: %s", h) - rows, err := db.Query( + rows, err := sdb.db.Query( "SELECT server_id, host, game FROM servers WHERE host LIKE ?", fmt.Sprintf("%%%s%%", h)) if err != nil { @@ -242,11 +255,11 @@ func GetIDsAPIQuery(result chan *models.DbServerID, db *sql.DB, hosts []string) // server database file in response to a user-specified API query for a given // set of server ID numbers. Sends the results over a channel consisting of a // host to game name string mapping. -func GetHostsAndGameFromIDAPIQuery(result chan map[string]string, db *sql.DB, +func (sdb *SDB) GetHostsAndGameFromIDAPIQuery(result chan map[string]string, ids []string) { hosts := make(map[string]string, len(ids)) for _, id := range ids { - host, game, err := getHostAndGame(db, id) + host, game, err := sdb.getHostAndGame(id) if err != nil { logger.LogAppErrorf("Error getting host from ID for API query: %s", err) return diff --git a/src/db/servers_test.go b/src/db/servers_test.go index 7095b90..b577856 100644 --- a/src/db/servers_test.go +++ b/src/db/servers_test.go @@ -17,8 +17,8 @@ func init() { testData["172.16.0.1"] = "QuakeLive" } -func TestCreateServerDB(t *testing.T) { - err := createServerDB(constants.TestServerDbFilePath) +func TestCreateServerDBtable(t *testing.T) { + err := createServerDBtable(constants.TestServerDbFilePath) if err != nil { t.Fatalf("Unable to create test DB file: %s", err) } @@ -30,7 +30,7 @@ func TestAddServersToDB(t *testing.T) { t.Fatalf("Unable to open test database: %s", err) } defer db.Close() - AddServersToDB(db, testData) + db.AddServersToDB(testData) } func TestGetIDsForServerList(t *testing.T) { @@ -40,7 +40,7 @@ func TestGetIDsForServerList(t *testing.T) { t.Fatalf("Unable to open test database: %s", err) } defer db.Close() - GetIDsForServerList(c, db, testData) + db.GetIDsForServerList(c, testData) result := <-c if len(result) != 2 { t.Fatalf("Expected 2 results, got: %d", len(result)) @@ -63,7 +63,7 @@ func TestGetIDsAPIQuery(t *testing.T) { defer db.Close() h1 := []string{"10.0.0.10"} h2 := []string{"172.16.0.1"} - GetIDsAPIQuery(c1, db, h1) + db.GetIDsAPIQuery(c1, h1) r1 := <-c1 if len(r1.Servers) != 1 { t.Fatalf("Expected 1 server, got: %d", len(r1.Servers)) @@ -71,7 +71,7 @@ func TestGetIDsAPIQuery(t *testing.T) { if !strings.EqualFold(r1.Servers[0].Game, "Reflex") { t.Fatalf("Expected result 1 to be Reflex, got: %v", r1.Servers[0].Game) } - GetIDsAPIQuery(c2, db, h2) + db.GetIDsAPIQuery(c2, h2) r2 := <-c2 if len(r2.Servers) != 1 { t.Fatalf("Expected 1 server, got: %d", len(r2.Servers)) @@ -89,7 +89,7 @@ func TestGetHostsAndGameFromIDAPIQuery(t *testing.T) { } defer db.Close() ids := []string{"1", "2"} - GetHostsAndGameFromIDAPIQuery(c, db, ids) + db.GetHostsAndGameFromIDAPIQuery(c, ids) result := <-c if len(result) != 2 { t.Fatalf("Expected 2 results, got: %d", len(result)) diff --git a/src/steam/listbuilder.go b/src/steam/listbuilder.go index 9c96710..a9d6de9 100644 --- a/src/steam/listbuilder.go +++ b/src/steam/listbuilder.go @@ -9,7 +9,6 @@ import ( "a2sapi/src/logger" "a2sapi/src/models" "a2sapi/src/steam/filters" - "database/sql" "net" "strconv" "strings" @@ -32,12 +31,6 @@ func buildServerList(data a2sData, addtoServerDB bool) (*models.APIServerList, FailedServers: make([]string, 0), } - cdb, err := db.OpenCountryDB() - if err != nil { - return nil, logger.LogAppError(err) - } - defer cdb.Close() - for host, game := range data.HostsGames { info, iok := data.Info[host] players, pok := data.Players[host] @@ -76,8 +69,8 @@ func buildServerList(data a2sData, addtoServerDB bool) (*models.APIServerList, Rules: rules, Info: info, } - // Hack for gametype support, which can be found in rules, info, or not - // at all depending on the game + // Gametype support: gametype can be found in rules, info, or not + // at all depending on the game (currently just for QuakeLive & Reflex) srv.Info.GameTypeShort, srv.Info.GameTypeFull = getGameType(game, srv) ip, port, serr := net.SplitHostPort(host) @@ -92,7 +85,7 @@ func buildServerList(data a2sData, addtoServerDB bool) (*models.APIServerList, srvDBhosts[host] = game.Name } loc := make(chan models.DbCountry, 1) - go db.GetCountryInfo(loc, cdb, ip) + go db.CountryDB.GetCountryInfo(loc, ip) srv.CountryInfo = <-loc } sl.Servers = append(sl.Servers, srv) @@ -108,12 +101,8 @@ func buildServerList(data a2sData, addtoServerDB bool) (*models.APIServerList, sl.FailedCount = len(sl.FailedServers) if len(srvDBhosts) != 0 { - sdb, err := db.OpenServerDB() - if err != nil { - return nil, logger.LogAppError(err) - } - go db.AddServersToDB(sdb, srvDBhosts) - sl = setServerIDForList(sdb, sl) + go db.ServerDB.AddServersToDB(srvDBhosts) + sl = setServerIDForList(sl) } logger.LogAppInfo( @@ -152,14 +141,13 @@ func removeBuggedPlayers(players []models.SteamPlayerInfo) models.FilteredPlayer return rpi } -func setServerIDForList(sdb *sql.DB, - sl *models.APIServerList) *models.APIServerList { +func setServerIDForList(sl *models.APIServerList) *models.APIServerList { toSet := make(map[string]string, len(sl.Servers)) for _, s := range sl.Servers { toSet[s.Host] = s.Game } result := make(chan map[string]int64, 1) - go db.GetIDsForServerList(result, sdb, toSet) + go db.ServerDB.GetIDsForServerList(result, toSet) m := <-result for _, s := range sl.Servers { diff --git a/src/test/test_funcs.go b/src/test/test_funcs.go index b2ca53e..a8daf20 100644 --- a/src/test/test_funcs.go +++ b/src/test/test_funcs.go @@ -3,6 +3,7 @@ package test import ( "a2sapi/src/config" "a2sapi/src/constants" + "a2sapi/src/db" "fmt" "os" ) @@ -27,6 +28,9 @@ func SetupEnvironment() { // Dump is not in test directory and needs config access deleteFiles(constants.DumpFileFullPath( config.Config.DebugConfig.ServerDumpFilename)) + + // Initialize database connections + db.InitDBs() } func deleteFiles(filepaths ...string) { diff --git a/src/web/retrievers.go b/src/web/retrievers.go index 360095e..0a426ea 100644 --- a/src/web/retrievers.go +++ b/src/web/retrievers.go @@ -13,17 +13,7 @@ import ( func getServerIDRetriever(w http.ResponseWriter, hosts []string) { m := make(chan *models.DbServerID, 1) - sdb, err := db.OpenServerDB() - if err != nil { - setNotFoundAndLog(w, err) - if err := json.NewEncoder(w).Encode(models.GetDefaultServerID()); err != nil { - writeJSONEncodeError(w, err) - return - } - return - } - defer sdb.Close() - go db.GetIDsAPIQuery(m, sdb, hosts) + go db.ServerDB.GetIDsAPIQuery(m, hosts) ids := <-m if len(ids.Servers) > 0 { if err := json.NewEncoder(w).Encode(ids); err != nil { @@ -41,16 +31,7 @@ func getServerIDRetriever(w http.ResponseWriter, hosts []string) { func queryServerIDRetriever(w http.ResponseWriter, ids []string) { s := make(chan map[string]string, len(ids)) - sdb, err := db.OpenServerDB() - if err != nil { - setNotFoundAndLog(w, err) - if err := json.NewEncoder(w).Encode(models.GetDefaultServerList()); err != nil { - writeJSONEncodeError(w, err) - } - return - } - defer sdb.Close() - db.GetHostsAndGameFromIDAPIQuery(s, sdb, ids) + db.ServerDB.GetHostsAndGameFromIDAPIQuery(s, ids) hostsgames := <-s if len(hostsgames) == 0 { w.WriteHeader(http.StatusNotFound)