Skip to content

Commit

Permalink
feat: add token to session
Browse files Browse the repository at this point in the history
  • Loading branch information
aeneasr committed Aug 25, 2020
1 parent 8bfd5f2 commit 08c8c78
Show file tree
Hide file tree
Showing 26 changed files with 196 additions and 59 deletions.
4 changes: 4 additions & 0 deletions Makefile
Expand Up @@ -121,3 +121,7 @@ migrations-sync: .bin/cli
.PHONY: migrations-render
migrations-render: .bin/cli
cli dev pop migration render persistence/sql/migrations/templates persistence/sql/migrations/sql

.PHONY: migrations-render-replace
migrations-render-replace: .bin/cli
cli dev pop migration render -r persistence/sql/migrations/templates persistence/sql/migrations/sql
4 changes: 2 additions & 2 deletions internal/testhelpers/handler_mock.go
Expand Up @@ -35,7 +35,7 @@ func MockSetSession(t *testing.T, reg mockDeps, conf configuration.Provider) htt
i := identity.NewIdentity(configuration.DefaultIdentityTraitsSchemaID)
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(context.Background(), i))

require.NoError(t, reg.SessionManager().CreateToRequest(context.Background(), w, r, session.NewSession(i, conf, time.Now().UTC())))
require.NoError(t, reg.SessionManager().CreateAndIssueCookie(context.Background(), w, r, session.NewSession(i, conf, time.Now().UTC())))

w.WriteHeader(http.StatusOK)
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func MockSessionCreateHandlerWithIdentity(t *testing.T, reg mockDeps, i *identit
require.Len(t, inserted.Credentials, len(i.Credentials))

return func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
require.NoError(t, reg.SessionManager().SaveToRequest(context.Background(), w, r, &sess))
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), w, r, &sess))
}, &sess
}

Expand Down
Empty file.
Empty file.
Empty file.
@@ -0,0 +1 @@
ALTER TABLE "sessions" DROP COLUMN "token";COMMIT TRANSACTION;BEGIN TRANSACTION;
@@ -0,0 +1,8 @@
DELETE FROM sessions;
ALTER TABLE "sessions" ADD COLUMN "token" VARCHAR (32);COMMIT TRANSACTION;BEGIN TRANSACTION;
ALTER TABLE "sessions" RENAME COLUMN "token" TO "_token_tmp";COMMIT TRANSACTION;BEGIN TRANSACTION;
ALTER TABLE "sessions" ADD COLUMN "token" VARCHAR (32);COMMIT TRANSACTION;BEGIN TRANSACTION;
UPDATE "sessions" SET "token" = "_token_tmp";COMMIT TRANSACTION;BEGIN TRANSACTION;
ALTER TABLE "sessions" DROP COLUMN "_token_tmp";COMMIT TRANSACTION;BEGIN TRANSACTION;
CREATE UNIQUE INDEX "sessions_token_uq_idx" ON "sessions" (token);COMMIT TRANSACTION;BEGIN TRANSACTION;
CREATE INDEX "sessions_token_idx" ON "sessions" (token);COMMIT TRANSACTION;BEGIN TRANSACTION;
@@ -0,0 +1 @@
ALTER TABLE `sessions` DROP COLUMN `token`;
@@ -0,0 +1,5 @@
DELETE FROM sessions;
ALTER TABLE `sessions` ADD COLUMN `token` VARCHAR (32);
ALTER TABLE `sessions` MODIFY `token` VARCHAR (32);
CREATE UNIQUE INDEX `sessions_token_uq_idx` ON `sessions` (`token`);
CREATE INDEX `sessions_token_idx` ON `sessions` (`token`);
@@ -0,0 +1 @@
ALTER TABLE "sessions" DROP COLUMN "token";
@@ -0,0 +1,5 @@
DELETE FROM sessions;
ALTER TABLE "sessions" ADD COLUMN "token" VARCHAR (32);
ALTER TABLE "sessions" ALTER COLUMN "token" TYPE VARCHAR (32), ALTER COLUMN "token" DROP NOT NULL;
CREATE UNIQUE INDEX "sessions_token_uq_idx" ON "sessions" (token);
CREATE INDEX "sessions_token_idx" ON "sessions" (token);
@@ -0,0 +1,15 @@
DROP INDEX IF EXISTS "sessions_token_idx";
DROP INDEX IF EXISTS "sessions_token_uq_idx";
ALTER TABLE "sessions" RENAME TO "_sessions_tmp";
CREATE TABLE "sessions" (
"id" TEXT PRIMARY KEY,
"issued_at" DATETIME NOT NULL DEFAULT 'CURRENT_TIMESTAMP',
"expires_at" DATETIME NOT NULL,
"authenticated_at" DATETIME NOT NULL,
"identity_id" char(36) NOT NULL,
"created_at" DATETIME NOT NULL,
"updated_at" DATETIME NOT NULL,
FOREIGN KEY (identity_id) REFERENCES identities (id) ON UPDATE NO ACTION ON DELETE CASCADE
);
INSERT INTO "sessions" (id, issued_at, expires_at, authenticated_at, identity_id, created_at, updated_at) SELECT id, issued_at, expires_at, authenticated_at, identity_id, created_at, updated_at FROM "_sessions_tmp";
DROP TABLE "_sessions_tmp";
@@ -0,0 +1,18 @@
DELETE FROM sessions;
ALTER TABLE "sessions" ADD COLUMN "token" TEXT;
ALTER TABLE "sessions" RENAME TO "_sessions_tmp";
CREATE TABLE "sessions" (
"id" TEXT PRIMARY KEY,
"issued_at" DATETIME NOT NULL DEFAULT 'CURRENT_TIMESTAMP',
"expires_at" DATETIME NOT NULL,
"authenticated_at" DATETIME NOT NULL,
"identity_id" char(36) NOT NULL,
"created_at" DATETIME NOT NULL,
"updated_at" DATETIME NOT NULL,
"token" TEXT,
FOREIGN KEY (identity_id) REFERENCES identities (id) ON UPDATE NO ACTION ON DELETE CASCADE
);
INSERT INTO "sessions" (id, issued_at, expires_at, authenticated_at, identity_id, created_at, updated_at, token) SELECT id, issued_at, expires_at, authenticated_at, identity_id, created_at, updated_at, token FROM "_sessions_tmp";
DROP TABLE "_sessions_tmp";
CREATE UNIQUE INDEX "sessions_token_uq_idx" ON "sessions" (token);
CREATE INDEX "sessions_token_idx" ON "sessions" (token);
@@ -0,0 +1 @@
drop_column("sessions", "token")
@@ -0,0 +1,7 @@
sql("DELETE FROM sessions;")

