Skip to content

Commit

Permalink
authenticate: remove ecjson (#3688)
Browse files Browse the repository at this point in the history
  • Loading branch information
calebdoxsey committed Oct 20, 2022
1 parent 61506c1 commit 75634df
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 206 deletions.
43 changes: 17 additions & 26 deletions authenticate/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
sharedCipher: sharedCipher,
sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder,
encryptedEncoder: tt.encoder,
sharedCipher: sharedCipher,
sessionStore: tt.session,
redirectURL: uriParseHelper("https://some.example"),
sharedEncoder: tt.encoder,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Expand Down Expand Up @@ -308,9 +307,8 @@ func TestAuthenticate_SignOut(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: mock.Encoder{},
sharedEncoder: mock.Encoder{},
sessionStore: tt.sessionStore,
sharedEncoder: mock.Encoder{},
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Expand Down Expand Up @@ -411,10 +409,6 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
if err != nil {
t.Fatal(err)
}
signer, err := jws.NewHS256Signer(nil)
if err != nil {
t.Fatal(err)
}
authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
Expand All @@ -429,11 +423,10 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
return nil, nil
},
},
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
encryptedEncoder: signer,
directoryClient: new(mockDirectoryServiceClient),
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
}),
options: config.NewAtomicOptions(),
}
Expand Down Expand Up @@ -558,12 +551,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
cookieCipher: aead,
encryptedEncoder: signer,
sharedEncoder: signer,
cookieSecret: cryptutil.NewKey(),
redirectURL: uriParseHelper("https://authenticate.corp.beyondperimeter.com"),
sessionStore: tt.session,
cookieCipher: aead,
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Expand Down Expand Up @@ -697,9 +689,8 @@ func TestAuthenticate_userInfo(t *testing.T) {
a := &Authenticate{
options: o,
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: signer,
sharedEncoder: signer,
sessionStore: tt.sessionStore,
sharedEncoder: signer,
dataBrokerClient: mockDataBrokerServiceClient{
get: func(ctx context.Context, in *databroker.GetRequest, opts ...grpc.CallOption) (*databroker.GetResponse, error) {
return &databroker.GetResponse{
Expand Down
10 changes: 1 addition & 9 deletions authenticate/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ import (

"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/sessions/header"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
Expand All @@ -40,8 +38,6 @@ type authenticateState struct {
cookieSecret []byte
// cookieCipher is the cipher to use to encrypt/decrypt session data
cookieCipher cipher.AEAD
// encryptedEncoder is the encoder used to marshal and unmarshal session data
encryptedEncoder encoding.MarshalUnmarshaler
// sessionStore is the session store used to persist a user's session
sessionStore sessions.SessionStore
// sessionLoaders are a collection of session loaders to attempt to pull
Expand Down Expand Up @@ -110,10 +106,6 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
return nil, err
}

state.encryptedEncoder = ecjson.New(state.cookieCipher)

headerStore := header.NewStore(state.encryptedEncoder)

cookieStore, err := cookie.NewStore(func() cookie.Options {
return cookie.Options{
Name: cfg.Options.CookieName,
Expand All @@ -128,7 +120,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
}

state.sessionStore = cookieStore
state.sessionLoaders = []sessions.SessionLoader{headerStore, cookieStore}
state.sessionLoaders = []sessions.SessionLoader{cookieStore}
state.jwk = new(jose.JSONWebKeySet)
signingKey, err := cfg.Options.GetSigningKey()
if err != nil {
Expand Down
121 changes: 0 additions & 121 deletions internal/encoding/ecjson/ecjson.go

This file was deleted.

47 changes: 21 additions & 26 deletions internal/sessions/cookie/cookie_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/stretchr/testify/require"

"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)

func TestNewStore(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := ecjson.New(cipher)
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
tests := []struct {
name string
opts *Options
Expand Down Expand Up @@ -58,11 +57,9 @@ func TestNewStore(t *testing.T) {
}

func TestNewCookieLoader(t *testing.T) {
cipher, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
encoder := ecjson.New(cipher)
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
tests := []struct {
name string
opts *Options
Expand Down Expand Up @@ -96,10 +93,9 @@ func TestNewCookieLoader(t *testing.T) {
}

func TestStore_SaveSession(t *testing.T) {
c, err := cryptutil.NewAEADCipher(cryptutil.NewKey())
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)

hugeString := make([]byte, 4097)
if _, err := rand.Read(hugeString); err != nil {
Expand All @@ -113,13 +109,13 @@ func TestStore_SaveSession(t *testing.T) {
wantErr bool
wantLoadErr bool
}{
{"good", &sessions.State{ID: "xyz"}, ecjson.New(c), ecjson.New(c), false, false},
{"good", &sessions.State{ID: "xyz"}, encoder, encoder, false, false},
{"bad cipher", &sessions.State{ID: "xyz"}, nil, nil, true, true},
{"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, ecjson.New(c), ecjson.New(c), false, false},
{"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, ecjson.New(c), true, true},
{"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, ecjson.New(c), true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, ecjson.New(c), false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, ecjson.New(c), false, true},
{"huge cookie", &sessions.State{ID: "xyz", Subject: fmt.Sprintf("%x", hugeString)}, encoder, encoder, false, false},
{"marshal error", &sessions.State{ID: "xyz"}, mock.Encoder{MarshalError: errors.New("error")}, encoder, true, true},
{"nil encoder cannot save non string type", &sessions.State{ID: "xyz"}, nil, encoder, true, true},
{"good marshal string directly", cryptutil.NewBase64Key(), nil, encoder, false, true},
{"good marshal bytes directly", cryptutil.NewKey(), nil, encoder, false, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -148,14 +144,13 @@ func TestStore_SaveSession(t *testing.T) {
r.AddCookie(cookie)
}

enc := ecjson.New(c)
jwt, err := s.LoadSession(r)
if (err != nil) != tt.wantLoadErr {
t.Errorf("LoadSession() error = %v, wantErr %v", err, tt.wantLoadErr)
return
}
var state sessions.State
enc.Unmarshal([]byte(jwt), &state)
encoder.Unmarshal([]byte(jwt), &state)

cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(sessions.State{}),
Expand Down
14 changes: 6 additions & 8 deletions internal/sessions/cookie/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"strings"
"testing"

"github.com/pomerium/pomerium/internal/sessions"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"

"github.com/pomerium/pomerium/internal/encoding/ecjson"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/pkg/cryptutil"
)

Expand Down Expand Up @@ -49,11 +49,9 @@ func TestVerifier(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cipher, err := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
encoder := ecjson.New(cipher)
if err != nil {
t.Fatal(err)
}
key := cryptutil.NewKey()
encoder, err := jws.NewHS256Signer(key)
require.NoError(t, err)
encSession, err := encoder.Marshal(&tt.state)
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit 75634df

Please sign in to comment.