Skip to content

Commit

Permalink
[bugfix] Fix potential dereference of accounts on own instance (#757)
Browse files Browse the repository at this point in the history
* add GetAccountByUsernameDomain

* simplify search

* add escape to not deref accounts on own domain

* check if local + we have account by ap uri
  • Loading branch information
tsmethurst committed Aug 20, 2022
1 parent 2ca234f commit 570fa7c
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 92 deletions.
15 changes: 15 additions & 0 deletions internal/cache/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func NewAccountCache() *AccountCache {
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("uri")
lm.RegisterLookup("url")
lm.RegisterLookup("usernamedomain")
},

AddLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
Expand All @@ -46,6 +47,7 @@ func NewAccountCache() *AccountCache {
if url := acc.URL; url != "" {
lm.Set("url", url, acc.ID)
}
lm.Set("usernamedomain", usernameDomainKey(acc.Username, acc.Domain), acc.ID)
},

DeleteLookups: func(lm *cache.LookupMap[string, string], acc *gtsmodel.Account) {
Expand All @@ -55,6 +57,7 @@ func NewAccountCache() *AccountCache {
if url := acc.URL; url != "" {
lm.Delete("url", url)
}
lm.Delete("usernamedomain", usernameDomainKey(acc.Username, acc.Domain))
},
})
c.cache.SetTTL(time.Minute*5, false)
Expand All @@ -77,6 +80,10 @@ func (c *AccountCache) GetByURI(uri string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("uri", uri)
}

func (c *AccountCache) GetByUsernameDomain(username string, domain string) (*gtsmodel.Account, bool) {
return c.cache.GetBy("usernamedomain", usernameDomainKey(username, domain))
}

// Put places a account in the cache, ensuring that the object place is a copy for thread-safety
func (c *AccountCache) Put(account *gtsmodel.Account) {
if account == nil || account.ID == "" {
Expand Down Expand Up @@ -135,3 +142,11 @@ func copyAccount(account *gtsmodel.Account) *gtsmodel.Account {
SuspensionOrigin: account.SuspensionOrigin,
}
}

func usernameDomainKey(username string, domain string) string {
u := "@" + username
if domain != "" {
return u + "@" + domain
}
return u
}
4 changes: 4 additions & 0 deletions internal/cache/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (suite *AccountCacheTestSuite) TestAccountCache() {
if account.URL != "" && !ok && !accountIs(account, check) {
suite.Fail("Failed to fetch expected account with URL: %s", account.URL)
}
check, ok = suite.cache.GetByUsernameDomain(account.Username, account.Domain)
if !ok && !accountIs(account, check) {
suite.Fail("Failed to fetch expected account with username/domain: %s/%s", account.Username, account.Domain)
}
}
}

Expand Down
3 changes: 3 additions & 0 deletions internal/db/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Account interface {
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(ctx context.Context, uri string) (*gtsmodel.Account, Error)

// GetAccountByUsernameDomain returns one account with the given username and domain, or an error if something goes wrong.
GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, Error)

// UpdateAccount updates one account by ID.
UpdateAccount(ctx context.Context, account *gtsmodel.Account) (*gtsmodel.Account, Error)

Expand Down
20 changes: 20 additions & 0 deletions internal/db/bundb/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
)
}

func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username string, domain string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
func() (*gtsmodel.Account, bool) {
return a.cache.GetByUsernameDomain(username, domain)
},
func(account *gtsmodel.Account) error {
q := a.newAccountQ(account).Where("account.username = ?", username)

if domain != "" {
q = q.Where("account.domain = ?", domain)
} else {
q = q.Where("account.domain IS NULL")
}

return q.Scan(ctx)
},
)
}

func (a *accountDB) getAccount(ctx context.Context, cacheGet func() (*gtsmodel.Account, bool), dbQuery func(*gtsmodel.Account) error) (*gtsmodel.Account, db.Error) {
// Attempt to fetch cached account
account, cached := cacheGet()
Expand Down
12 changes: 12 additions & 0 deletions internal/db/bundb/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}

func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
testAccount1 := suite.testAccounts["local_account_1"]
account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
suite.NoError(err)
suite.NotNil(account1)

testAccount2 := suite.testAccounts["remote_account_1"]
account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
suite.NoError(err)
suite.NotNil(account2)
}

