diff --git a/CHANGELOG.md b/CHANGELOG.md index b3db5c4f..387de4c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [unreleased] +- Adds test to verify that session container uses overridden functions ### Added - Adds with-go-zero example: https://github.com/supertokens/supertokens-golang/issues/157 diff --git a/recipe/session/session_test.go b/recipe/session/session_test.go index 7bb85cfc..2ab24175 100644 --- a/recipe/session/session_test.go +++ b/recipe/session/session_test.go @@ -1013,3 +1013,65 @@ func TestSignoutWorksAfterSessionDeletedOnBackend(t *testing.T) { assert.Equal(t, cookieData["idRefreshTokenFromCookie"], "") assert.Equal(t, cookieData["idRefreshTokenFromHeader"], "remove") } + +func TestSessionContainerOverride(t *testing.T) { + customAntiCsrfVal := "VIA_TOKEN" + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + APIDomain: "api.supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&sessmodels.TypeInput{ + AntiCsrf: &customAntiCsrfVal, + Override: &sessmodels.OverrideStruct{ + Functions: func(originalImplementation sessmodels.RecipeInterface) sessmodels.RecipeInterface { + oGetSessionInformation := *originalImplementation.GetSessionInformation + nGetSessionInformation := func(sessionHandle string, userContext supertokens.UserContext) (*sessmodels.SessionInformation, error) { + info, err := oGetSessionInformation(sessionHandle, userContext) + if err != nil { + return nil, err + } + info.SessionData = map[string]interface{}{ + "test": 1, + } + return info, nil + } + *originalImplementation.GetSessionInformation = nGetSessionInformation + return originalImplementation + }, + }, + }), + }, + } + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + res := MockResponseWriter{} + session, err := CreateNewSession(res, "testId", map[string]interface{}{}, map[string]interface{}{}) + assert.NoError(t, err) + + data, err := session.GetSessionData() + assert.NoError(t, err) + + assert.Equal(t, 1, data["test"]) +} + +type MockResponseWriter struct{} + +func (mw MockResponseWriter) Header() http.Header { + return http.Header{} +} +func (mw MockResponseWriter) Write([]byte) (int, error) { + return 0, nil +} +func (mw MockResponseWriter) WriteHeader(statusCode int) { +}