add_column("sessions", "token", "string", {"size": 32, "null": true})
change_column("sessions", "token", "string", {"size": 32, "null": false})

add_index("sessions", "token", {"unique": true, "name": "sessions_token_uq_idx"})
add_index("sessions", "token", {"name": "sessions_token_idx" })
26 changes: 18 additions & 8 deletions persistence/sql/persister_session.go
Expand Up @@ -14,14 +14,9 @@ var _ session.Persister = new(Persister)

func (p *Persister) GetSession(ctx context.Context, sid uuid.UUID) (*session.Session, error) {
var s session.Session
if err := p.GetConnection(ctx).Find(&s, sid); err != nil {
if err := p.GetConnection(ctx).Eager("Identity").Find(&s, sid); err != nil {
return nil, sqlcon.HandleError(err)
}
i, err := p.GetIdentity(ctx, s.IdentityID)
if err != nil {
return nil, err
}
s.Identity = i
return &s, nil
}

Expand All @@ -33,8 +28,23 @@ func (p *Persister) DeleteSession(ctx context.Context, sid uuid.UUID) error {
return p.GetConnection(ctx).Destroy(&session.Session{ID: sid}) // This must not be eager or identities will be created / updated
}

func (p *Persister) DeleteSessionsFor(ctx context.Context, sid uuid.UUID) error {
if err := p.GetConnection(ctx).RawQuery("DELETE FROM sessions WHERE identity_id =?", sid).Exec(); err != nil {
func (p *Persister) DeleteSessionsFor(ctx context.Context, identityID uuid.UUID) error {
if err := p.GetConnection(ctx).RawQuery("DELETE FROM sessions WHERE identity_id = ?", identityID).Exec(); err != nil {
return sqlcon.HandleError(err)
}
return nil
}

func (p *Persister) GetSessionFromToken(ctx context.Context, token string) (*session.Session, error) {
var s session.Session
if err := p.GetConnection(ctx).Eager("Identity").Where("token = ?", token).First(&s); err != nil {
return nil, sqlcon.HandleError(err)
}
return &s, nil
}

func (p *Persister) DeleteSessionFromToken(ctx context.Context, token string) error {
if err := p.GetConnection(ctx).RawQuery("DELETE FROM sessions WHERE token = ?", token).Exec(); err != nil {
return sqlcon.HandleError(err)
}
return nil
Expand Down
38 changes: 37 additions & 1 deletion selfservice/flow/login/hook.go
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/ory/kratos/driver/configuration"
"github.com/ory/kratos/identity"
"github.com/ory/kratos/selfservice/flow"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
)
Expand All @@ -31,6 +32,7 @@ type (
executorDependencies interface {
HooksProvider
session.ManagementProvider
session.PersistenceProvider
x.WriterProvider
x.LoggingProvider
}
Expand All @@ -43,6 +45,31 @@ type (
}
)

// Contains the Session and Session Token for API Based Authentication
//
// swagger:model sessionTokenContainer
type SessionTokenContainer struct {
// The Session Token
//
// A session token is equivalent to a session cookie, but it can be sent in the HTTP Authorization
// Header:
//
// Authorization: bearer <session-token>
//
// The session token is only issued for API flows, not for Browser flows!
//
// required: true
Token string `json:"session_token"`

// The Session
//
// The session contains information about the user, the session device, and so on.
// This is only available for API flows, not for Browser flows!
//
// required: true
Session *session.Session `json:"session"`
}

func NewHookExecutor(d executorDependencies, c configuration.Provider) *HookExecutor {
return &HookExecutor{
d: d,
Expand All @@ -63,7 +90,16 @@ func (e *HookExecutor) PostLoginHook(w http.ResponseWriter, r *http.Request, ct
}
}

if err := e.d.SessionManager().CreateToRequest(r.Context(), w, r, s); err != nil {
if a.Type == flow.TypeAPI {
if err := e.d.SessionPersister().CreateSession(r.Context(), s); err != nil {
return errors.WithStack(err)
}

e.d.Writer().Write(w, r, &SessionTokenContainer{Session: s, Token: s.Token})
return nil
}

if err := e.d.SessionManager().CreateAndIssueCookie(r.Context(), w, r, s); err != nil {
return errors.WithStack(err)
}

Expand Down
4 changes: 2 additions & 2 deletions selfservice/hook/session_issuer.go
Expand Up @@ -36,13 +36,13 @@ func (e *SessionIssuer) ExecutePostRegistrationPostPersistHook(w http.ResponseWr
if err := e.r.SessionPersister().CreateSession(r.Context(), s); err != nil {
return err
}
return e.r.SessionManager().SaveToRequest(r.Context(), w, r, s)
return e.r.SessionManager().IssueCookie(r.Context(), w, r, s)
}

func (e *SessionIssuer) ExecuteLoginPostHook(w http.ResponseWriter, r *http.Request, a *login.Flow, s *session.Session) error {
s.AuthenticatedAt = time.Now().UTC()
if err := e.r.SessionPersister().CreateSession(r.Context(), s); err != nil {
return err
}
return e.r.SessionManager().SaveToRequest(r.Context(), w, r, s)
return e.r.SessionManager().IssueCookie(r.Context(), w, r, s)
}
20 changes: 1 addition & 19 deletions selfservice/strategy/password/login.go
Expand Up @@ -58,24 +58,6 @@ type completeSelfServiceLoginFlowWithPasswordMethod struct {
LoginFormPayload
}

// Completed Login Flow with Username/Password Method for API Clients Response
//
// swagger:response completeSelfServiceLoginFlowWithPasswordMethodResponse
type completeSelfServiceLoginFlowWithPasswordResponse struct {
// The Session Token
//
// A session token is equivalent to a session cookie, but it can be sent in the HTTP Authorization
// Header:
//
// Authorization: bearer <session-token>
//
// The session token is only issued for API flows, not for Browser flows!
//
// in: body
// required: true
SessionToken string `json:"session_token"`
}

// swagger:route GET /self-service/login/methods/password public completeSelfServiceLoginFlowWithPasswordMethod
//
// Complete Login Flow with Username/Email Password Method
Expand Down Expand Up @@ -103,7 +85,7 @@ type completeSelfServiceLoginFlowWithPasswordResponse struct {
// - application/json
//
// Responses:
// 200: completeSelfServiceLoginFlowWithPasswordMethodResponse
// 200: selfserviceLoginSessionContainer
// 302: emptyResponse
// 400: genericError
// 500: genericError
Expand Down
7 changes: 4 additions & 3 deletions selfservice/strategy/password/login_test.go
Expand Up @@ -406,7 +406,7 @@ func TestCompleteLogin(t *testing.T) {
}

lr := nlr(time.Hour, isAPI)
return fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusBadRequest))
return fakeRequest(t, lr, isAPI, payload, nil, nil, expectStatusCode(isAPI, http.StatusOK))
}

t.Run("type=browser", func(t *testing.T) {
Expand All @@ -417,8 +417,9 @@ func TestCompleteLogin(t *testing.T) {

t.Run("type=api", func(t *testing.T) {
res, body := run(t, true)
require.Contains(t, res.Request.URL.Path, "return-ts", "%s", res.Request.URL.String())
assert.Equal(t, identifier, gjson.GetBytes(body, "identity.traits.subject").String(), "%s", body)
require.Contains(t, res.Request.URL.Path, password.RouteLogin, "%s", res.Request.URL.String())
assert.Equal(t, identifier, gjson.GetBytes(body, "session.identity.traits.subject").String(), "%s", body)
assert.NotEmpty(t, gjson.GetBytes(body, "session_token").String(), "%s", body)
})
})

Expand Down
2 changes: 1 addition & 1 deletion selfservice/strategy/recoverytoken/strategy.go
Expand Up @@ -344,7 +344,7 @@ func (s *Strategy) issueSession(w http.ResponseWriter, r *http.Request, req *rec
}

sess := session.NewSession(recovered, s.c, time.Now().UTC())
if err := s.d.SessionManager().CreateToRequest(r.Context(), w, r, sess); err != nil {
if err := s.d.SessionManager().CreateAndIssueCookie(r.Context(), w, r, sess); err != nil {
s.handleError(w, r, req, err)
return
}
Expand Down
13 changes: 9 additions & 4 deletions session/manager.go
Expand Up @@ -17,10 +17,15 @@ var (

// Manager handles identity sessions.
type Manager interface {
CreateToRequest(context.Context, http.ResponseWriter, *http.Request, *Session) error

// SaveToRequest creates an HTTP session using cookies.
SaveToRequest(context.Context, http.ResponseWriter, *http.Request, *Session) error
// CreateAndIssueCookie stores a session in the database and issues a cookie by calling IssueCookie.
//
// Also regenerates CSRF tokens due to assumed principal change.
CreateAndIssueCookie(context.Context, http.ResponseWriter, *http.Request, *Session) error

// IssueCookie issues a cookie for the given session.
//
// Also regenerates CSRF tokens due to assumed principal change.
IssueCookie(context.Context, http.ResponseWriter, *http.Request, *Session) error

// FetchFromRequest creates an HTTP session using cookies.
FetchFromRequest(context.Context, *http.Request) (*Session, error)
Expand Down
6 changes: 3 additions & 3 deletions session/manager_http.go
Expand Up @@ -48,19 +48,19 @@ func NewManagerHTTP(
}
}

func (s *ManagerHTTP) CreateToRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, ss *Session) error {
func (s *ManagerHTTP) CreateAndIssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, ss *Session) error {
if err := s.r.SessionPersister().CreateSession(ctx, ss); err != nil {
return err
}

if err := s.SaveToRequest(ctx, w, r, ss); err != nil {
if err := s.IssueCookie(ctx, w, r, ss); err != nil {
return err
}

return nil
}

func (s *ManagerHTTP) SaveToRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) error {
func (s *ManagerHTTP) IssueCookie(ctx context.Context, w http.ResponseWriter, r *http.Request, session *Session) error {
_ = s.r.CSRFHandler().RegenerateToken(w, r)
cookie, _ := s.r.CookieManager().Get(r, s.cookieName)
if s.c.SessionDomain() != "" {
Expand Down
4 changes: 2 additions & 2 deletions session/manager_http_test.go
Expand Up @@ -40,7 +40,7 @@ func TestManagerHTTP(t *testing.T) {
mock := new(mockCSRFHandler)
reg.WithCSRFHandler(mock)

require.NoError(t, reg.SessionManager().SaveToRequest(context.Background(), httptest.NewRecorder(), new(http.Request), new(session.Session)))
require.NoError(t, reg.SessionManager().IssueCookie(context.Background(), httptest.NewRecorder(), new(http.Request), new(session.Session)))
assert.Equal(t, 1, mock.c)
})

Expand All @@ -53,7 +53,7 @@ func TestManagerHTTP(t *testing.T) {
var s *session.Session
rp := x.NewRouterPublic()
rp.GET("/session/set", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
require.NoError(t, reg.SessionManager().CreateToRequest(r.Context(), w, r, s))
require.NoError(t, reg.SessionManager().CreateAndIssueCookie(r.Context(), w, r, s))
w.WriteHeader(http.StatusOK)
})

Expand Down

0 comments on commit 08c8c78

Please sign in to comment.