func (suite *AccountTestSuite) TestUpdateAccount() {
testAccount := suite.testAccounts["local_account_1"]

Expand Down
96 changes: 59 additions & 37 deletions internal/federation/dereferencing/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
Expand Down Expand Up @@ -78,7 +79,10 @@ type GetRemoteAccountParams struct {

// GetRemoteAccount completely dereferences a remote account, converts it to a GtS model account,
// puts or updates it in the database (if necessary), and returns it to a caller.
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (remoteAccount *gtsmodel.Account, err error) {
//
// If a local account is passed into this function for whatever reason (hey, it happens!), then it
// will be returned from the database without making any remote calls.
func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountParams) (foundAccount *gtsmodel.Account, err error) {
/*
In this function we want to retrieve a gtsmodel representation of a remote account, with its proper
accountDomain set, while making as few calls to remote instances as possible to save time and bandwidth.
Expand All @@ -99,23 +103,40 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
from that.
*/

// first check if we can retrieve the account locally just with what we've been given
skipResolve := params.SkipResolve

// this first step checks if we have the
// account in the database somewhere already
switch {
case params.RemoteAccountID != nil:
// try with uri
if a, dbErr := d.db.GetAccountByURI(ctx, params.RemoteAccountID.String()); dbErr == nil {
remoteAccount = a
uri := params.RemoteAccountID
host := uri.Host
if host == config.GetHost() || host == config.GetAccountDomain() {
// this is actually a local account,
// make sure we don't try to resolve
skipResolve = true
}

if a, dbErr := d.db.GetAccountByURI(ctx, uri.String()); dbErr == nil {
foundAccount = a
} else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account with uri %s: %s", uri, err)
}
case params.RemoteAccountUsername != "" && (params.RemoteAccountHost == "" || params.RemoteAccountHost == config.GetHost() || params.RemoteAccountHost == config.GetAccountDomain()):
// either no domain is provided or this seems
// to be a local account, so don't resolve
skipResolve = true

if a, dbErr := d.db.GetLocalAccountByUsername(ctx, params.RemoteAccountUsername); dbErr == nil {
foundAccount = a
} else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account %s: %s", params.RemoteAccountID, err)
err = fmt.Errorf("GetRemoteAccount: database error looking for local account with username %s: %s", params.RemoteAccountUsername, err)
}
case params.RemoteAccountUsername != "" && params.RemoteAccountHost != "":
// try with username/host
a := &gtsmodel.Account{}
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: params.RemoteAccountHost}}
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
remoteAccount = a
if a, dbErr := d.db.GetAccountByUsernameDomain(ctx, params.RemoteAccountUsername, params.RemoteAccountHost); dbErr == nil {
foundAccount = a
} else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and domain %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
}
default:
err = errors.New("GetRemoteAccount: no identifying parameters were set so we cannot get account")
Expand All @@ -125,10 +146,11 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
return
}

