Skip to content

Commit

Permalink
[performance] add user cache and database (#879)
Browse files Browse the repository at this point in the history
* go fmt

* add + use user cache and database

* fix import

* update tests

* remove unused relation
  • Loading branch information
tsmethurst committed Oct 3, 2022
1 parent f7af7c0 commit 56f53a2
Show file tree
Hide file tree
Showing 21 changed files with 490 additions and 70 deletions.
30 changes: 14 additions & 16 deletions cmd/gotosocial/action/admin/account/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ import (

"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/validate"
"golang.org/x/crypto/bcrypt"
)
Expand Down Expand Up @@ -92,8 +90,8 @@ var Confirm action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

Expand Down Expand Up @@ -130,16 +128,16 @@ var Promote action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"admin", "updated_at"}
admin := true
u.Admin = &admin
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand All @@ -166,16 +164,16 @@ var Demote action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"admin", "updated_at"}
admin := false
u.Admin = &admin
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand All @@ -202,16 +200,16 @@ var Disable action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

updatingColumns := []string{"disabled", "updated_at"}
disabled := true
u.Disabled = &disabled
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand Down Expand Up @@ -252,8 +250,8 @@ var Password action.GTSAction = func(ctx context.Context) error {
return err
}

u := &gtsmodel.User{}
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err != nil {
return err
}

Expand All @@ -265,7 +263,7 @@ var Password action.GTSAction = func(ctx context.Context) error {
updatingColumns := []string{"encrypted_password", "updated_at"}
u.EncryptedPassword = string(pw)
u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err
}

Expand Down
8 changes: 4 additions & 4 deletions internal/api/client/auth/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return
}

user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
Expand Down Expand Up @@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return
}

user := &gtsmodel.User{}
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err != nil {
m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode
Expand Down
10 changes: 6 additions & 4 deletions internal/api/client/auth/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")

user := suite.testUsers["unconfirmed_account"]
account := suite.testAccounts["unconfirmed_account"]
user := &gtsmodel.User{}
account := &gtsmodel.Account{}

*user = *suite.testUsers["unconfirmed_account"]
*account = *suite.testAccounts["unconfirmed_account"]

testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID)
Expand All @@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)

updatingColumns = append(updatingColumns, "updated_at")
user.UpdatedAt = time.Now()
err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
_, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err)
Expand Down
3 changes: 1 addition & 2 deletions internal/api/client/auth/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i

// see if we already have a user for this email address
// if so, we don't need to continue + create one
user := &gtsmodel.User{}
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
if err == nil {
return user, nil
}
Expand Down
6 changes: 2 additions & 4 deletions internal/api/client/auth/signin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt"
)

Expand Down Expand Up @@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err)
}

user := &gtsmodel.User{}
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
user, err := m.db.GetUserByEmailAddress(ctx, email)
if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err)
}
Expand Down
23 changes: 13 additions & 10 deletions internal/api/security/tokencheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())

// fetch user for this token
user := &gtsmodel.User{}
if err := m.db.GetByID(ctx, userID, user); err != nil {
user, err := m.db.GetUserByID(ctx, userID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for user with id %s: %s", userID, err)
return
Expand All @@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user)

// fetch account for this token
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
if user.Account == nil {
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
if err != nil {
if err != db.ErrNoEntries {
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
return
}
log.Warnf("no account found for userID %s", userID)
return
}
log.Warnf("no account found for userID %s", userID)
return
user.Account = acct
}

if !acct.SuspendedAt.IsZero() {
if !user.Account.SuspendedAt.IsZero() {
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return
}

c.Set(oauth.SessionAuthorizedAccount, acct)
c.Set(oauth.SessionAuthorizedAccount, user.Account)
}

// check for application token
Expand Down
141 changes: 141 additions & 0 deletions internal/cache/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package cache

import (
"time"

"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)

// UserCache is a cache wrapper to provide lookups for gtsmodel.User
type UserCache struct {
cache cache.LookupCache[string, string, *gtsmodel.User]
}

// NewUserCache returns a new instantiated UserCache object
func NewUserCache() *UserCache {
c := &UserCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("accountid")
lm.RegisterLookup("email")
lm.RegisterLookup("unconfirmedemail")
lm.RegisterLookup("confirmationtoken")
},

AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Set("accountid", user.AccountID, user.ID)
if email := user.Email; email != "" {
lm.Set("email", email, user.ID)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Set("confirmationtoken", confirmationToken, user.ID)
}
},

DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Delete("accountid", user.AccountID)
if email := user.Email; email != "" {
lm.Delete("email", email)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Delete("unconfirmedemail", unconfirmedEmail)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Delete("confirmationtoken", confirmationToken)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}

// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
return c.cache.Get(id)
}

// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
return c.cache.GetBy("accountid", accountID)
}

// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
return c.cache.GetBy("email", email)
}

// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
return c.cache.GetBy("confirmationtoken", token)
}

// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
func (c *UserCache) Put(user *gtsmodel.User) {
if user == nil || user.ID == "" {
panic("invalid user")
}
c.cache.Set(user.ID, copyUser(user))
}

// Invalidate invalidates one user from the cache using the ID of the user as key.
func (c *UserCache) Invalidate(userID string) {
c.cache.Invalidate(userID)
}

func copyUser(user *gtsmodel.User) *gtsmodel.User {
return &gtsmodel.User{
ID: user.ID,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
AccountID: user.AccountID,
Account: nil,
EncryptedPassword: user.EncryptedPassword,
SignUpIP: user.SignUpIP,
CurrentSignInAt: user.CurrentSignInAt,
CurrentSignInIP: user.CurrentSignInIP,
LastSignInAt: user.LastSignInAt,
LastSignInIP: user.LastSignInIP,
SignInCount: user.SignInCount,
InviteID: user.InviteID,
ChosenLanguages: user.ChosenLanguages,
FilteredLanguages: user.FilteredLanguages,
Locale: user.Locale,
CreatedByApplicationID: user.CreatedByApplicationID,
CreatedByApplication: nil,
LastEmailedAt: user.LastEmailedAt,
ConfirmationToken: user.ConfirmationToken,
ConfirmationSentAt: user.ConfirmationSentAt,
ConfirmedAt: user.ConfirmedAt,
UnconfirmedEmail: user.UnconfirmedEmail,
Moderator: copyBoolPtr(user.Moderator),
Admin: copyBoolPtr(user.Admin),
Disabled: copyBoolPtr(user.Disabled),
Approved: copyBoolPtr(user.Approved),
ResetPasswordToken: user.ResetPasswordToken,
ResetPasswordSentAt: user.ResetPasswordSentAt,
}
}
5 changes: 4 additions & 1 deletion internal/db/bundb/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"time"

"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
Expand All @@ -40,7 +41,8 @@ import (
)

type adminDB struct {
conn *DBConn
conn *DBConn
userCache *cache.UserCache
}

func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
Expand Down Expand Up @@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err)
}
a.userCache.Put(u)

return u, nil
}
Expand Down

0 comments on commit 56f53a2

Please sign in to comment.