diff --git a/database/connection.go b/database/connection.go index 8f08e64c..4a9a48ce 100644 --- a/database/connection.go +++ b/database/connection.go @@ -1,7 +1,6 @@ package database import ( - "database/sql" "fmt" "log/slog" @@ -47,27 +46,17 @@ func (c *Connection) Close() bool { return true } -func (c *Connection) Ping() { - var driver *sql.DB - - slog.Info("Database ping started", "separator", "---------") - - if conn, err := c.driver.DB(); err != nil { - slog.Error("Error retrieving the db driver", "error", err.Error()) - - return - } else { - driver = conn - slog.Info("Database driver acquired", "type", fmt.Sprintf("%T", driver)) +func (c *Connection) Ping() error { + conn, err := c.driver.DB() + if err != nil { + return fmt.Errorf("error retrieving the db driver: %w", err) } - if err := driver.Ping(); err != nil { - slog.Error("Error pinging the db driver", "error", err.Error()) + if err := conn.Ping(); err != nil { + return fmt.Errorf("error pinging the db driver: %w", err) } - slog.Info("Database driver is healthy", "stats", driver.Stats()) - - slog.Info("Database ping completed", "separator", "---------") + return nil } func (c *Connection) Sql() *gorm.DB { diff --git a/handler/ping.go b/handler/keep_alive.go similarity index 60% rename from handler/ping.go rename to handler/keep_alive.go index 437a3150..521a5235 100644 --- a/handler/ping.go +++ b/handler/keep_alive.go @@ -11,15 +11,15 @@ import ( "github.com/oullin/pkg/portal" ) -type PingHandler struct { +type KeepAliveHandler struct { env *env.Ping } -func MakePingHandler(e *env.Ping) PingHandler { - return PingHandler{env: e} +func MakeKeepAliveHandler(e *env.Ping) KeepAliveHandler { + return KeepAliveHandler{env: e} } -func (h PingHandler) Handle(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { +func (h KeepAliveHandler) Handle(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { user, pass, ok := r.BasicAuth() if !ok || h.env.HasInvalidCreds(user, pass) { @@ -29,16 +29,16 @@ func (h PingHandler) Handle(w baseHttp.ResponseWriter, r *baseHttp.Request) *htt ) } - resp := http.MakeResponseFrom("0.0.1", w, r) + resp := http.MakeNoCacheResponse(w, r) now := time.Now().UTC() - data := payload.PingResponse{ + data := payload.KeepAliveResponse{ Message: "pong", DateTime: now.Format(portal.DatesLayout), } if err := resp.RespondOk(data); err != nil { - return http.LogInternalError("could not encode ping response", err) + return http.LogInternalError("could not encode keep-alive response", err) } return nil diff --git a/handler/keep_alive_db.go b/handler/keep_alive_db.go new file mode 100644 index 00000000..b589377c --- /dev/null +++ b/handler/keep_alive_db.go @@ -0,0 +1,51 @@ +package handler + +import ( + "fmt" + baseHttp "net/http" + "time" + + "github.com/oullin/database" + "github.com/oullin/handler/payload" + "github.com/oullin/metal/env" + "github.com/oullin/pkg/http" + "github.com/oullin/pkg/portal" +) + +type KeepAliveDBHandler struct { + env *env.Ping + db *database.Connection +} + +func MakeKeepAliveDBHandler(e *env.Ping, db *database.Connection) KeepAliveDBHandler { + return KeepAliveDBHandler{env: e, db: db} +} + +func (h KeepAliveDBHandler) Handle(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { + user, pass, ok := r.BasicAuth() + + if !ok || h.env.HasInvalidCreds(user, pass) { + return http.LogUnauthorisedError( + "invalid credentials", + fmt.Errorf("invalid credentials"), + ) + } + + if err := h.db.Ping(); err != nil { + return http.LogInternalError("database ping failed", err) + } + + resp := http.MakeNoCacheResponse(w, r) + now := time.Now().UTC() + + data := payload.KeepAliveResponse{ + Message: "pong", + DateTime: now.Format(portal.DatesLayout), + } + + if err := resp.RespondOk(data); err != nil { + return http.LogInternalError("could not encode keep-alive response", err) + } + + return nil +} diff --git a/handler/keep_alive_db_test.go b/handler/keep_alive_db_test.go new file mode 100644 index 00000000..2ce01b2d --- /dev/null +++ b/handler/keep_alive_db_test.go @@ -0,0 +1,61 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/oullin/handler/payload" + handlertests "github.com/oullin/handler/tests" + "github.com/oullin/metal/env" + "github.com/oullin/pkg/portal" +) + +func TestKeepAliveDBHandler(t *testing.T) { + db, _ := handlertests.MakeTestDB(t) + e := env.Ping{Username: "user", Password: "pass"} + h := MakeKeepAliveDBHandler(&e, db) + + t.Run("valid credentials", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping-db", nil) + req.SetBasicAuth("user", "pass") + rec := httptest.NewRecorder() + if err := h.Handle(rec, req); err != nil { + t.Fatalf("handle err: %v", err) + } + if rec.Code != http.StatusOK { + t.Fatalf("status %d", rec.Code) + } + var resp payload.KeepAliveResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.Message != "pong" { + t.Fatalf("unexpected message: %s", resp.Message) + } + if _, err := time.Parse(portal.DatesLayout, resp.DateTime); err != nil { + t.Fatalf("invalid datetime: %v", err) + } + }) + + t.Run("invalid credentials", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping-db", nil) + req.SetBasicAuth("bad", "creds") + rec := httptest.NewRecorder() + if err := h.Handle(rec, req); err == nil || err.Status != http.StatusUnauthorized { + t.Fatalf("expected unauthorized, got %#v", err) + } + }) + + t.Run("db ping failure", func(t *testing.T) { + db.Close() + req := httptest.NewRequest("GET", "/ping-db", nil) + req.SetBasicAuth("user", "pass") + rec := httptest.NewRecorder() + if err := h.Handle(rec, req); err == nil || err.Status != http.StatusInternalServerError { + t.Fatalf("expected internal error, got %#v", err) + } + }) +} diff --git a/handler/ping_test.go b/handler/keep_alive_test.go similarity index 91% rename from handler/ping_test.go rename to handler/keep_alive_test.go index cf2fe289..a0296be2 100644 --- a/handler/ping_test.go +++ b/handler/keep_alive_test.go @@ -12,9 +12,9 @@ import ( "github.com/oullin/pkg/portal" ) -func TestPingHandler(t *testing.T) { +func TestKeepAliveHandler(t *testing.T) { e := env.Ping{Username: "user", Password: "pass"} - h := MakePingHandler(&e) + h := MakeKeepAliveHandler(&e) t.Run("valid credentials", func(t *testing.T) { req := httptest.NewRequest("GET", "/ping", nil) @@ -26,7 +26,7 @@ func TestPingHandler(t *testing.T) { if rec.Code != http.StatusOK { t.Fatalf("status %d", rec.Code) } - var resp payload.PingResponse + var resp payload.KeepAliveResponse if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode: %v", err) } diff --git a/handler/payload/ping.go b/handler/payload/keep_alive.go similarity index 73% rename from handler/payload/ping.go rename to handler/payload/keep_alive.go index d05f6805..1e142d32 100644 --- a/handler/payload/ping.go +++ b/handler/payload/keep_alive.go @@ -1,6 +1,6 @@ package payload -type PingResponse struct { +type KeepAliveResponse struct { Message string `json:"message"` DateTime string `json:"date_time"` } diff --git a/main.go b/main.go index 19e9178f..2f2ec70b 100644 --- a/main.go +++ b/main.go @@ -35,7 +35,9 @@ func main() { app.Boot() // --- Testing - app.GetDB().Ping() + if err := app.GetDB().Ping(); err != nil { + slog.Error("database ping failed", "error", err) + } slog.Info("Starting new server on :" + app.GetEnv().Network.HttpPort) // --- diff --git a/metal/kernel/app.go b/metal/kernel/app.go index 968d4dfc..e110064b 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -69,7 +69,8 @@ func (a *App) Boot() { router := *a.router - router.Ping() + router.KeepAlive() + router.KeepAliveDB() router.Profile() router.Experience() router.Projects() diff --git a/metal/kernel/router.go b/metal/kernel/router.go index 5e3df29b..0760117a 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -80,8 +80,8 @@ func (r *Router) Signature() { r.Mux.HandleFunc("POST /generate-signature", generate) } -func (r *Router) Ping() { - abstract := handler.MakePingHandler(&r.Env.Ping) +func (r *Router) KeepAlive() { + abstract := handler.MakeKeepAliveHandler(&r.Env.Ping) apiHandler := http.MakeApiHandler( r.Pipeline.Chain(abstract.Handle), @@ -90,6 +90,16 @@ func (r *Router) Ping() { r.Mux.HandleFunc("GET /ping", apiHandler) } +func (r *Router) KeepAliveDB() { + abstract := handler.MakeKeepAliveDBHandler(&r.Env.Ping, r.Db) + + apiHandler := http.MakeApiHandler( + r.Pipeline.Chain(abstract.Handle), + ) + + r.Mux.HandleFunc("GET /ping-db", apiHandler) +} + func (r *Router) Profile() { addStaticRoute(r, "/profile", "./storage/fixture/profile.json", handler.MakeProfileHandler) } diff --git a/metal/kernel/router_keep_alive_db_test.go b/metal/kernel/router_keep_alive_db_test.go new file mode 100644 index 00000000..baac81b6 --- /dev/null +++ b/metal/kernel/router_keep_alive_db_test.go @@ -0,0 +1,42 @@ +package kernel + +import ( + "net/http" + "net/http/httptest" + "testing" + + handlertests "github.com/oullin/handler/tests" + "github.com/oullin/metal/env" + "github.com/oullin/pkg/middleware" +) + +func TestKeepAliveDBRoute(t *testing.T) { + db, _ := handlertests.MakeTestDB(t) + r := Router{ + Env: &env.Environment{Ping: env.Ping{Username: "user", Password: "pass"}}, + Db: db, + Mux: http.NewServeMux(), + Pipeline: middleware.Pipeline{PublicMiddleware: middleware.MakePublicMiddleware("", false)}, + } + r.KeepAliveDB() + + t.Run("valid credentials", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping-db", nil) + req.SetBasicAuth("user", "pass") + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code) + } + }) + + t.Run("invalid credentials", func(t *testing.T) { + req := httptest.NewRequest("GET", "/ping-db", nil) + req.SetBasicAuth("bad", "creds") + rec := httptest.NewRecorder() + r.Mux.ServeHTTP(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code) + } + }) +} diff --git a/metal/kernel/router_ping_test.go b/metal/kernel/router_keep_alive_test.go similarity index 94% rename from metal/kernel/router_ping_test.go rename to metal/kernel/router_keep_alive_test.go index 692b3d47..bb00b728 100644 --- a/metal/kernel/router_ping_test.go +++ b/metal/kernel/router_keep_alive_test.go @@ -9,13 +9,13 @@ import ( "github.com/oullin/pkg/middleware" ) -func TestPingRoute(t *testing.T) { +func TestKeepAliveRoute(t *testing.T) { r := Router{ Env: &env.Environment{Ping: env.Ping{Username: "user", Password: "pass"}}, Mux: http.NewServeMux(), Pipeline: middleware.Pipeline{PublicMiddleware: middleware.MakePublicMiddleware("", false)}, } - r.Ping() + r.KeepAlive() t.Run("valid credentials", func(t *testing.T) { req := httptest.NewRequest("GET", "/ping", nil) diff --git a/pkg/http/response.go b/pkg/http/response.go index 636f2d48..188d41f8 100644 --- a/pkg/http/response.go +++ b/pkg/http/response.go @@ -38,6 +38,21 @@ func MakeResponseFrom(salt string, writer baseHttp.ResponseWriter, request *base } } +func MakeNoCacheResponse(writer baseHttp.ResponseWriter, request *baseHttp.Request) *Response { + cacheControl := "no-store" + + return &Response{ + writer: writer, + request: request, + cacheControl: cacheControl, + headers: func(w baseHttp.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Cache-Control", cacheControl) + }, + } +} + func (r *Response) WithHeaders(callback func(w baseHttp.ResponseWriter)) { callback(r.writer) } @@ -53,6 +68,10 @@ func (r *Response) RespondOk(payload any) error { } func (r *Response) HasCache() bool { + if r.etag == "" { + return false + } + request := r.request match := strings.TrimSpace( diff --git a/pkg/http/response_test.go b/pkg/http/response_test.go index 35acd4b2..a505543d 100644 --- a/pkg/http/response_test.go +++ b/pkg/http/response_test.go @@ -31,6 +31,33 @@ func TestResponse_RespondOkAndHasCache(t *testing.T) { } } +func TestResponse_NoCache(t *testing.T) { + req := httptest.NewRequest("GET", "/", nil) + rec := httptest.NewRecorder() + + r := MakeNoCacheResponse(rec, req) + + if err := r.RespondOk(map[string]string{"a": "b"}); err != nil { + t.Fatalf("respond: %v", err) + } + + if rec.Code != http.StatusOK { + t.Fatalf("status %d", rec.Code) + } + + if rec.Header().Get("Cache-Control") != "no-store" { + t.Fatalf("unexpected cache-control: %s", rec.Header().Get("Cache-Control")) + } + + if rec.Header().Get("ETag") != "" { + t.Fatalf("etag should be empty") + } + + if r.HasCache() { + t.Fatalf("expected no cache") + } +} + func TestResponse_WithHeaders(t *testing.T) { req := httptest.NewRequest("GET", "/", nil) rec := httptest.NewRecorder()