Skip to content

Commit

Permalink
feat: add basic user banning functionality (#343)
Browse files Browse the repository at this point in the history
* test: add token test for happy paths

* chore: update user_instance_id_email_idx

* fix: add ban_until column in users

* fix: update user model with ban_until column

* fix: check if user is banned in /token

* test: update token tests

* fix tests

* check if user is banned on verify

* check if user is banned in external provider callback

* fix: add ban_duration to admin update user endpoint

* docs: update readme

* rename ban_until to banned_until

* use 'unban' instead of '0' to unban user

* change unban to none & update readme
  • Loading branch information
kangmingtay authored Jan 19, 2022
1 parent 467fb94 commit cc94302
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 13 deletions.
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,30 @@ Returns the publicly available settings for this gotrue instance.
}
```

### **POST, PUT /admin/users/<user_id>**

Creates (POST) or Updates (PUT) the user based on the `user_id` specified. The `ban_duration` field accepts the following time units: "ns", "us", "ms", "s", "m", "h". See [`time.ParseDuration`](https://pkg.go.dev/time#ParseDuration) for more details on the format used.

```js
headers:
{
"Authorization": "Bearer eyJhbGciOiJI...M3A90LCkxxtX9oNP9KZO" // admin role required
}

body:
{
"role": "test-user",
"email": "email@example.com",
"phone": "12345678",
"password": "secret", // only if type = signup
"email_confirm": true,
"phone_confirm": true,
"user_metadata": {},
"app_metadata": {},
"ban_duration": "24h" or "none" // to unban a user
}
```

### **POST /admin/generate_link**

Returns the corresponding email action link based on the type specified.
Expand Down
34 changes: 34 additions & 0 deletions api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"net/http"
"strings"
"time"

"github.com/go-chi/chi"
"github.com/gofrs/uuid"
Expand All @@ -23,6 +25,7 @@ type adminUserParams struct {
PhoneConfirm bool `json:"phone_confirm"`
UserMetaData map[string]interface{} `json:"user_metadata"`
AppMetaData map[string]interface{} `json:"app_metadata"`
BanDuration string `json:"ban_duration"`
}

func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, error) {
Expand Down Expand Up @@ -156,6 +159,22 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

if params.BanDuration != "" {
if params.BanDuration == "none" {
user.BannedUntil = nil
} else {
duration, terr := time.ParseDuration(params.BanDuration)
if terr != nil {
return badRequestError("Invalid format for ban_duration: %v", terr)
}
t := time.Now().Add(duration)
user.BannedUntil = &t
}
if terr := user.UpdateBannedUntil(tx); terr != nil {
return terr
}
}

if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserModifiedAction, map[string]interface{}{
"user_id": user.ID,
"user_email": user.Email,
Expand All @@ -170,6 +189,9 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if errors.Is(err, invalidPasswordLengthError(config)) {
return err
}
if strings.Contains(err.Error(), "Invalid format for ban_duration") {
return err
}
return internalServerError("Error updating user").WithInternalError(err)
}

Expand Down Expand Up @@ -241,6 +263,15 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
user.AppMetaData["provider"] = "email"
user.AppMetaData["providers"] = []string{"email"}

if params.BanDuration != "" {
duration, terr := time.ParseDuration(params.BanDuration)
if terr != nil {
return badRequestError("Invalid format for ban_duration: %v", terr)
}
t := time.Now().Add(duration)
user.BannedUntil = &t
}

err = a.db.Transaction(func(tx *storage.Connection) error {
if terr := models.NewAuditLogEntry(tx, instanceID, adminUser, models.UserSignedUpAction, map[string]interface{}{
"user_id": user.ID,
Expand Down Expand Up @@ -278,6 +309,9 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
})

if err != nil {
if strings.Contains(err.Error(), "Invalid format for ban_duration") {
return err
}
return internalServerError("Database error creating new user").WithInternalError(err)
}

Expand Down
39 changes: 39 additions & 0 deletions api/admin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,20 @@ func (ts *AdminTestSuite) TestAdminUserCreate() {
"isAuthenticated": false,
},
},
{
desc: "Ban created user",
params: map[string]interface{}{
"email": "test4@example.com",
"phone": "",
"password": "test1",
"ban_duration": "24h",
},
expected: map[string]interface{}{
"email": "test4@example.com",
"phone": "",
"isAuthenticated": true,
},
},
}

for _, c := range cases {
Expand Down Expand Up @@ -379,6 +393,7 @@ func (ts *AdminTestSuite) TestAdminUserUpdate() {
"user_metadata": map[string]interface{}{
"name": "David",
},
"ban_duration": "24h",
}))

// Setup request
Expand All @@ -401,6 +416,7 @@ func (ts *AdminTestSuite) TestAdminUserUpdate() {
assert.Len(ts.T(), data.AppMetaData["roles"], 2)
assert.Contains(ts.T(), data.AppMetaData["roles"], "writer")
assert.Contains(ts.T(), data.AppMetaData["roles"], "editor")
assert.NotNil(ts.T(), data.BannedUntil)
}

// TestAdminUserUpdate tests API /admin/user route (UPDATE) as system user
Expand Down Expand Up @@ -472,6 +488,29 @@ func (ts *AdminTestSuite) TestAdminUserUpdatePasswordFailed() {
})
}

func (ts *AdminTestSuite) TestAdminUserUpdateBannedUntilFailed() {
u, err := models.NewUser(ts.instanceID, "test1@example.com", "test", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error making new user")
require.NoError(ts.T(), ts.API.db.Create(u), "Error creating user")

var updateEndpoint = fmt.Sprintf("/admin/users/%s", u.ID)
ts.Config.PasswordMinLength = 6
ts.Run("Incorrect format for ban_duration", func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"ban_duration": "24",
}))

// Setup request
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPut, updateEndpoint, &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", ts.token))

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusBadRequest, w.Code)
})
}

// TestAdminUserDelete tests API /admin/user route (DELETE)
func (ts *AdminTestSuite) TestAdminUserDelete() {
u, err := models.NewUser(ts.instanceID, "test-delete@example.com", "test", ts.Config.JWT.Aud, nil)
Expand Down
4 changes: 4 additions & 0 deletions api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
}
}

if user.IsBanned() {
return unauthorizedError("User is unauthorized")
}

if !user.IsConfirmed() {
if !emailData.Verified && !config.Mailer.Autoconfirm {
mailer := a.Mailer(ctx)
Expand Down
23 changes: 23 additions & 0 deletions api/external_github_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"time"

jwt "github.com/golang-jwt/jwt"
"github.com/netlify/gotrue/models"
"github.com/stretchr/testify/require"
)

func (ts *ExternalTestSuite) TestSignupExternalGithub() {
Expand Down Expand Up @@ -201,3 +204,23 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() {
ts.Equal("Unverified email with github", v.Get("error_description"))
assertAuthorizationFailure(ts, u, "", "", "")
}

func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() {
tokenCount, userCount := 0, 0
code := "authcode"
emails := `[{"email":"github@example.com", "primary": true, "verified": true}]`
server := GitHubTestSignupSetup(ts, &tokenCount, &userCount, code, emails)
defer server.Close()

u := performAuthorization(ts, "github", code, "")
assertAuthorizationSuccess(ts, u, tokenCount, userCount, "github@example.com", "GitHub Test", "123", "http://example.com/avatar")

user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "github@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
t := time.Now().Add(24 * time.Hour)
user.BannedUntil = &t
require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until"))

u = performAuthorization(ts, "github", code, "")
assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "")
}
9 changes: 8 additions & 1 deletion api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return internalServerError("Database error querying schema").WithInternalError(err)
}

if !user.Authenticate(params.Password) {
if user.IsBanned() || !user.Authenticate(params.Password) {
return oauthError("invalid_grant", InvalidLoginMessage)
}

Expand Down Expand Up @@ -235,6 +235,10 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h
return internalServerError(err.Error())
}

if user.IsBanned() {
return oauthError("invalid_grant", "Invalid Refresh Token")
}

if !(config.External.Email.Enabled && config.External.Phone.Enabled) {
providers, err := models.FindProvidersByUser(a.db, user)
if err != nil {
Expand Down Expand Up @@ -397,6 +401,9 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
if email != "" {
identity.IdentityData["email"] = email
}
if user.IsBanned() {
return oauthError("invalid_grant", "invalid id token grant")
}
if terr = tx.UpdateOnly(identity, "identity_data", "last_sign_in_at"); terr != nil {
return terr
}
Expand Down
95 changes: 93 additions & 2 deletions api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package api

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/conf"
Expand All @@ -20,8 +22,8 @@ type TokenTestSuite struct {
API *API
Config *conf.Configuration

token string
instanceID uuid.UUID
RefreshToken *models.RefreshToken
instanceID uuid.UUID
}

func TestToken(t *testing.T) {
Expand All @@ -40,7 +42,19 @@ func TestToken(t *testing.T) {
}

func (ts *TokenTestSuite) SetupTest() {
ts.RefreshToken = nil
models.TruncateAll(ts.API.db)

// Create user & refresh token
u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
u.BannedUntil = nil
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")

ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u)
require.NoError(ts.T(), err, "Error creating refresh token")
}

func (ts *TokenTestSuite) TestRateLimitToken() {
Expand Down Expand Up @@ -73,3 +87,80 @@ func (ts *TokenTestSuite) TestRateLimitToken() {
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestTokenPasswordGrantSuccess() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "test@example.com",
"password": "password",
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
}

func (ts *TokenTestSuite) TestTokenRefreshTokenGrantSuccess() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusOK, w.Code)
}

func (ts *TokenTestSuite) TestTokenPasswordGrantFailure() {
u := ts.createBannedUser()

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": u.GetEmail(),
"password": "password",
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) TestTokenRefreshTokenGrantFailure() {
_ = ts.createBannedUser()

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": ts.RefreshToken.Token,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
assert.Equal(ts.T(), http.StatusBadRequest, w.Code)
}

func (ts *TokenTestSuite) createBannedUser() *models.User {
u, err := models.NewUser(ts.instanceID, "banned@example.com", "password", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
t := time.Now()
u.EmailConfirmedAt = &t
t = t.Add(24 * time.Hour)
u.BannedUntil = &t
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test banned user")

ts.RefreshToken, err = models.GrantAuthenticatedUser(ts.API.db, u)
require.NoError(ts.T(), err, "Error creating refresh token")

return u
}
Loading

0 comments on commit cc94302

Please sign in to comment.