if params.SkipResolve {
// if we can't resolve, return already since there's nothing more we can do
if remoteAccount == nil {
err = errors.New("GetRemoteAccount: error retrieving account with skipResolve set true")
if skipResolve {
// if we can't resolve, return already
// since there's nothing more we can do
if foundAccount == nil {
err = errors.New("GetRemoteAccount: couldn't retrieve account locally and won't try to resolve it")
}
return
}
Expand All @@ -141,8 +163,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// ... but we still need the username so we can do a finger for the accountDomain

// check if we had the account stored already and got it earlier
if remoteAccount != nil {
params.RemoteAccountUsername = remoteAccount.Username
if foundAccount != nil {
params.RemoteAccountUsername = foundAccount.Username
} else {
// if we didn't already have it, we have dereference it from remote and just...
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
Expand All @@ -167,8 +189,8 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// already about what the account domain might be; this var will be overwritten later if necessary
var accountDomain string
switch {
case remoteAccount != nil:
accountDomain = remoteAccount.Domain
case foundAccount != nil:
accountDomain = foundAccount.Domain
case params.RemoteAccountID != nil:
accountDomain = params.RemoteAccountID.Host
default:
Expand All @@ -178,7 +200,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// to save on remote calls: only webfinger if we don't have a remoteAccount yet, or if we haven't
// fingered the remote account for at least 2 days; don't finger instance accounts
var fingered time.Time
if remoteAccount == nil || (remoteAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(remoteAccount)) {
if foundAccount == nil || (foundAccount.LastWebfingeredAt.Before(time.Now().Add(webfingerInterval)) && !instanceAccount(foundAccount)) {
accountDomain, params.RemoteAccountID, err = d.fingerRemoteAccount(ctx, params.RequestingUsername, params.RemoteAccountUsername, params.RemoteAccountHost)
if err != nil {
err = fmt.Errorf("GetRemoteAccount: error while fingering: %s", err)
Expand All @@ -187,14 +209,14 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
fingered = time.Now()
}

if !fingered.IsZero() && remoteAccount == nil {
if !fingered.IsZero() && foundAccount == nil {
// if we just fingered and now have a discovered account domain but still no account,
// we should do a final lookup in the database with the discovered username + accountDomain
// to make absolutely sure we don't already have this account
a := &gtsmodel.Account{}
where := []db.Where{{Key: "username", Value: params.RemoteAccountUsername}, {Key: "domain", Value: accountDomain}}
if dbErr := d.db.GetWhere(ctx, where, a); dbErr == nil {
remoteAccount = a
foundAccount = a
} else if dbErr != db.ErrNoEntries {
err = fmt.Errorf("GetRemoteAccount: database error looking for account with username %s and host %s: %s", params.RemoteAccountUsername, params.RemoteAccountHost, err)
return
Expand All @@ -203,7 +225,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar

// we may also have some extra information already, like the account we had in the db, or the
// accountable representation that we dereferenced from remote
if remoteAccount == nil {
if foundAccount == nil {
// we still don't have the account, so deference it if we didn't earlier
if accountable == nil {
accountable, err = d.dereferenceAccountable(ctx, params.RequestingUsername, params.RemoteAccountID)
Expand All @@ -214,7 +236,7 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
}

// then convert
remoteAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
foundAccount, err = d.typeConverter.ASRepresentationToAccount(ctx, accountable, accountDomain, false)
if err != nil {
err = fmt.Errorf("GetRemoteAccount: error converting accountable to account: %s", err)
return
Expand All @@ -227,18 +249,18 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
err = fmt.Errorf("GetRemoteAccount: error generating new id for account: %s", err)
return
}
remoteAccount.ID = ulid
foundAccount.ID = ulid

_, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
_, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
if err != nil {
err = fmt.Errorf("GetRemoteAccount: error populating further account fields: %s", err)
return
}

remoteAccount.LastWebfingeredAt = fingered
remoteAccount.UpdatedAt = time.Now()
foundAccount.LastWebfingeredAt = fingered
foundAccount.UpdatedAt = time.Now()

err = d.db.Put(ctx, remoteAccount)
err = d.db.Put(ctx, foundAccount)
if err != nil {
err = fmt.Errorf("GetRemoteAccount: error putting new account: %s", err)
return
Expand All @@ -248,9 +270,9 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
}

// we had the account already, but now we know the account domain, so update it if it's different
if !strings.EqualFold(remoteAccount.Domain, accountDomain) {
remoteAccount.Domain = accountDomain
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
if !strings.EqualFold(foundAccount.Domain, accountDomain) {
foundAccount.Domain = accountDomain
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
if err != nil {
err = fmt.Errorf("GetRemoteAccount: error updating account: %s", err)
return
Expand All @@ -260,20 +282,20 @@ func (d *deref) GetRemoteAccount(ctx context.Context, params GetRemoteAccountPar
// make sure the account fields are populated before returning:
// the caller might want to block until everything is loaded
var fieldsChanged bool
fieldsChanged, err = d.populateAccountFields(ctx, remoteAccount, params.RequestingUsername, params.Blocking)
fieldsChanged, err = d.populateAccountFields(ctx, foundAccount, params.RequestingUsername, params.Blocking)
if err != nil {
return nil, fmt.Errorf("GetRemoteAccount: error populating remoteAccount fields: %s", err)
}

var fingeredChanged bool
if !fingered.IsZero() {
fingeredChanged = true
remoteAccount.LastWebfingeredAt = fingered
foundAccount.LastWebfingeredAt = fingered
}

if fieldsChanged || fingeredChanged {
remoteAccount.UpdatedAt = time.Now()
remoteAccount, err = d.db.UpdateAccount(ctx, remoteAccount)
foundAccount.UpdatedAt = time.Now()
foundAccount, err = d.db.UpdateAccount(ctx, foundAccount)
if err != nil {
return nil, fmt.Errorf("GetRemoteAccount: error updating remoteAccount: %s", err)
}
Expand Down

0 comments on commit 570fa7c

Please sign in to comment.