Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
43 changes: 33 additions & 10 deletions backend/api/v1/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ import (
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"

"github.com/bytebase/bytebase/backend/api/auth"
"github.com/bytebase/bytebase/backend/common"
"github.com/bytebase/bytebase/backend/common/log"
"github.com/bytebase/bytebase/backend/component/config"
storepb "github.com/bytebase/bytebase/backend/generated-go/store"
v1pb "github.com/bytebase/bytebase/backend/generated-go/v1"
"github.com/bytebase/bytebase/backend/generated-go/v1/v1connect"
"github.com/bytebase/bytebase/backend/store"
)

Expand All @@ -30,15 +33,19 @@ var (
maskedString string
)

// ACLInterceptor is the v1 ACL interceptor for gRPC server.
// AuditInterceptor is the v1 audit interceptor for gRPC server.
type AuditInterceptor struct {
store *store.Store
store *store.Store
secret string
profile *config.Profile
}

// NewAuditInterceptor returns a new v1 API ACL interceptor.
func NewAuditInterceptor(store *store.Store) *AuditInterceptor {
// NewAuditInterceptor returns a new v1 API audit interceptor.
func NewAuditInterceptor(store *store.Store, secret string, profile *config.Profile) *AuditInterceptor {
return &AuditInterceptor{
store: store,
store: store,
secret: secret,
profile: profile,
}
}

Expand All @@ -59,7 +66,7 @@ func (in *AuditInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
if !common.IsNil(response) {
respMsg = response.Any()
}
if err := createAuditLogConnect(ctx, req.Any(), respMsg, req.Spec().Procedure, in.store, serviceData, rerr, req.Header(), req.Peer().Addr, latency); err != nil {
if err := createAuditLogConnect(ctx, req.Any(), respMsg, req.Spec().Procedure, in.store, in.secret, in.profile, serviceData, rerr, req.Header(), req.Peer().Addr, latency); err != nil {
slog.Warn("audit interceptor: failed to create audit log", log.BBError(err), slog.String("method", req.Spec().Procedure))
}
}
Expand Down Expand Up @@ -120,14 +127,14 @@ func (c *auditConnectStreamingConn) Send(resp any) error {
// Create audit log for each message pair
if c.curRequest != nil {
latency := time.Since(c.startTime)
if auditErr := createAuditLogConnect(c.ctx, c.curRequest, resp, c.method, c.interceptor.store, nil, nil, c.RequestHeader(), c.Peer().Addr, latency); auditErr != nil {
if auditErr := createAuditLogConnect(c.ctx, c.curRequest, resp, c.method, c.interceptor.store, c.interceptor.secret, c.interceptor.profile, nil, nil, c.RequestHeader(), c.Peer().Addr, latency); auditErr != nil {
return auditErr
}
}
return nil
}

func createAuditLogConnect(ctx context.Context, request, response any, method string, storage *store.Store, serviceData *anypb.Any, rerr error, headers http.Header, peerAddr string, latency time.Duration) error {
func createAuditLogConnect(ctx context.Context, request, response any, method string, storage *store.Store, secret string, profile *config.Profile, serviceData *anypb.Any, rerr error, headers http.Header, peerAddr string, latency time.Duration) error {
requestString, err := getRequestString(request)
if err != nil {
return errors.Wrapf(err, "failed to get request string")
Expand All @@ -141,6 +148,7 @@ func createAuditLogConnect(ctx context.Context, request, response any, method st
if u, ok := GetUserFromContext(ctx); ok {
user = common.FormatUserUID(u.ID)
} else {
// Try to get user from successful login response.
if loginResponse, ok := response.(*v1pb.LoginResponse); ok {
user = loginResponse.GetUser().GetName()
}
Expand Down Expand Up @@ -169,10 +177,26 @@ func createAuditLogConnect(ctx context.Context, request, response any, method st

createAuditLogCtx := context.WithoutCancel(ctx)
for _, parent := range parents {
resource := getRequestResource(request)
// For login requests, if resource is empty, try to get email from user context or MFA temp token.
// This handles MFA phase where request doesn't have email field.
if resource == "" && method == v1connect.AuthServiceLoginProcedure {
if u, ok := GetUserFromContext(ctx); ok {
resource = u.Email
} else if loginRequest, ok := request.(*v1pb.LoginRequest); ok && loginRequest.MfaTempToken != nil {
// Extract user from MFA temp token and fetch email from database.
if userID, err := auth.GetUserIDFromMFATempToken(*loginRequest.MfaTempToken, profile.Mode, secret); err == nil {
if u, err := storage.GetUserByID(createAuditLogCtx, userID); err == nil && u != nil {
resource = u.Email
}
}
}
}

p := &storepb.AuditLog{
Parent: parent,
Method: method,
Resource: getRequestResource(request),
Resource: resource,
Severity: storepb.AuditLog_INFO,
User: user,
Request: requestString,
Expand Down Expand Up @@ -477,7 +501,6 @@ func redactQueryResponse(r *v1pb.QueryResponse) *v1pb.QueryResponse {
Latency: result.Latency,
Statement: result.Statement,
DetailedError: result.DetailedError,
AllowExport: result.AllowExport,
Masked: redactMaskingReasons(result.Masked), // Redact icon data
})
}
Expand Down
122 changes: 115 additions & 7 deletions backend/api/v1/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/bytebase/bytebase/backend/api/auth"
"github.com/bytebase/bytebase/backend/common"
"github.com/bytebase/bytebase/backend/common/log"
"github.com/bytebase/bytebase/backend/common/qb"
"github.com/bytebase/bytebase/backend/component/config"
"github.com/bytebase/bytebase/backend/component/iam"
"github.com/bytebase/bytebase/backend/component/state"
Expand All @@ -35,8 +36,32 @@ import (
"github.com/bytebase/bytebase/backend/store"
)

const (
// mfaTempTokenDuration is the duration for MFA temporary tokens.
// Following industry standards (Okta: 5 minutes, Auth0: 10 minutes, AWS Cognito: 3 minutes).
// A short duration reduces the attack window for TOTP brute-force attempts.
mfaTempTokenDuration = 5 * time.Minute

// Login rate limiting configuration.
// Password phase: 10 failed attempts within 10 minutes.
passwordMaxFailedAttempts = 10 //nolint:unused // Will be used for password rate limiting
passwordLockoutWindow = 10 * time.Minute //nolint:unused // Will be used for password rate limiting

// MFA phase: 5 failed attempts within 5 minutes.
mfaMaxFailedAttempts = 5
mfaLockoutWindow = 5 * time.Minute

// Error messages for authentication failures.
// These constants are used both for error responses and for querying audit logs during rate limiting.
errMsgInvalidCredentials = "invalid email or password"
errMsgInvalidMFACode = "invalid MFA code"
errMsgInvalidRecoveryCode = "invalid recovery code"
errMsgTooManyPassword = "too many failed login attempts, please try again later" //nolint:unused // Will be used for password rate limiting
errMsgTooManyMFA = "too many failed MFA attempts, please try again later"
)

var (
invalidUserOrPasswordError = connect.NewError(connect.CodeUnauthenticated, errors.Errorf("the email or password is not valid"))
invalidCredentialsError = connect.NewError(connect.CodeUnauthenticated, errors.Errorf(errMsgInvalidCredentials))
)

// AuthService implements the auth service.
Expand Down Expand Up @@ -98,7 +123,12 @@ func (s *AuthService) Login(ctx context.Context, req *connect.Request[v1pb.Login
return nil, connect.NewError(connect.CodeInternal, errors.Wrapf(err, "failed to find user, error"))
}
if user == nil {
return nil, invalidUserOrPasswordError
return nil, invalidCredentialsError
}

// Check if user is locked out due to too many failed MFA attempts.
if err := s.checkMFALockout(ctx, user.Email); err != nil {
return nil, err
}

if request.OtpCode != nil {
Expand Down Expand Up @@ -145,7 +175,7 @@ func (s *AuthService) Login(ctx context.Context, req *connect.Request[v1pb.Login
userMFAEnabled := loginUser.MFAConfig != nil && loginUser.MFAConfig.OtpSecret != ""
// We only allow MFA login (2-step) when the feature is enabled and user has enabled MFA.
if s.licenseService.IsFeatureEnabled(v1pb.PlanFeature_FEATURE_TWO_FA) == nil && !mfaSecondLogin && userMFAEnabled {
mfaTempToken, err := auth.GenerateMFATempToken(loginUser.Name, loginUser.ID, s.profile.Mode, s.secret, tokenDuration)
mfaTempToken, err := auth.GenerateMFATempToken(loginUser.Name, loginUser.ID, s.profile.Mode, s.secret, mfaTempTokenDuration)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, errors.Errorf("failed to generate MFA temp token"))
}
Expand Down Expand Up @@ -269,17 +299,22 @@ func (s *AuthService) Logout(ctx context.Context, req *connect.Request[v1pb.Logo
}

func (s *AuthService) getAndVerifyUser(ctx context.Context, request *v1pb.LoginRequest) (*store.UserMessage, error) {
// Check if user is locked out due to too many failed password attempts.
if err := s.checkPasswordLockout(ctx, request.Email); err != nil {
return nil, err
}

user, err := s.store.GetUserByEmail(ctx, request.Email)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, errors.Wrapf(err, "failed to get user by email %q", request.Email))
}
if user == nil {
return nil, invalidUserOrPasswordError
return nil, invalidCredentialsError
}
// Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil {
// If the two passwords don't match, return a 401 status.
return nil, invalidUserOrPasswordError
return nil, invalidCredentialsError
}
return user, nil
}
Expand Down Expand Up @@ -471,9 +506,82 @@ func (s *AuthService) userCountGuard(ctx context.Context) error {
return nil
}

// countRecentLoginFailures counts the number of failed login attempts for a given email
// within the specified time window, matching any of the provided error messages.
func (s *AuthService) countRecentLoginFailures(ctx context.Context, email string, window time.Duration, errMessages ...string) (int, error) {
if len(errMessages) == 0 {
return 0, errors.New("at least one error message is required")
}

windowStart := time.Now().Add(-window)

// Build filter query for login failures.
filterQ := qb.Q().Space("TRUE")
filterQ.And("payload->>'method' = ?", "/bytebase.v1.AuthService/Login")
filterQ.And("payload->>'resource' = ?", email)
filterQ.And("(payload->'status'->>'code')::int != 0")

// Build OR condition for error messages.
if len(errMessages) == 1 {
filterQ.And("payload->'status'->>'message' = ?", errMessages[0])
} else {
// For multiple messages, build: (msg = ? OR msg = ? OR ...)
orConditions := qb.Q()
for i, msg := range errMessages {
if i == 0 {
orConditions.Space("payload->'status'->>'message' = ?", msg)
} else {
orConditions.Or("payload->'status'->>'message' = ?", msg)
}
}
filterQ.And("(?)", orConditions)
}

filterQ.And("created_at >= ?", windowStart)

logs, err := s.store.SearchAuditLogs(ctx, &store.AuditLogFind{
FilterQ: filterQ,
})
if err != nil {
return 0, errors.Wrapf(err, "failed to search audit logs for login failures")
}

return len(logs), nil
}

// checkPasswordLockout checks if the user has exceeded the password failure rate limit.
// Returns an error if the user is locked out due to too many failed password attempts.
func (s *AuthService) checkPasswordLockout(ctx context.Context, email string) error {
count, err := s.countRecentLoginFailures(ctx, email, passwordLockoutWindow, errMsgInvalidCredentials)
if err != nil {
return connect.NewError(connect.CodeInternal, errors.Wrapf(err, "failed to count recent password failures"))
}

if count >= passwordMaxFailedAttempts {
return connect.NewError(connect.CodeResourceExhausted, errors.Errorf(errMsgTooManyPassword))
}

return nil
}

// checkMFALockout checks if the user has exceeded the MFA failure rate limit.
// Returns an error if the user is locked out due to too many failed MFA attempts.
func (s *AuthService) checkMFALockout(ctx context.Context, email string) error {
count, err := s.countRecentLoginFailures(ctx, email, mfaLockoutWindow, errMsgInvalidMFACode, errMsgInvalidRecoveryCode)
if err != nil {
return connect.NewError(connect.CodeInternal, errors.Wrapf(err, "failed to count recent MFA failures"))
}

if count >= mfaMaxFailedAttempts {
return connect.NewError(connect.CodeResourceExhausted, errors.Errorf(errMsgTooManyMFA))
}

return nil
}

func challengeMFACode(user *store.UserMessage, mfaCode string) error {
if !validateWithCodeAndSecret(mfaCode, user.MFAConfig.OtpSecret) {
return connect.NewError(connect.CodeUnauthenticated, errors.Errorf("invalid MFA code"))
return connect.NewError(connect.CodeUnauthenticated, errors.Errorf(errMsgInvalidMFACode))
}
return nil
}
Expand All @@ -495,7 +603,7 @@ func (s *AuthService) challengeRecoveryCode(ctx context.Context, user *store.Use
return nil
}
}
return connect.NewError(connect.CodeUnauthenticated, errors.Errorf("invalid recovery code"))
return connect.NewError(connect.CodeUnauthenticated, errors.Errorf(errMsgInvalidRecoveryCode))
}

// validateWithCodeAndSecret validates the given code against the given secret.
Expand Down
12 changes: 1 addition & 11 deletions backend/api/v1/sql_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,6 @@ func (s *SQLService) AdminExecute(ctx context.Context, stream *connect.BidiStrea
}
} else {
response.Results = result
for _, result := range response.Results {
// The AdminExecute requires bb.sql.admin permission, so we can presume the users have enough permission to export.
result.AllowExport = true
}
}

if err := stream.Send(response); err != nil {
Expand Down Expand Up @@ -212,7 +208,7 @@ func (s *SQLService) Query(ctx context.Context, req *connect.Request[v1pb.QueryR
if request.Schema != nil {
queryContext.Schema = *request.Schema
}
results, spans, duration, queryErr := queryRetry(
results, _, duration, queryErr := queryRetry(
ctx,
s.store,
user,
Expand Down Expand Up @@ -261,12 +257,6 @@ func (s *SQLService) Query(ctx context.Context, req *connect.Request[v1pb.QueryR
return nil, connect.NewError(code, errors.New(queryErr.Error()))
}

for _, result := range results {
// AllowExport is a validate only check.
checkErr := s.accessCheck(ctx, instance, database, user, spans, request.Explain)
result.AllowExport = checkErr == nil
}

slog.Debug("request finished",
slog.Duration("duration", time.Since(startTime)),
slog.String("instance", instance.ResourceID),
Expand Down
16 changes: 3 additions & 13 deletions backend/generated-go/v1/sql_service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions backend/generated-go/v1/sql_service_equal.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading