Skip to content
Open
7 changes: 6 additions & 1 deletion internal/controller/oauth_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
name = user.Name
} else {
controller.log.App.Debug().Msg("No name from OAuth provider, generating from email")
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
parts := strings.SplitN(user.Email, "@", 2)
if len(parts) == 2 {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(parts[0]), parts[1])
} else {
name = utils.Capitalize(user.Email)
}
}

var username string
Expand Down
4 changes: 2 additions & 2 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
client, ok := controller.oidc.GetClient(req.ClientID)

if !ok {
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
controller.authorizeError(c, fmt.Errorf("client not found: %s", req.ClientID), "Client not found", "The client ID is invalid", "", "", "")
return
}

Expand Down Expand Up @@ -288,7 +288,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
if err != nil {
if err := controller.oidc.DeleteTokenByCodeHash(c, controller.oidc.Hash(req.Code)); err != nil {
controller.log.App.Error().Err(err).Msg("Failed to delete code")
controller.log.App.Error().Err(err).Msg("Failed to revoke tokens for replayed code")
}
if errors.Is(err, service.ErrCodeNotFound) {
controller.log.App.Warn().Msg("Code not found")
Expand Down
3 changes: 1 addition & 2 deletions internal/controller/well_known_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ func (controller *WellKnownController) OpenIDConnectConfiguration(c *gin.Context
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
RequestParameterSupported: false,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah #864 is a false positive, the frontend does the JWT parsing so it works just fine.

})
}

Expand Down
3 changes: 1 addition & 2 deletions internal/controller/well_known_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ func TestWellKnownController(t *testing.T) {
TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"},
ClaimsSupported: []string{"sub", "updated_at", "name", "preferred_username", "email", "email_verified", "groups", "phone_number", "phone_number_verified", "address", "given_name", "family_name", "middle_name", "nickname", "profile", "picture", "website", "gender", "birthdate", "zoneinfo", "locale"},
ServiceDocumentation: "https://tinyauth.app/docs/guides/oidc",
RequestParameterSupported: true,
RequestObjectSigningAlgValuesSupported: []string{"none"},
Comment on lines -60 to -61
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RequestParameterSupported: false,
}

assert.Equal(t, expected, res)
Expand Down
61 changes: 33 additions & 28 deletions internal/service/auth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,46 +773,49 @@ func (auth *AuthService) ensureOAuthSessionLimit() {
auth.oauthMutex.Lock()
defer auth.oauthMutex.Unlock()

if len(auth.oauthPendingSessions) >= MaxOAuthPendingSessions {

cleanupIds := make([]string, 0, OAuthCleanupCount)
if len(auth.oauthPendingSessions) <= MaxOAuthPendingSessions {
return
}

for range OAuthCleanupCount {
oldestId := ""
oldestTime := int64(0)
type entry struct {
id string
expiresAt int64
}

for id, session := range auth.oauthPendingSessions {
if oldestTime == 0 {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
continue
}
if slices.Contains(cleanupIds, id) {
continue
}
if session.ExpiresAt.Unix() < oldestTime {
oldestId = id
oldestTime = session.ExpiresAt.Unix()
}
}
entries := make([]entry, 0, len(auth.oauthPendingSessions))
for id, session := range auth.oauthPendingSessions {
entries = append(entries, entry{id, session.ExpiresAt.Unix()})
}

cleanupIds = append(cleanupIds, oldestId)
slices.SortFunc(entries, func(a, b entry) int {
if a.expiresAt < b.expiresAt {
return -1
}

for _, id := range cleanupIds {
delete(auth.oauthPendingSessions, id)
if a.expiresAt > b.expiresAt {
return 1
}
return 0
})

for _, e := range entries[:OAuthCleanupCount] {
delete(auth.oauthPendingSessions, e.id)
}
}

func (auth *AuthService) lockdownMode() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel

auth.loginMutex.Lock()

if auth.lockdown != nil && auth.lockdown.Active {
auth.loginMutex.Unlock()
cancel()
return
}

auth.lockdownCtx = ctx
auth.lockdownCancelFunc = cancel

auth.log.App.Warn().Msg("Too many failed login attempts, entering lockdown mode")

auth.lockdown = &Lockdown{
Expand All @@ -825,10 +828,12 @@ func (auth *AuthService) lockdownMode() {
auth.loginAttempts = make(map[string]*LoginAttempt)

timer := time.NewTimer(time.Until(auth.lockdown.ActiveUntil))
defer timer.Stop()

auth.loginMutex.Unlock()

defer cancel()
defer timer.Stop()

select {
case <-timer.C:
// Timer expired, end lockdown
Expand Down
1 change: 1 addition & 0 deletions internal/service/oauth_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func NewOAuthService(config model.OAuthServiceConfig, id string, ctx context.Con
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: config.Insecure,
MinVersion: tls.VersionTLS12,
},
},
}
Expand Down
55 changes: 10 additions & 45 deletions internal/service/oidc_service.go
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to keep the public key loaded from file, not that it makes a big difference but we should let users specify their own. Call it more predictable behavior. Code-wise just use the loaded public key instead of deriving one from the private key.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package service

import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
Expand Down Expand Up @@ -121,7 +120,6 @@ type OIDCService struct {

clients map[string]model.OIDCClientConfig
privateKey *rsa.PrivateKey
publicKey crypto.PublicKey
issuer string
}

Expand Down Expand Up @@ -194,49 +192,17 @@ func NewOIDCService(
}
}

var publicKey crypto.PublicKey

fpublicKey, err := os.ReadFile(config.OIDC.PublicKeyPath)

if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("failed to read public key: %w", err)
der := x509.MarshalPKCS1PublicKey(&privateKey.PublicKey)
if der == nil {
return nil, errors.New("failed to marshal public key")
}

if errors.Is(err, os.ErrNotExist) {
publicKey = privateKey.Public()
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
if der == nil {
return nil, errors.New("failed to marshal public key")
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
log.App.Trace().Str("type", "RSA PUBLIC KEY").Msg("Generated public RSA key")
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, err
}
} else {
block, _ := pem.Decode(fpublicKey)
if block == nil {
return nil, errors.New("failed to decode public key")
}
log.App.Trace().Str("type", block.Type).Msg("Loaded public key")
switch block.Type {
case "RSA PUBLIC KEY":
publicKey, err = x509.ParsePKCS1PublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
case "PUBLIC KEY":
publicKey, err = x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %w", err)
}
default:
return nil, fmt.Errorf("unsupported public key type: %s", block.Type)
}
encoded := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: der,
})
err = os.WriteFile(config.OIDC.PublicKeyPath, encoded, 0644)
if err != nil {
return nil, err
}

// We will reorganize the client into a map with the client ID as the key
Expand Down Expand Up @@ -271,7 +237,6 @@ func NewOIDCService(

clients: clients,
privateKey: privateKey,
publicKey: publicKey,
issuer: issuer,
}

Expand Down