diff --git a/caddy/Caddyfile.local b/caddy/Caddyfile.local index 94bbf7eb..d1c84dbc 100644 --- a/caddy/Caddyfile.local +++ b/caddy/Caddyfile.local @@ -16,7 +16,7 @@ header { Access-Control-Allow-Origin "http://localhost:5173" # allows the Vue app (running on localhost:5173) to make requests. Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" # Specifies which methods are allowed. - Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match" # allows the custom headers needed by the API. + Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match, X-API-Intended-Origin" # allows the custom headers needed by the API. Access-Control-Expose-Headers "ETag, X-Request-ID" } @@ -30,7 +30,7 @@ # Reflect the Origin back so it's always allowed header Access-Control-Allow-Origin "{http.request.header.Origin}" header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" - header Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match" + header Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match, X-API-Intended-Origin" header Access-Control-Max-Age "86400" respond 204 } diff --git a/caddy/Caddyfile.prod b/caddy/Caddyfile.prod index 10507694..a9614925 100644 --- a/caddy/Caddyfile.prod +++ b/caddy/Caddyfile.prod @@ -34,7 +34,7 @@ oullin.io { header { Access-Control-Allow-Origin "https://oullin.io" Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" - Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match" + Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match, X-API-Intended-Origin" Access-Control-Expose-Headers "ETag, X-Request-ID" } @@ -47,7 +47,7 @@ oullin.io { # Reflect the Origin back so it's always allowed header Access-Control-Allow-Origin "{http.request.header.Origin}" header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" - header Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match" + header Access-Control-Allow-Headers "X-API-Key, X-API-Username, X-API-Signature, X-API-Timestamp, X-API-Nonce, X-Request-ID, Content-Type, User-Agent, If-None-Match, X-API-Intended-Origin" header Access-Control-Max-Age "86400" respond 204 } @@ -63,6 +63,7 @@ oullin.io { header_up Content-Type {http.request.header.Content-Type} header_up User-Agent {http.request.header.User-Agent} header_up If-None-Match {http.request.header.If-None-Match} + header_up X-API-Intended-Origin {http.request.header.X-API-Intended-Origin} transport http { dial_timeout 10s diff --git a/database/attrs.go b/database/attrs.go index b54168bb..2b9358f8 100644 --- a/database/attrs.go +++ b/database/attrs.go @@ -39,9 +39,9 @@ type CommentsAttrs struct { } type LikesAttrs struct { - UUID string `gorm:"type:uuid;unique;not null"` - PostID uint64 `gorm:"not null;index;uniqueIndex:idx_likes_post_user"` - UserID uint64 `gorm:"not null;index;uniqueIndex:idx_likes_post_user"` + UUID string + PostID uint64 + UserID uint64 } type NewsletterAttrs struct { diff --git a/database/connection.go b/database/connection.go index 96703095..8f08e64c 100644 --- a/database/connection.go +++ b/database/connection.go @@ -3,10 +3,11 @@ package database import ( "database/sql" "fmt" + "log/slog" + "github.com/oullin/metal/env" "gorm.io/driver/postgres" "gorm.io/gorm" - "log/slog" ) type Connection struct { diff --git a/database/infra/migrations/000002_api_keys.up.sql b/database/infra/migrations/000002_api_keys.up.sql index 7b48305e..aa453584 100644 --- a/database/infra/migrations/000002_api_keys.up.sql +++ b/database/infra/migrations/000002_api_keys.up.sql @@ -11,7 +11,7 @@ CREATE TABLE api_keys ( CONSTRAINT uq_account_keys UNIQUE (account_name, public_key, secret_key) ); -CREATE INDEX idx_account_name ON api_keys(account_name); -CREATE INDEX idx_public_key ON api_keys(public_key); -CREATE INDEX idx_secret_key ON api_keys(secret_key); -CREATE INDEX idx_deleted_at ON api_keys(deleted_at); +CREATE INDEX idx_api_keys_account_name ON api_keys(account_name); +CREATE INDEX idx_api_keys_public_key ON api_keys(public_key); +CREATE INDEX idx_api_keys_secret_key ON api_keys(secret_key); +CREATE INDEX idx_api_keys_deleted_at ON api_keys(deleted_at); diff --git a/database/infra/migrations/000003_api_keys_signatures.up.sql b/database/infra/migrations/000003_api_keys_signatures.up.sql new file mode 100644 index 00000000..ed5292ff --- /dev/null +++ b/database/infra/migrations/000003_api_keys_signatures.up.sql @@ -0,0 +1,25 @@ +CREATE TABLE api_key_signatures ( + id BIGSERIAL PRIMARY KEY, + uuid UUID UNIQUE NOT NULL, + api_key_id BIGINT NOT NULL, + signature BYTEA NOT NULL, + max_tries SMALLINT NOT NULL DEFAULT 1 CHECK (max_tries > 0), + current_tries SMALLINT NOT NULL DEFAULT 1 CHECK (current_tries > 0), + expires_at TIMESTAMP DEFAULT NULL, + expired_at TIMESTAMP DEFAULT NULL, + origin TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + deleted_at TIMESTAMP DEFAULT NULL, + + CONSTRAINT uq_api_key_signatures_signature UNIQUE (signature), + CONSTRAINT api_key_signatures_fk_api_key_id FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE CASCADE +); + +CREATE INDEX idx_api_key_signatures_api_key_id ON api_key_signatures(api_key_id); +CREATE INDEX idx_api_key_signatures_signature_created_at ON api_key_signatures(signature, created_at); +CREATE INDEX idx_api_key_signatures_origin ON api_key_signatures(origin); +CREATE INDEX idx_api_key_signatures_expires_at ON api_key_signatures(expires_at); +CREATE INDEX idx_api_key_signatures_expired_at ON api_key_signatures(expired_at); +CREATE INDEX idx_api_key_signatures_created_at ON api_key_signatures(created_at); +CREATE INDEX idx_api_key_signatures_deleted_at ON api_key_signatures(deleted_at); diff --git a/database/model.go b/database/model.go index 69c560c2..8945e16b 100644 --- a/database/model.go +++ b/database/model.go @@ -1,9 +1,10 @@ package database import ( - "gorm.io/gorm" "slices" "time" + + "gorm.io/gorm" ) const DriverName = "postgres" @@ -11,8 +12,8 @@ const DriverName = "postgres" var schemaTables = []string{ "users", "posts", "categories", "post_categories", "tags", "post_tags", - "post_views", "comments", - "likes", "newsletters", "api_keys", + "post_views", "comments", "likes", + "newsletters", "api_keys", "api_key_signatures", } func GetSchemaTables() []string { @@ -32,6 +33,25 @@ type APIKey struct { CreatedAt time.Time UpdatedAt time.Time DeletedAt gorm.DeletedAt `gorm:"index"` + + //Associations + APIKeySignature []APIKeySignatures `gorm:"foreignKey:APIKeyID"` +} + +type APIKeySignatures struct { + ID int64 `gorm:"primaryKey"` + UUID string `gorm:"type:uuid;unique;not null"` + APIKeyID int64 `gorm:"not null;index:idx_api_key_signatures_api_key_id"` + MaxTries int `gorm:"not null"` + CurrentTries int `gorm:"not null"` + APIKey APIKey `gorm:"foreignKey:APIKeyID;references:ID;constraint:OnDelete:CASCADE"` + Signature []byte `gorm:"not null;uniqueIndex:uq_api_key_signatures_signature"` + Origin string `gorm:"type:varchar(255);not null;index:idx_api_key_signatures_origin"` + ExpiresAt time.Time `gorm:"index:idx_api_key_signatures_expires_at"` + ExpiredAt *time.Time `gorm:"index:idx_api_key_signatures_expired_at"` + CreatedAt time.Time `gorm:"index:idx_api_key_signatures_created_at"` + UpdatedAt time.Time `gorm:"index:idx_api_key_signatures_updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index:idx_api_key_signatures_deleted_at"` } type User struct { diff --git a/database/repository/api_keys.go b/database/repository/api_keys.go index 746a52b4..77e30753 100644 --- a/database/repository/api_keys.go +++ b/database/repository/api_keys.go @@ -2,10 +2,15 @@ package repository import ( "fmt" + "strings" + "time" + "github.com/google/uuid" "github.com/oullin/database" + "github.com/oullin/database/repository/repoentity" "github.com/oullin/pkg/gorm" - "strings" + "github.com/oullin/pkg/portal" + baseGorm "gorm.io/gorm" ) type ApiKeys struct { @@ -49,3 +54,137 @@ func (a ApiKeys) FindBy(accountName string) *database.APIKey { return nil } + +func (a ApiKeys) CreateSignatureFor(entity repoentity.APIKeyCreateSignatureFor) (*database.APIKeySignatures, error) { + var item *database.APIKeySignatures + + if item = a.FindActiveSignatureFor(entity.Key, entity.Origin); item != nil { + item.CurrentTries++ + a.DB.Sql().Save(&item) + + return item, nil + } + + now := time.Now() + signature := database.APIKeySignatures{ + CreatedAt: now, + UpdatedAt: now, + Signature: entity.Seed, + APIKeyID: entity.Key.ID, + ExpiresAt: entity.ExpiresAt, + UUID: uuid.NewString(), + MaxTries: portal.MaxSignaturesTries, + Origin: entity.Origin, + CurrentTries: 1, + } + + err := a.DB.Sql().Transaction(func(tx *baseGorm.DB) error { + username := entity.Key.AccountName + if result := a.DB.Sql().Create(&signature); gorm.HasDbIssues(result.Error) { + return fmt.Errorf("issue creating the given api keys signature [%s, %s]: ", username, result.Error) + } + + if result := a.DisablePreviousSignatures(entity.Key, signature.UUID, entity.Origin); result != nil { + return fmt.Errorf("issue disabling previous api keys signature [%s, %s]: ", username, result) + } + + return nil + }) + + if err != nil { + return nil, err + } + + return &signature, nil +} + +func (a ApiKeys) FindActiveSignatureFor(key *database.APIKey, origin string) *database.APIKeySignatures { + var item database.APIKeySignatures + + result := a.DB.Sql(). + Model(&database.APIKeySignatures{}). + Where("expired_at IS NULL"). + Where("api_key_id = ?", key.ID). + Where("origin = ?", origin). + Where("current_tries <= max_tries"). + Where("expires_at > ?", time.Now()). + First(&item) + + if gorm.HasDbIssues(result.Error) { + return nil + } + + if result.RowsAffected > 0 { + return &item + } + + return nil +} + +func (a ApiKeys) FindSignatureFrom(entity repoentity.FindSignatureFrom) *database.APIKeySignatures { + var item database.APIKeySignatures + + result := a.DB.Sql(). + Model(&database.APIKeySignatures{}). + Where("api_key_id = ?", entity.Key.ID). + Where("signature = ?", entity.Signature). + Where("expires_at >= ? ", entity.ServerTime). + Where("origin = ?", entity.Origin). + Where("expired_at IS NULL"). + Where("current_tries <= max_tries"). + First(&item) + + if gorm.HasDbIssues(result.Error) { + return nil + } + + if result.RowsAffected > 0 { + return &item + } + + return nil +} + +func (a ApiKeys) DisablePreviousSignatures(key *database.APIKey, signatureUUID, origin string) error { + query := a.DB.Sql(). + Model(&database.APIKeySignatures{}). + Where( + a.DB.Sql(). + Where("expired_at IS NULL").Or("current_tries > max_tries"), + ). + Where("api_key_id = ?", key.ID). + Where( + a.DB.Sql(). + Where("origin = ?", origin). + Or("TRIM(origin) = ''"), + ). + Where("uuid NOT IN (?)", []string{signatureUUID}). + Update("expired_at", time.Now()) + + if gorm.HasDbIssues(query.Error) { + return query.Error + } + + return nil +} + +func (a ApiKeys) IncreaseSignatureTries(signatureUUID string, currentTries int) error { + if currentTries >= portal.MaxSignaturesTries { + return nil + } + + response := a.DB.Sql(). + Model(&database.APIKeySignatures{}). + Where("uuid = ? AND current_tries < max_tries", signatureUUID). + UpdateColumn("current_tries", baseGorm.Expr("current_tries + 1")) + + if gorm.HasDbIssues(response.Error) { + return response.Error + } + + if response.RowsAffected == 0 { + return nil + } + + return nil +} diff --git a/database/repository/posts.go b/database/repository/posts.go index 0c6e5220..0b370dca 100644 --- a/database/repository/posts.go +++ b/database/repository/posts.go @@ -2,6 +2,7 @@ package repository import ( "fmt" + "github.com/google/uuid" "github.com/oullin/database" "github.com/oullin/database/repository/pagination" diff --git a/database/repository/repoentity/api_keys.go b/database/repository/repoentity/api_keys.go new file mode 100644 index 00000000..f4651655 --- /dev/null +++ b/database/repository/repoentity/api_keys.go @@ -0,0 +1,21 @@ +package repoentity + +import ( + "time" + + "github.com/oullin/database" +) + +type APIKeyCreateSignatureFor struct { + Key *database.APIKey + ExpiresAt time.Time + Seed []byte + Origin string +} + +type FindSignatureFrom struct { + Key *database.APIKey + Signature []byte + Origin string + ServerTime time.Time +} diff --git a/handler/categories.go b/handler/categories.go index 1ef2c451..4f7fbb2f 100644 --- a/handler/categories.go +++ b/handler/categories.go @@ -2,14 +2,15 @@ package handler import ( "encoding/json" + "log/slog" + baseHttp "net/http" + "github.com/oullin/database" "github.com/oullin/database/repository" "github.com/oullin/database/repository/pagination" "github.com/oullin/handler/paginate" "github.com/oullin/handler/payload" "github.com/oullin/pkg/http" - "log/slog" - baseHttp "net/http" ) type CategoriesHandler struct { diff --git a/handler/payload/signatures.go b/handler/payload/signatures.go new file mode 100644 index 00000000..d6fb7877 --- /dev/null +++ b/handler/payload/signatures.go @@ -0,0 +1,21 @@ +package payload + +type SignatureRequest struct { + Nonce string `json:"nonce" validate:"required,lowercase,hexadecimal,len=32"` + PublicKey string `json:"public_key" validate:"required,lowercase,min=64,max=67"` + Username string `json:"username" validate:"required,lowercase,min=5"` + Timestamp int64 `json:"timestamp" validate:"required,number,gte=1000000000,min=10"` + Origin string `json:"origin"` +} + +type SignatureResponse struct { + Signature string `json:"signature"` + MaxTries int `json:"max_tries"` + Cadence SignatureCadenceResponse `json:"cadence"` +} + +type SignatureCadenceResponse struct { + ReceivedAt string `json:"received_at"` + CreatedAt string `json:"created_at"` + ExpiresAt string `json:"expires_at"` +} diff --git a/handler/signatures.go b/handler/signatures.go new file mode 100644 index 00000000..6afee908 --- /dev/null +++ b/handler/signatures.go @@ -0,0 +1,131 @@ +package handler + +import ( + "encoding/json" + "fmt" + "io" + "log/slog" + baseHttp "net/http" + "time" + + "github.com/oullin/database" + "github.com/oullin/database/repository" + "github.com/oullin/database/repository/repoentity" + "github.com/oullin/handler/payload" + "github.com/oullin/pkg/auth" + "github.com/oullin/pkg/http" + "github.com/oullin/pkg/portal" +) + +type SignaturesHandler struct { + Validator *portal.Validator + ApiKeys *repository.ApiKeys +} + +func MakeSignaturesHandler(validator *portal.Validator, ApiKeys *repository.ApiKeys) SignaturesHandler { + return SignaturesHandler{ + Validator: validator, + ApiKeys: ApiKeys, + } +} + +func (s *SignaturesHandler) Generate(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { + defer portal.CloseWithLog(r.Body) + + var err error + var bodyBytes []byte + + bodyBytes, err = io.ReadAll(r.Body) + + if err != nil { + return http.LogBadRequestError("could not read signatures request body", err) + } + + var req payload.SignatureRequest + if err = json.Unmarshal(bodyBytes, &req); err != nil { + return http.LogBadRequestError("could not parse the given data.", err) + } + + if _, err = s.Validator.Rejects(req); err != nil { + return http.UnprocessableEntity("The given fields are invalid", s.Validator.GetErrors()) + } + + serverTime := time.Now() + receivedAt := time.Unix(req.Timestamp, 0) + req.Origin = r.Header.Get(portal.IntendedOriginHeader) + + if err = s.isRequestWithinTimeframe(serverTime, receivedAt); err != nil { + return http.LogBadRequestError(err.Error(), err) + } + + var keySignature *database.APIKeySignatures + if keySignature, err = s.CreateSignature(req, serverTime); err != nil { + return http.LogInternalError(err.Error(), err) + } + + response := payload.SignatureResponse{ + Signature: auth.SignatureToString(keySignature.Signature), + MaxTries: keySignature.MaxTries, + Cadence: payload.SignatureCadenceResponse{ + ReceivedAt: receivedAt.Format(portal.DatesLayout), + CreatedAt: keySignature.CreatedAt.Format(portal.DatesLayout), + ExpiresAt: keySignature.ExpiresAt.Format(portal.DatesLayout), + }, + } + + resp := http.MakeResponseFrom("0.0.1", w, r) + + if err = resp.RespondOk(response); err != nil { + slog.Error("Error marshaling JSON for signatures response", "error", err) + return nil + } + + return nil // A nil return indicates success. +} + +func (s *SignaturesHandler) isRequestWithinTimeframe(serverTime, receivedAt time.Time) error { + skew := 5 * time.Second + + earliestValidTime := serverTime.Add(-skew) + if receivedAt.Before(earliestValidTime) { + return fmt.Errorf("the request timestamp [%s] is too old", receivedAt.Format(portal.DatesLayout)) + } + + latestValidTime := serverTime.Add(skew) + if receivedAt.After(latestValidTime) { + return fmt.Errorf("the request timestamp [%s] is from the future", receivedAt.Format(portal.DatesLayout)) + } + + return nil +} + +func (s *SignaturesHandler) CreateSignature(request payload.SignatureRequest, serverTime time.Time) (*database.APIKeySignatures, error) { + var err error + var token *database.APIKey + var keySignature *database.APIKeySignatures + + if token = s.ApiKeys.FindBy(request.Username); token == nil { + return nil, fmt.Errorf("the given username [%s] was not found", request.Username) + } + + var seed []byte + if seed, err = auth.GenerateAESKey(); err != nil { + return nil, fmt.Errorf("unable to generate the signature seed. Please try again") + } + + expiresAt := serverTime.Add(time.Second * 30) + hash := auth.CreateSignature(seed, token.SecretKey) + + entity := repoentity.APIKeyCreateSignatureFor{ + Key: token, + ExpiresAt: expiresAt, + Seed: hash, + Origin: request.Origin, + } + + if keySignature, err = s.ApiKeys.CreateSignatureFor(entity); err != nil { + return nil, fmt.Errorf("unable to create the signature item. Please try again") + } + + return keySignature, nil +} diff --git a/main.go b/main.go index 08a5ee3d..19e9178f 100644 --- a/main.go +++ b/main.go @@ -46,16 +46,32 @@ func main() { } func serverHandler() baseHttp.Handler { - if app.IsProduction() { // CORS is handled by Caddy. + if app.IsProduction() { // Caddy handles CORS. return app.GetMux() } localhost := app.GetEnv().Network.GetHostURL() + headers := []string{ + "Accept", + "Authorization", + "Content-Type", + "X-CSRF-Token", + "User-Agent", + "X-API-Key", + "X-API-Username", + "X-API-Signature", + "X-API-Timestamp", + "X-API-Nonce", + "X-Request-ID", + "If-None-Match", + "X-API-Intended-Origin", //new + } + c := cors.New(cors.Options{ AllowedOrigins: []string{localhost, "http://localhost:5173"}, AllowedMethods: []string{baseHttp.MethodGet, baseHttp.MethodPost, baseHttp.MethodPut, baseHttp.MethodDelete, baseHttp.MethodOptions}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "User-Agent", "X-API-Key", "X-API-Username", "X-API-Signature", "X-API-Timestamp", "X-API-Nonce", "X-Request-ID", "If-None-Match"}, + AllowedHeaders: headers, AllowCredentials: true, Debug: true, }) diff --git a/metal/cli/accounts/factory.go b/metal/cli/accounts/factory.go index e340b637..ee7a9352 100644 --- a/metal/cli/accounts/factory.go +++ b/metal/cli/accounts/factory.go @@ -2,6 +2,7 @@ package accounts import ( "fmt" + "github.com/oullin/database" "github.com/oullin/database/repository" "github.com/oullin/metal/env" diff --git a/metal/cli/accounts/factory_test.go b/metal/cli/accounts/factory_test.go index 8bf58a68..c91da457 100644 --- a/metal/cli/accounts/factory_test.go +++ b/metal/cli/accounts/factory_test.go @@ -1,10 +1,10 @@ package accounts import ( - "github.com/oullin/metal/cli/clitest" "testing" "github.com/oullin/database" + "github.com/oullin/metal/cli/clitest" ) func TestMakeHandler(t *testing.T) { diff --git a/metal/cli/accounts/handler.go b/metal/cli/accounts/handler.go index cdd2e4f3..fa345d01 100644 --- a/metal/cli/accounts/handler.go +++ b/metal/cli/accounts/handler.go @@ -2,6 +2,7 @@ package accounts import ( "fmt" + "github.com/oullin/database" "github.com/oullin/pkg/auth" "github.com/oullin/pkg/cli" diff --git a/metal/cli/main.go b/metal/cli/main.go index a8cbbd84..c841cef2 100644 --- a/metal/cli/main.go +++ b/metal/cli/main.go @@ -1,15 +1,12 @@ package main import ( - "fmt" - "github.com/oullin/database" "github.com/oullin/metal/cli/accounts" "github.com/oullin/metal/cli/panel" "github.com/oullin/metal/cli/posts" "github.com/oullin/metal/env" "github.com/oullin/metal/kernel" - "github.com/oullin/pkg/auth" "github.com/oullin/pkg/cli" "github.com/oullin/pkg/portal" ) @@ -65,13 +62,6 @@ func main() { continue } - return - case 5: - if err = generateAppEncryptionKey(); err != nil { - cli.Errorln(err.Error()) - continue - } - return case 0: cli.Successln("Goodbye!") @@ -162,22 +152,3 @@ func generateApiAccountsHTTPSignature(menu panel.Menu) error { return nil } - -func generateAppEncryptionKey() error { - var err error - var key []byte - - if key, err = auth.GenerateAESKey(); err != nil { - return err - } - - decoded := fmt.Sprintf("%x", key) - - cli.Successln("\n The key was generated successfully.") - cli.Magentaln(fmt.Sprintf(" > Full key: %s", decoded)) - cli.Cyanln(fmt.Sprintf(" > First half : %s", decoded[:32])) - cli.Cyanln(fmt.Sprintf(" > Second half: %s", decoded[32:])) - fmt.Println(" ") - - return nil -} diff --git a/metal/cli/panel/menu.go b/metal/cli/panel/menu.go index 61302e9e..074c8d7f 100644 --- a/metal/cli/panel/menu.go +++ b/metal/cli/panel/menu.go @@ -89,8 +89,7 @@ func (p *Menu) Print() { p.PrintOption("1) Parse Blog Posts.", inner) p.PrintOption("2) Create new API account.", inner) p.PrintOption("3) Show API accounts.", inner) - p.PrintOption("4) Generate API accounts HTTP signature.", inner) - p.PrintOption("5) Generate app encryption key.", inner) + p.PrintOption("4) Generate API account HTTP key pair.", inner) p.PrintOption(" ", inner) p.PrintOption("0) Exit.", inner) diff --git a/metal/kernel/app.go b/metal/kernel/app.go index 679c1c72..b9627343 100644 --- a/metal/kernel/app.go +++ b/metal/kernel/app.go @@ -42,9 +42,10 @@ func MakeApp(env *env.Environment, validator *portal.Validator) (*App, error) { } router := Router{ - Env: env, - Db: db, - Mux: baseHttp.NewServeMux(), + Env: env, + Db: db, + Mux: baseHttp.NewServeMux(), + validator: validator, Pipeline: middleware.Pipeline{ Env: env, ApiKeys: &repository.ApiKeys{DB: db}, @@ -73,4 +74,5 @@ func (a *App) Boot() { router.Recommendations() router.Posts() router.Categories() + router.Signature() } diff --git a/metal/kernel/router.go b/metal/kernel/router.go index ca13d6b9..e8014b5f 100644 --- a/metal/kernel/router.go +++ b/metal/kernel/router.go @@ -9,6 +9,7 @@ import ( "github.com/oullin/metal/env" "github.com/oullin/pkg/http" "github.com/oullin/pkg/middleware" + "github.com/oullin/pkg/portal" ) type StaticRouteResource interface { @@ -22,10 +23,19 @@ func addStaticRoute[H StaticRouteResource](r *Router, path, file string, maker f } type Router struct { - Env *env.Environment - Mux *baseHttp.ServeMux - Pipeline middleware.Pipeline - Db *database.Connection + Env *env.Environment + Mux *baseHttp.ServeMux + Pipeline middleware.Pipeline + Db *database.Connection + validator *portal.Validator +} + +func (r *Router) PublicPipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerFunc { + return http.MakeApiHandler( + r.Pipeline.Chain( + apiHandler, + ), + ) } func (r *Router) PipelineFor(apiHandler http.ApiHandler) baseHttp.HandlerFunc { @@ -62,6 +72,13 @@ func (r *Router) Categories() { r.Mux.HandleFunc("GET /categories", index) } +func (r *Router) Signature() { + abstract := handler.MakeSignaturesHandler(r.validator, r.Pipeline.ApiKeys) + generate := r.PublicPipelineFor(abstract.Generate) + + r.Mux.HandleFunc("POST /generate-signature", generate) +} + func (r *Router) Profile() { addStaticRoute(r, "/profile", "./storage/fixture/profile.json", handler.MakeProfileHandler) } diff --git a/pkg/auth/encryption.go b/pkg/auth/encryption.go index 4210f9b9..b55c94de 100644 --- a/pkg/auth/encryption.go +++ b/pkg/auth/encryption.go @@ -16,7 +16,7 @@ func GenerateAESKey() ([]byte, error) { key := make([]byte, EncryptionKeyLength) if _, err := rand.Read(key); err != nil { - return nil, fmt.Errorf("failed to generate random key: %w", err) + return []byte(""), err } return key, nil diff --git a/pkg/auth/signature.go b/pkg/auth/signature.go new file mode 100644 index 00000000..b122d042 --- /dev/null +++ b/pkg/auth/signature.go @@ -0,0 +1,26 @@ +package auth + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" +) + +func CreateSignature(message, secretKey []byte) []byte { + mac := hmac.New(sha256.New, secretKey) + mac.Write(message) + + return mac.Sum(nil) +} + +func SignatureToString(signature []byte) string { + return hex.EncodeToString(signature) +} + +func VerifySignature(message, secretKey, signature []byte) bool { + mac := hmac.New(sha256.New, secretKey) + _, _ = mac.Write(message) + expected := mac.Sum(nil) + + return hmac.Equal(expected, signature) +} diff --git a/pkg/cache/ttl_cache_useonce_test.go b/pkg/cache/ttl_cache_useonce_test.go index 91b53452..d66d2c4b 100644 --- a/pkg/cache/ttl_cache_useonce_test.go +++ b/pkg/cache/ttl_cache_useonce_test.go @@ -8,10 +8,10 @@ import ( // TestTTLCache_UseOnce verifies the behavior of UseOnce for first use, // repeated use before expiry and reuse after the TTL has elapsed. func TestTTLCache_UseOnce(t *testing.T) { - t.Parallel() - c := NewTTLCache() - key := "nonce" - ttl := 100 * time.Millisecond + t.Parallel() + c := NewTTLCache() + key := "nonce" + ttl := 100 * time.Millisecond t.Run("first use", func(t *testing.T) { if used := c.UseOnce(key, ttl); used { @@ -25,12 +25,12 @@ func TestTTLCache_UseOnce(t *testing.T) { } }) - t.Run("use after expiry", func(t *testing.T) { - time.Sleep(ttl + 50*time.Millisecond) - if used := c.UseOnce(key, ttl); used { - t.Fatalf("expected UseOnce to return false for an expired key") - } - }) + t.Run("use after expiry", func(t *testing.T) { + time.Sleep(ttl + 50*time.Millisecond) + if used := c.UseOnce(key, ttl); used { + t.Fatalf("expected UseOnce to return false for an expired key") + } + }) } // TestTTLCache_Mark_PrunesExpiredEntries ensures that calling Mark prunes diff --git a/pkg/gorm/support.go b/pkg/gorm/support.go index 4a4f98df..20ed6e5c 100644 --- a/pkg/gorm/support.go +++ b/pkg/gorm/support.go @@ -2,6 +2,7 @@ package gorm import ( "errors" + "gorm.io/gorm" ) diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 78951482..9d5a89fa 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -17,6 +17,7 @@ func MakeApiHandler(fn ApiHandler) baseHttp.HandlerFunc { resp := ErrorResponse{ Error: err.Message, Status: err.Status, + Data: err.Data, } if result := json.NewEncoder(w).Encode(resp); result != nil { diff --git a/pkg/http/response.go b/pkg/http/response.go index d589a94f..7fef9eea 100644 --- a/pkg/http/response.go +++ b/pkg/http/response.go @@ -3,6 +3,7 @@ package http import ( "encoding/json" "fmt" + "log/slog" baseHttp "net/http" "strings" ) @@ -72,6 +73,15 @@ func InternalError(msg string) *ApiError { } } +func LogInternalError(msg string, err error) *ApiError { + slog.Error(err.Error(), "error", err) + + return &ApiError{ + Message: fmt.Sprintf("Internal server error: %s", msg), + Status: baseHttp.StatusInternalServerError, + } +} + func BadRequestError(msg string) *ApiError { return &ApiError{ Message: fmt.Sprintf("Bad request error: %s", msg), @@ -79,6 +89,23 @@ func BadRequestError(msg string) *ApiError { } } +func LogBadRequestError(msg string, err error) *ApiError { + slog.Error(err.Error(), "error", err) + + return &ApiError{ + Message: fmt.Sprintf("Bad request error: %s", msg), + Status: baseHttp.StatusBadRequest, + } +} + +func UnprocessableEntity(msg string, errors map[string]any) *ApiError { + return &ApiError{ + Message: fmt.Sprintf("Unprocessable entity: %s", msg), + Status: baseHttp.StatusUnprocessableEntity, + Data: errors, + } +} + func NotFound(msg string) *ApiError { return &ApiError{ Message: fmt.Sprintf("Not found error: %s", msg), diff --git a/pkg/http/schema.go b/pkg/http/schema.go index 6f5329a9..08c879ef 100644 --- a/pkg/http/schema.go +++ b/pkg/http/schema.go @@ -3,13 +3,15 @@ package http import baseHttp "net/http" type ErrorResponse struct { - Error string `json:"error"` - Status int `json:"status"` + Error string `json:"error"` + Status int `json:"status"` + Data map[string]any `json:"data"` } type ApiError struct { - Message string `json:"message"` - Status int `json:"status"` + Message string `json:"message"` + Status int `json:"status"` + Data map[string]any `json:"data"` } func (e *ApiError) Error() string { diff --git a/pkg/middleware/headers.go b/pkg/middleware/headers.go new file mode 100644 index 00000000..c4dfecaf --- /dev/null +++ b/pkg/middleware/headers.go @@ -0,0 +1,12 @@ +package middleware + +type AuthTokenHeaders struct { + AccountName string + PublicKey string + Signature string + Timestamp string + Nonce string + ClientIP string + RequestID string + IntendedOriginURL string +} diff --git a/pkg/middleware/mwguards/mw_response_messages.go b/pkg/middleware/mwguards/mw_response_messages.go new file mode 100644 index 00000000..29e05fbe --- /dev/null +++ b/pkg/middleware/mwguards/mw_response_messages.go @@ -0,0 +1,118 @@ +package mwguards + +import ( + "log/slog" + baseHttp "net/http" + "strings" + + "github.com/oullin/pkg/http" +) + +func normaliseData(data ...map[string]any) map[string]any { + if data == nil || len(data) == 0 { + return map[string]any{} + } + + result := make(map[string]any, len(data)) + for _, d := range data { + for k, v := range d { + result[k] = v + } + } + + return result +} + +func normaliseMessages(message, logMessage string) (string, string) { + message = strings.TrimSpace(message) + + if strings.TrimSpace(logMessage) == "" { + logMessage = message + } + + return message, logMessage +} + +func InvalidRequestError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: message, + Status: baseHttp.StatusUnauthorized, + Data: normaliseData(data...), + } +} + +func InvalidTokenFormatError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: message, + Status: baseHttp.StatusUnauthorized, + Data: normaliseData(data...), + } +} + +func UnauthenticatedError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: "2- Invalid credentials: " + logMessage, + Status: baseHttp.StatusUnauthorized, + Data: normaliseData(data...), + } +} + +func RateLimitedError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: "Too many authentication attempts", + Status: baseHttp.StatusTooManyRequests, + Data: normaliseData(data...), + } +} + +func NotFound(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: message, + Status: baseHttp.StatusNotFound, + Data: normaliseData(data...), + } +} + +func TimestampTooOldError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: "Request timestamp expired", + Status: baseHttp.StatusUnauthorized, + Data: normaliseData(data...), + } +} + +func TimestampTooNewError(message, logMessage string, data ...map[string]any) *http.ApiError { + message, logMessage = normaliseMessages(message, logMessage) + + slog.Error(logMessage, "error") + + return &http.ApiError{ + Message: "Request timestamp invalid", + Status: baseHttp.StatusUnauthorized, + Data: normaliseData(data...), + } +} diff --git a/pkg/middleware/mwguards/mw_token_guard.go b/pkg/middleware/mwguards/mw_token_guard.go new file mode 100644 index 00000000..c5bde87f --- /dev/null +++ b/pkg/middleware/mwguards/mw_token_guard.go @@ -0,0 +1,104 @@ +package mwguards + +import ( + "crypto/sha256" + "crypto/subtle" + "fmt" + + "github.com/oullin/database" + "github.com/oullin/database/repository" + "github.com/oullin/pkg/auth" +) + +type MWTokenGuard struct { + Error error + ApiKey *database.APIKey + TokenHandler *auth.TokenHandler + KeysRepository *repository.ApiKeys +} + +type MWTokenGuardData struct { + Username string + PublicKey string +} + +func NewMWTokenGuard(apiKeys *repository.ApiKeys, TokenHandler *auth.TokenHandler) MWTokenGuard { + return MWTokenGuard{ + KeysRepository: apiKeys, + TokenHandler: TokenHandler, + } +} + +func (g *MWTokenGuard) Rejects(data MWTokenGuardData) bool { + if g.HasInvalidDependencies() { + g.Error = fmt.Errorf("invalid mw-token guard dependencies") + + return true + } + + if err := g.AccountNotFound(data.Username); err != nil { + g.Error = err + + return true + } + + if g.HasInvalidFormat(data.PublicKey) { + return true + } + + return false +} + +func (g *MWTokenGuard) HasInvalidDependencies() bool { + return g == nil || g.KeysRepository == nil || g.TokenHandler == nil +} + +func (g *MWTokenGuard) AccountNotFound(username string) error { + var item *database.APIKey + + if item = g.KeysRepository.FindBy(username); item == nil { + return fmt.Errorf("account [%s] not found", username) + } + + g.ApiKey = item + + return nil +} + +func (g *MWTokenGuard) HasInvalidFormat(publicKey string) bool { + token, err := g.TokenHandler.DecodeTokensFor( + g.ApiKey.AccountName, + g.ApiKey.SecretKey, + g.ApiKey.PublicKey, + ) + + if err != nil { + g.Error = fmt.Errorf("unable to decode the given account [%s] keys", g.ApiKey.AccountName) + + return true + } + + pBytes := []byte(publicKey) + eBytes := []byte(token.PublicKey) + hP := sha256.Sum256(pBytes) + hE := sha256.Sum256(eBytes) + + if subtle.ConstantTimeCompare(hP[:], hE[:]) != 1 { + g.Error = fmt.Errorf("invalid provided public token: %s", auth.SafeDisplay(publicKey)) + + return true + } + + return false +} + +func (g *MWTokenGuard) GetError() error { + return g.Error +} + +func (receiver MWTokenGuardData) ToMap() map[string]any { + return map[string]any{ + "username": receiver.Username, + "public_key": receiver.PublicKey, + } +} diff --git a/pkg/middleware/pipeline.go b/pkg/middleware/pipeline.go index 81a9f5d4..8220f127 100644 --- a/pkg/middleware/pipeline.go +++ b/pkg/middleware/pipeline.go @@ -13,9 +13,6 @@ type Pipeline struct { TokenHandler *auth.TokenHandler } -// Chain applies a list of middleware handlers to a final ApiHandler. -// It builds the chain in reverse, so the first middleware -// in the list is the outermost one, executing first. func (m Pipeline) Chain(h http.ApiHandler, handlers ...http.Middleware) http.ApiHandler { for i := len(handlers) - 1; i >= 0; i-- { h = handlers[i](h) diff --git a/pkg/middleware/token_middleware.go b/pkg/middleware/token_middleware.go index 99c23575..58620fd8 100644 --- a/pkg/middleware/token_middleware.go +++ b/pkg/middleware/token_middleware.go @@ -1,155 +1,95 @@ package middleware import ( - "bytes" "context" - "crypto/sha256" - "crypto/subtle" - "io" - "log/slog" + "encoding/hex" + "fmt" baseHttp "net/http" "strings" "time" "github.com/oullin/database" "github.com/oullin/database/repository" + "github.com/oullin/database/repository/repoentity" "github.com/oullin/pkg/auth" "github.com/oullin/pkg/cache" "github.com/oullin/pkg/http" "github.com/oullin/pkg/limiter" + "github.com/oullin/pkg/middleware/mwguards" "github.com/oullin/pkg/portal" ) -const tokenHeader = "X-API-Key" -const usernameHeader = "X-API-Username" -const signatureHeader = "X-API-Signature" -const timestampHeader = "X-API-Timestamp" -const nonceHeader = "X-API-Nonce" -const requestIDHeader = "X-Request-ID" - -// Context keys for propagating auth info downstream -// Use unexported custom type to avoid collisions -type contextKey string - -const ( - authAccountNameKey contextKey = "auth.account_name" - requestIdKey contextKey = "request.id" -) - -// TokenCheckMiddleware authenticates signed API requests using account tokens. -// It validates required headers, enforces a timestamp skew window, prevents -// replay attacks via nonce tracking, compares tokens/signatures in constant time, -// and applies a basic failure-based rate limiter per client scope. -// -// Error handling: -// - Rate limiting errors return 429 Too Many Requests -// - Timestamp errors return 401 with specific messages for expired or future timestamps -// - Other authentication errors return 401 with generic messages type TokenCheckMiddleware struct { - // ApiKeys provides access to persisted API key records used to resolve - // account credentials (account name, public key, and secret key). - ApiKeys *repository.ApiKeys - - // TokenHandler performs encoding/decoding of tokens and signature creation/verification. - TokenHandler *auth.TokenHandler - - // nonceCache stores recently seen nonce's to prevent replaying the same request - // within the configured TTL window. - nonceCache *cache.TTLCache - - // rateLimiter throttles repeated authentication failures per "clientIP|account" scope. - rateLimiter *limiter.MemoryLimiter - - // clockSkew defines the allowed difference between client and server time when - // validating the request timestamp. - clockSkew time.Duration - - // now is an injectable time source for deterministic tests. If nil, time.Now is used. - now func() time.Time - - // disallowFuture, if true, rejects timestamps greater than the current server time, - // even if they are within the positive skew window. - disallowFuture bool - - // nonceTTL is how long nonce remains invalid after its first use (replay-protection window). - nonceTTL time.Duration - - // failWindow indicates the sliding time window used to evaluate authentication failures. - failWindow time.Duration - - // maxFailPerScope is the maximum number of failures allowed within failWindow for a given scope. maxFailPerScope int + disallowFuture bool + nonceTTL time.Duration + failWindow time.Duration + clockSkew time.Duration + nonceCache *cache.TTLCache + now func() time.Time + TokenHandler *auth.TokenHandler + ApiKeys *repository.ApiKeys + rateLimiter *limiter.MemoryLimiter } func MakeTokenMiddleware(tokenHandler *auth.TokenHandler, apiKeys *repository.ApiKeys) TokenCheckMiddleware { return TokenCheckMiddleware{ + maxFailPerScope: 10, + disallowFuture: true, ApiKeys: apiKeys, + now: time.Now, TokenHandler: tokenHandler, - nonceCache: cache.NewTTLCache(), - rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), clockSkew: 5 * time.Minute, - now: time.Now, - disallowFuture: true, - nonceTTL: 5 * time.Minute, failWindow: 1 * time.Minute, - maxFailPerScope: 10, + nonceTTL: 5 * time.Minute, + nonceCache: cache.NewTTLCache(), + rateLimiter: limiter.NewMemoryLimiter(1*time.Minute, 10), } } func (t TokenCheckMiddleware) Handle(next http.ApiHandler) http.ApiHandler { return func(w baseHttp.ResponseWriter, r *baseHttp.Request) *http.ApiError { - reqID := strings.TrimSpace(r.Header.Get(requestIDHeader)) - logger := slog.With("request_id", reqID, "path", r.URL.Path, "method", r.Method) + reqID := strings.TrimSpace(r.Header.Get(portal.RequestIDHeader)) - if reqID == "" || logger == nil { - return t.getInvalidRequestError() + if reqID == "" { + return mwguards.InvalidRequestError(fmt.Sprintf("Invalid request ID for URL [%s].", r.URL.Path), "") } - if depErr := t.guardDependencies(logger); depErr != nil { - return depErr + if err := t.GuardDependencies(); err != nil { + return err } - // Extract and validate required headers - accountName, publicToken, signature, ts, nonce, hdrErr := t.validateAndGetHeaders(r, logger) - if hdrErr != nil { - return hdrErr + headers, err := t.ValidateAndGetHeaders(r, reqID) + if err != nil { + return err } // Validate timestamp within allowed skew using ValidTimestamp helper - vt := NewValidTimestamp(ts, logger, t.now) + vt := NewValidTimestamp(headers.Timestamp, t.now) if tsErr := vt.Validate(t.clockSkew, t.disallowFuture); tsErr != nil { return tsErr } - // Read body and compute hash - bodyHash, bodyErr := t.readBodyHash(r, logger) - if bodyErr != nil { - return bodyErr + var apiKey *database.APIKey + if apiKey, err = t.HasInvalidFormat(headers); err != nil { + return err } - // Build a canonical request string - canonical := portal.BuildCanonical(r.Method, r.URL, accountName, publicToken, ts, nonce, bodyHash) - - clientIP := portal.ParseClientIP(r) - - if err := t.shallReject(logger, accountName, publicToken, signature, canonical, nonce, clientIP); err != nil { + if err = t.HasInvalidSignature(headers, apiKey); err != nil { return err } - // Update the request context - r = t.attachContext(r, accountName, reqID) - - logger.Info("authentication successful") + r = t.AttachContext(r, headers) return next(w, r) } } -func (t TokenCheckMiddleware) guardDependencies(logger *slog.Logger) *http.ApiError { +func (t TokenCheckMiddleware) GuardDependencies() *http.ApiError { missing := make([]string, 0, 4) if t.ApiKeys == nil { - missing = append(missing, "ApiKeys") + missing = append(missing, "KeysRepository") } if t.TokenHandler == nil { @@ -165,160 +105,132 @@ func (t TokenCheckMiddleware) guardDependencies(logger *slog.Logger) *http.ApiEr } if len(missing) > 0 { - logger.Error("token middleware missing dependencies", "missing", strings.Join(missing, ",")) - return t.getUnauthenticatedError() + return mwguards.UnauthenticatedError( + "token middleware missing dependencies", + "token middleware missing dependencies: "+strings.Join(missing, ",")+".", + map[string]any{ + "missing": missing, + }, + ) } return nil } -func (t TokenCheckMiddleware) validateAndGetHeaders(r *baseHttp.Request, logger *slog.Logger) (accountName, publicToken, signature, ts, nonce string, apiErr *http.ApiError) { - accountName = strings.TrimSpace(r.Header.Get(usernameHeader)) - publicToken = strings.TrimSpace(r.Header.Get(tokenHeader)) - signature = strings.TrimSpace(r.Header.Get(signatureHeader)) - ts = strings.TrimSpace(r.Header.Get(timestampHeader)) - nonce = strings.TrimSpace(r.Header.Get(nonceHeader)) +func (t TokenCheckMiddleware) ValidateAndGetHeaders(r *baseHttp.Request, requestId string) (AuthTokenHeaders, *http.ApiError) { + intendedOriginURL := strings.TrimSpace(r.Header.Get(portal.IntendedOriginHeader)) + accountName := strings.TrimSpace(r.Header.Get(portal.UsernameHeader)) + signature := strings.TrimSpace(r.Header.Get(portal.SignatureHeader)) + publicToken := strings.TrimSpace(r.Header.Get(portal.TokenHeader)) + ts := strings.TrimSpace(r.Header.Get(portal.TimestampHeader)) + nonce := strings.TrimSpace(r.Header.Get(portal.NonceHeader)) + ip := portal.ParseClientIP(r) - if accountName == "" || publicToken == "" || signature == "" || ts == "" || nonce == "" { - logger.Warn("missing authentication headers") - return "", "", "", "", "", t.getInvalidRequestError() + if accountName == "" || publicToken == "" || signature == "" || ts == "" || nonce == "" || ip == "" || intendedOriginURL == "" { + return AuthTokenHeaders{}, mwguards.InvalidRequestError( + "Invalid authentication headers / or missing headers", + "", + ) } if err := auth.ValidateTokenFormat(publicToken); err != nil { - logger.Warn("invalid token format") - return "", "", "", "", "", t.getInvalidTokenFormatError() - } - - return accountName, publicToken, signature, ts, nonce, nil + return AuthTokenHeaders{}, mwguards.InvalidTokenFormatError(err.Error(), "", map[string]any{}) + } + + return AuthTokenHeaders{ + Timestamp: ts, + ClientIP: ip, + Nonce: nonce, + Signature: signature, + RequestID: requestId, + AccountName: accountName, + PublicKey: publicToken, + IntendedOriginURL: intendedOriginURL, + }, nil } -func (t TokenCheckMiddleware) readBodyHash(r *baseHttp.Request, logger *slog.Logger) (string, *http.ApiError) { - if r.Body == nil { - return portal.Sha256Hex(nil), nil - } - - b, err := portal.ReadWithSizeLimit(r.Body) - if err != nil { - logger.Warn("unable to read body for signing") - return "", t.getInvalidRequestError() - } - - // restore for downstream handlers - r.Body = io.NopCloser(bytes.NewReader(b)) - - return portal.Sha256Hex(b), nil -} - -func (t TokenCheckMiddleware) attachContext(r *baseHttp.Request, accountName, reqID string) *baseHttp.Request { - ctx := context.WithValue(r.Context(), authAccountNameKey, accountName) - ctx = context.WithValue(r.Context(), requestIdKey, reqID) +func (t TokenCheckMiddleware) AttachContext(r *baseHttp.Request, headers AuthTokenHeaders) *baseHttp.Request { + ctx := context.WithValue(r.Context(), portal.AuthAccountNameKey, headers.AccountName) + ctx = context.WithValue(r.Context(), portal.RequestIDKey, headers.RequestID) return r.WithContext(ctx) } -func (t TokenCheckMiddleware) shallReject(logger *slog.Logger, accountName, publicToken, signature, canonical, nonce, clientIP string) *http.ApiError { - limiterKey := clientIP + "|" + strings.ToLower(accountName) +func (t TokenCheckMiddleware) HasInvalidFormat(headers AuthTokenHeaders) (*database.APIKey, *http.ApiError) { + limiterKey := headers.ClientIP + "|" + strings.ToLower(headers.AccountName) if t.rateLimiter.TooMany(limiterKey) { - logger.Warn("too many authentication failures", "ip", clientIP) - return t.getRateLimitedError() - } - - var item *database.APIKey - if item = t.ApiKeys.FindBy(accountName); item == nil { - t.rateLimiter.Fail(limiterKey) - logger.Warn("account not found") - return t.getUnauthenticatedError() + return nil, mwguards.RateLimitedError( + "Too many authentication attempts", + "Too many authentication attempts for key: "+limiterKey, + ) } - // Fetch account to understand its keys - token, err := t.TokenHandler.DecodeTokensFor( - item.AccountName, - item.SecretKey, - item.PublicKey, - ) + guard := mwguards.NewMWTokenGuard(t.ApiKeys, t.TokenHandler) - if err != nil { - t.rateLimiter.Fail(limiterKey) - logger.Error("failed to decode account keys", "account", item.AccountName, "error", err) - return t.getUnauthenticatedError() + rejectsRequest := mwguards.MWTokenGuardData{ + Username: headers.AccountName, + PublicKey: headers.PublicKey, } - // Constant-time compare (fixed-length by hashing) of provided public token vs stored one - pBytes := []byte(strings.TrimSpace(publicToken)) - eBytes := []byte(strings.TrimSpace(token.PublicKey)) - hP := sha256.Sum256(pBytes) - hE := sha256.Sum256(eBytes) - - if subtle.ConstantTimeCompare(hP[:], hE[:]) != 1 { + if guard.Rejects(rejectsRequest) { t.rateLimiter.Fail(limiterKey) - logger.Warn("public token mismatch", "account", item.AccountName) - return t.getUnauthenticatedError() - } - - // Compute local signature over canonical request and compare in constant time (hash to fixed-length first) - localSignature := auth.CreateSignatureFrom(canonical, token.PublicKey) //@todo Change! - hSig := sha256.Sum256([]byte(strings.TrimSpace(signature))) - hLocal := sha256.Sum256([]byte(localSignature)) - if subtle.ConstantTimeCompare(hSig[:], hLocal[:]) != 1 { - t.rateLimiter.Fail(limiterKey) - logger.Warn("signature mismatch", "account", item.AccountName) - return t.getUnauthenticatedError() + return nil, mwguards.UnauthenticatedError( + "Invalid public token", + guard.Error.Error(), + rejectsRequest.ToMap(), + ) } - // Nonce replay protection: atomically check-and-mark (UseOnce) if t.nonceCache != nil { - key := item.AccountName + "|" + nonce + key := strings.ToLower(headers.AccountName) + "|" + headers.Nonce if t.nonceCache.UseOnce(key, t.nonceTTL) { t.rateLimiter.Fail(limiterKey) - logger.Warn("replay detected: nonce already used", "account", item.AccountName) - return t.getUnauthenticatedError() + + return nil, mwguards.UnauthenticatedError( + "Invalid nonce", + "Invalid nonce using key: "+key, + map[string]any{"key": key, "limiter_key": limiterKey}, + ) } } - return nil + return guard.ApiKey, nil } -func (t TokenCheckMiddleware) getInvalidRequestError() *http.ApiError { - return &http.ApiError{ - Message: "Invalid authentication headers", - Status: baseHttp.StatusUnauthorized, - } -} +func (t TokenCheckMiddleware) HasInvalidSignature(headers AuthTokenHeaders, apiKey *database.APIKey) *http.ApiError { + var err error + var byteSignature []byte + limiterKey := headers.ClientIP + "|" + strings.ToLower(headers.AccountName) -func (t TokenCheckMiddleware) getInvalidTokenFormatError() *http.ApiError { - return &http.ApiError{ - Message: "Invalid credentials", - Status: baseHttp.StatusUnauthorized, - } -} + if byteSignature, err = hex.DecodeString(headers.Signature); err != nil { + t.rateLimiter.Fail(limiterKey) -func (t TokenCheckMiddleware) getUnauthenticatedError() *http.ApiError { - return &http.ApiError{ - Message: "Invalid credentials", - Status: baseHttp.StatusUnauthorized, + return mwguards.NotFound("error decoding signature string", "") } -} -func (t TokenCheckMiddleware) getRateLimitedError() *http.ApiError { - return &http.ApiError{ - Message: "Too many authentication attempts", - Status: baseHttp.StatusTooManyRequests, + entity := repoentity.FindSignatureFrom{ + Key: apiKey, + Signature: byteSignature, + Origin: headers.IntendedOriginURL, + ServerTime: time.Now(), } -} -func (t TokenCheckMiddleware) getTimestampTooOldError() *http.ApiError { - return &http.ApiError{ - Message: "Request timestamp expired", - Status: baseHttp.StatusUnauthorized, + signature := t.ApiKeys.FindSignatureFrom(entity) + + if signature == nil { + t.rateLimiter.Fail(limiterKey) + + return mwguards.NotFound("signature not found", "") } -} -func (t TokenCheckMiddleware) getTimestampTooNewError() *http.ApiError { - return &http.ApiError{ - Message: "Request timestamp invalid", - Status: baseHttp.StatusUnauthorized, + if err = t.ApiKeys.IncreaseSignatureTries(signature.UUID, signature.CurrentTries+1); err != nil { + t.rateLimiter.Fail(limiterKey) + + return mwguards.InvalidRequestError("could not increase signature tries", err.Error()) } + + return nil } diff --git a/pkg/middleware/token_middleware_additional_test.go b/pkg/middleware/token_middleware_additional_test.go index 19535a3c..c150668e 100644 --- a/pkg/middleware/token_middleware_additional_test.go +++ b/pkg/middleware/token_middleware_additional_test.go @@ -79,13 +79,13 @@ func makeRepo(t *testing.T, account string) (*repository.ApiKeys, *auth.TokenHan func TestTokenMiddlewareGuardDependencies(t *testing.T) { logger := slogNoop() tm := TokenCheckMiddleware{} - if err := tm.guardDependencies(logger); err == nil || err.Status != http.StatusUnauthorized { + if err := tm.GuardDependencies(logger); err == nil || err.Status != http.StatusUnauthorized { t.Fatalf("expected unauthorized when dependencies missing") } tm.ApiKeys, tm.TokenHandler, _ = makeRepo(t, "guard1") tm.nonceCache = cache.NewTTLCache() tm.rateLimiter = limiter.NewMemoryLimiter(time.Minute, 1) - if err := tm.guardDependencies(logger); err != nil { + if err := tm.GuardDependencies(logger); err != nil { t.Fatalf("expected no error when dependencies provided, got %#v", err) } } diff --git a/pkg/middleware/token_middleware_test.go b/pkg/middleware/token_middleware_test.go index 18b00b57..d4aa0943 100644 --- a/pkg/middleware/token_middleware_test.go +++ b/pkg/middleware/token_middleware_test.go @@ -95,7 +95,7 @@ func TestValidateAndGetHeaders_MissingAndInvalidFormat(t *testing.T) { logger := slogNoop() req := httptest.NewRequest("GET", "/", nil) // All empty - if _, _, _, _, _, apiErr := tm.validateAndGetHeaders(req, logger); apiErr == nil || apiErr.Status != http.StatusUnauthorized { + if _, _, _, _, _, apiErr := tm.ValidateAndGetHeaders(req, logger); apiErr == nil || apiErr.Status != http.StatusUnauthorized { t.Fatalf("expected error for missing headers") } @@ -105,7 +105,7 @@ func TestValidateAndGetHeaders_MissingAndInvalidFormat(t *testing.T) { req.Header.Set("X-API-Signature", "sig") req.Header.Set("X-API-Timestamp", "1700000000") req.Header.Set("X-API-Nonce", "n1") - if _, _, _, _, _, apiErr := tm.validateAndGetHeaders(req, logger); apiErr == nil || apiErr.Status != http.StatusUnauthorized { + if _, _, _, _, _, apiErr := tm.ValidateAndGetHeaders(req, logger); apiErr == nil || apiErr.Status != http.StatusUnauthorized { t.Fatalf("expected error for invalid token format") } } @@ -129,7 +129,7 @@ func TestReadBodyHash_RestoresBody(t *testing.T) { func TestAttachContext(t *testing.T) { tm := MakeTokenMiddleware(nil, nil) req := httptest.NewRequest("GET", "/", nil) - r := tm.attachContext(req, "Alice", "RID-123") + r := tm.AttachContext(req, "Alice", "RID-123") if r == req { t.Fatalf("expected a new request with updated context") } diff --git a/pkg/middleware/valid_timestamp.go b/pkg/middleware/valid_timestamp.go index 88cb2dbc..7877af7b 100644 --- a/pkg/middleware/valid_timestamp.go +++ b/pkg/middleware/valid_timestamp.go @@ -1,7 +1,6 @@ package middleware import ( - "log/slog" baseHttp "net/http" "strconv" "time" @@ -9,41 +8,25 @@ import ( "github.com/oullin/pkg/http" ) -// ValidTimestamp encapsulates timestamp validation context. -// It accepts: the raw timestamp string (ts), a logger, and a clock (now) function. -// Use Validate to check against a provided skew window and future-time policy. type ValidTimestamp struct { - // ts is the timestamp string (expected Unix epoch in seconds). - ts string - - // logger is used to record validation details. - logger *slog.Logger - - // now returns the current time; useful to inject a deterministic clock in tests. + ts string now func() time.Time } -func NewValidTimestamp(ts string, logger *slog.Logger, now func() time.Time) ValidTimestamp { +func NewValidTimestamp(ts string, now func() time.Time) ValidTimestamp { return ValidTimestamp{ - ts: ts, - logger: logger, - now: now, + ts: ts, + now: now, } } func (v ValidTimestamp) Validate(skew time.Duration, disallowFuture bool) *http.ApiError { - if v.logger == nil { - return &http.ApiError{Message: "Invalid timestamp headers tracker", Status: baseHttp.StatusUnauthorized} - } - if v.ts == "" { - v.logger.Warn("missing timestamp") return &http.ApiError{Message: "Invalid authentication headers", Status: baseHttp.StatusUnauthorized} } epoch, err := strconv.ParseInt(v.ts, 10, 64) if err != nil { - v.logger.Warn("invalid timestamp format") return &http.ApiError{Message: "Invalid authentication headers", Status: baseHttp.StatusUnauthorized} } @@ -66,12 +49,10 @@ func (v ValidTimestamp) Validate(skew time.Duration, disallowFuture bool) *http. } if epoch < minValue { - v.logger.Warn("timestamp outside allowed window: too old") return &http.ApiError{Message: "Request timestamp expired", Status: baseHttp.StatusUnauthorized} } if epoch > maxValue { - v.logger.Warn("timestamp outside allowed window: in the future") return &http.ApiError{Message: "Request timestamp invalid", Status: baseHttp.StatusUnauthorized} } diff --git a/pkg/portal/client_test.go b/pkg/portal/client_test.go index 64140a4e..0bc3b62b 100644 --- a/pkg/portal/client_test.go +++ b/pkg/portal/client_test.go @@ -12,7 +12,7 @@ func TestClientTransportAndGet(t *testing.T) { c := MakeDefaultClient(tr) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) + _, _ = w.Write([]byte("hello")) })) defer srv.Close() @@ -26,7 +26,7 @@ func TestClientTransportAndGet(t *testing.T) { func TestClientGetNil(t *testing.T) { var c *Client - _, err := c.Get(context.Background(), "http://example.com") + _, err := c.Get(context.Background(), "https://example.com") if err == nil { t.Fatalf("expected error") diff --git a/pkg/portal/consts.go b/pkg/portal/consts.go new file mode 100644 index 00000000..70bfe672 --- /dev/null +++ b/pkg/portal/consts.go @@ -0,0 +1,21 @@ +package portal + +const DatesLayout = "2006-01-02 15:04:05" +const MaxSignaturesTries = 10 + +// ---- Middleware / HTTP + +const TokenHeader = "X-API-Key" +const UsernameHeader = "X-API-Username" +const SignatureHeader = "X-API-Signature" +const TimestampHeader = "X-API-Timestamp" +const NonceHeader = "X-API-Nonce" +const RequestIDHeader = "X-Request-ID" +const IntendedOriginHeader = "X-API-Intended-Origin" + +// ---- Middleware / Context + +type contextKey string + +const AuthAccountNameKey contextKey = "auth.account_name" +const RequestIDKey contextKey = "request.id" diff --git a/pkg/portal/support.go b/pkg/portal/support.go index 44fbd374..99d49704 100644 --- a/pkg/portal/support.go +++ b/pkg/portal/support.go @@ -23,6 +23,24 @@ func CloseWithLog(c io.Closer) { } } +func GenerateURL(r *baseHttp.Request) string { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + if v := r.Header.Get("X-Forwarded-Proto"); v != "" { + scheme = v + } + + host := r.Host + if v := r.Header.Get("X-Forwarded-Host"); v != "" { + host = v + } + + return scheme + "://" + host + r.URL.RequestURI() +} + func Sha256Hex(b []byte) string { h := sha256.Sum256(b) return hex.EncodeToString(h[:])