Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion frontend/src/lib/hooks/oidc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ export type OIDCValues = {
redirect_uri: string;
state: string;
nonce: string;
code_challenge: string;
code_challenge_method: string;
};

interface IuseOIDCParams {
Expand All @@ -14,7 +16,12 @@ interface IuseOIDCParams {
missingParams: string[];
}

const optionalParams: string[] = ["state", "nonce"];
const optionalParams: string[] = [
"state",
"nonce",
"code_challenge",
"code_challenge_method",
];

export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
let compiled: string = "";
Expand All @@ -28,6 +35,8 @@ export function useOIDCParams(params: URLSearchParams): IuseOIDCParams {
redirect_uri: params.get("redirect_uri") ?? "",
state: params.get("state") ?? "",
nonce: params.get("nonce") ?? "",
code_challenge: params.get("code_challenge") ?? "",
code_challenge_method: params.get("code_challenge_method") ?? "",
};

for (const key of Object.keys(values)) {
Expand Down
1 change: 1 addition & 0 deletions internal/assets/migrations/000007_oidc_pkce.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "oidc_codes" DROP COLUMN "code_challenge";
1 change: 1 addition & 0 deletions internal/assets/migrations/000007_oidc_pkce.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE "oidc_codes" ADD COLUMN "code_challenge" TEXT DEFAULT "";
2 changes: 2 additions & 0 deletions internal/controller/context_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"github.com/steveiliop56/tinyauth/internal/config"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)

func TestContextController(t *testing.T) {
tlog.NewTestLogger().Init()
controllerConfig := controller.ContextControllerConfig{
Providers: []controller.Provider{
{
Expand Down
2 changes: 2 additions & 0 deletions internal/controller/health_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import (

"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
)

func TestHealthController(t *testing.T) {
tlog.NewTestLogger().Init()
tests := []struct {
description string
path string
Expand Down
11 changes: 11 additions & 0 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type TokenRequest struct {
RefreshToken string `form:"refresh_token" url:"refresh_token"`
ClientSecret string `form:"client_secret" url:"client_secret"`
ClientID string `form:"client_id" url:"client_id"`
CodeVerifier string `form:"code_verifier" url:"code_verifier"`
}

type CallbackError struct {
Expand Down Expand Up @@ -308,6 +309,16 @@ func (controller *OIDCController) Token(c *gin.Context) {
return
}

ok := controller.oidc.ValidatePKCE(entry.CodeChallenge, req.CodeVerifier)

if !ok {
tlog.App.Warn().Msg("PKCE validation failed")
c.JSON(400, gin.H{
"error": "invalid_grant",
})
return
}

tokenRes, err := controller.oidc.GenerateAccessToken(c, client, entry)

if err != nil {
Expand Down
225 changes: 225 additions & 0 deletions internal/controller/oidc_controller_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package controller_test

import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http/httptest"
"net/url"
Expand All @@ -15,11 +17,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestOIDCController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()

oidcServiceCfg := service.OIDCServiceConfig{
Expand Down Expand Up @@ -431,6 +435,227 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response")
},
},
{
description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: "some-challenge",
// Not setting a code challenge method should default to "plain"
CodeChallengeMethod: "",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)

req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)

queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")

code := queryParams.Get("code")
assert.NotEmpty(t, code)

// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)

req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)

assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure S256 PKCE succeeds",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)

req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)

queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")

code := queryParams.Get("code")
assert.NotEmpty(t, code)

// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)

req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)

assert.Equal(t, 200, recorder.Code)
},
},
{
description: "Ensure request with invalid PKCE fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "S256",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)

req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)

queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state")

code := queryParams.Get("code")
assert.NotEmpty(t, code)

// Now exchange the code for a token
recorder = httptest.NewRecorder()
tokenReqBody := controller.TokenRequest{
GrantType: "authorization_code",
Code: code,
RedirectURI: "https://test.example.com/callback",
CodeVerifier: "some-challenge-1",
}
reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err)

req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.SetBasicAuth("some-client-id", "some-client-secret")
router.ServeHTTP(recorder, req)

assert.Equal(t, 400, recorder.Code)
},
},
{
description: "Ensure request with invalid challenge method fails",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
hasher := sha256.New()
hasher.Write([]byte("some-challenge"))
codeChallenge := hasher.Sum(nil)
codeChallengeEncoded := base64.RawURLEncoding.EncodeToString(codeChallenge)
reqBody := service.AuthorizeRequest{
Scope: "openid",
ResponseType: "code",
ClientID: "some-client-id",
RedirectURI: "https://test.example.com/callback",
State: "some-state",
Nonce: "some-nonce",
CodeChallenge: codeChallengeEncoded,
CodeChallengeMethod: "foo",
}
reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err)

req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)

redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI)
assert.NoError(t, err)

queryParams := url.Query()
error := queryParams.Get("error")
assert.NotEmpty(t, error)
},
Comment thread
steveiliop56 marked this conversation as resolved.
},
}

app := bootstrap.NewBootstrapApp(config.Config{})
Expand Down
3 changes: 1 addition & 2 deletions internal/controller/proxy_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
)

func TestProxyController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()

authServiceCfg := service.AuthServiceConfig{
Expand Down Expand Up @@ -390,8 +391,6 @@ func TestProxyController(t *testing.T) {
},
}

tlog.NewSimpleLogger().Init()

oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)

app := bootstrap.NewBootstrapApp(config.Config{})
Expand Down
2 changes: 2 additions & 0 deletions internal/controller/resources_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ import (

"github.com/gin-gonic/gin"
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestResourcesController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()

resourcesControllerCfg := controller.ResourcesControllerConfig{
Expand Down
3 changes: 1 addition & 2 deletions internal/controller/user_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
)

func TestUserController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()

authServiceCfg := service.AuthServiceConfig{
Expand Down Expand Up @@ -274,8 +275,6 @@ func TestUserController(t *testing.T) {
},
}

tlog.NewSimpleLogger().Init()

oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)

app := bootstrap.NewBootstrapApp(config.Config{})
Expand Down
2 changes: 2 additions & 0 deletions internal/controller/well_known_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ import (
"github.com/steveiliop56/tinyauth/internal/controller"
"github.com/steveiliop56/tinyauth/internal/repository"
"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestWellKnownController(t *testing.T) {
tlog.NewTestLogger().Init()
tempDir := t.TempDir()

oidcServiceCfg := service.OIDCServiceConfig{
Expand Down
Loading
Loading