diff --git a/README.md b/README.md index 50e85d12f..5b21c630e 100644 --- a/README.md +++ b/README.md @@ -737,9 +737,12 @@ or show an account confirmed/welcome message in the case of `signup`, or direct One-Time-Password. Will deliver a magiclink or sms otp to the user depending on whether the request body contains an "email" or "phone" key. +If `"create_user": true`, user will not be automatically signed up if the user doesn't exist. + ```js { "phone": "12345678" // follows the E.164 format + "create_user": true } OR @@ -747,6 +750,7 @@ OR // exactly the same as /magiclink { "email": "email@example.com" + "create_user": true } ``` diff --git a/api/errors.go b/api/errors.go index 992321841..f839357de 100644 --- a/api/errors.go +++ b/api/errors.go @@ -181,6 +181,18 @@ func (e *OTPError) Error() string { return fmt.Sprintf("%s: %s", e.Err, e.Description) } +// WithInternalError adds internal error information to the error +func (e *OTPError) WithInternalError(err error) *OTPError { + e.InternalError = err + return e +} + +// WithInternalMessage adds internal message information to the error +func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { + e.InternalMessage = fmt.Sprintf(fmtString, args...) + return e +} + // Cause returns the root cause error func (e *OTPError) Cause() error { if e.InternalError != nil { @@ -244,6 +256,11 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) { if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { handleError(jsonErr, w, r) } + case *OTPError: + log.WithError(e.Cause()).Info(e.Error()) + if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { + handleError(jsonErr, w, r) + } case ErrorCause: handleError(e.Cause(), w, r) default: diff --git a/api/otp.go b/api/otp.go index 05bf94e8c..56d1378b2 100644 --- a/api/otp.go +++ b/api/otp.go @@ -14,8 +14,9 @@ import ( // OtpParams contains the request body params for the otp endpoint type OtpParams struct { - Email string `json:"email"` - Phone string `json:"phone"` + Email string `json:"email"` + Phone string `json:"phone"` + CreateUser bool `json:"create_user"` } // SmsParams contains the request body params for sms otp @@ -36,6 +37,11 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } r.Body = ioutil.NopCloser(strings.NewReader(string(body))) + + if !a.shouldCreateUser(r, params) { + return badRequestError("Signups not allowed for otp") + } + if params.Email != "" { return a.MagicLink(w, r) } else if params.Phone != "" { @@ -108,3 +114,22 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { return sendJSON(w, http.StatusOK, make(map[string]string)) } + +func (a *API) shouldCreateUser(r *http.Request, params *OtpParams) bool { + if !params.CreateUser { + ctx := r.Context() + instanceID := getInstanceID(ctx) + aud := a.requestAud(ctx, r) + var err error + if params.Email != "" { + _, err = models.FindUserByEmailAndAudience(a.db, instanceID, params.Email, aud) + } else if params.Phone != "" { + _, err = models.FindUserByPhoneAndAudience(a.db, instanceID, params.Phone, aud) + } + + if err != nil && models.IsNotFoundError(err) { + return false + } + } + return true +} diff --git a/api/otp_test.go b/api/otp_test.go new file mode 100644 index 000000000..97851b8d0 --- /dev/null +++ b/api/otp_test.go @@ -0,0 +1,134 @@ +package api + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gofrs/uuid" + "github.com/netlify/gotrue/conf" + "github.com/netlify/gotrue/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type OtpTestSuite struct { + suite.Suite + API *API + Config *conf.Configuration + + instanceID uuid.UUID +} + +func TestOtp(t *testing.T) { + api, config, instanceID, err := setupAPIForTestForInstance() + require.NoError(t, err) + + ts := &OtpTestSuite{ + API: api, + Config: config, + instanceID: instanceID, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *OtpTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) +} + +func (ts *OtpTestSuite) TestOtp() { + cases := []struct { + desc string + params OtpParams + expected struct { + code int + response map[string]interface{} + } + }{ + { + "Test Success Magiclink Otp", + OtpParams{ + Email: "test@example.com", + CreateUser: true, + }, + struct { + code int + response map[string]interface{} + }{ + http.StatusOK, + make(map[string]interface{}), + }, + }, + { + "Test Failure Pass Both Email & Phone", + OtpParams{ + Email: "test@example.com", + Phone: "123456789", + CreateUser: true, + }, + struct { + code int + response map[string]interface{} + }{ + http.StatusBadRequest, + map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "msg": "Only an email address or phone number should be provided", + }, + }, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.params)) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), c.expected.code, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, c.expected.response) + }) + } +} + +func (ts *OtpTestSuite) TestNoSignupsForOtp() { + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "email": "newuser@example.com", + "create_user": false, + })) + + req := httptest.NewRequest(http.MethodPost, "/otp", &buffer) + req.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusBadRequest, w.Code) + + data := make(map[string]interface{}) + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + + // response should be empty + assert.Equal(ts.T(), data, map[string]interface{}{ + "code": float64(http.StatusBadRequest), + "msg": "Signups not allowed for otp", + }) +}