Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions internal/api/hooks_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package api

import (
"encoding/json"
"net/http"
"net/url"
"testing"

"net/http/httptest"
Expand All @@ -13,6 +15,7 @@ import (
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/hooks/hookserrors"
"github.com/supabase/auth/internal/hooks/v0hooks"
mail "github.com/supabase/auth/internal/mailer"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"

Expand Down Expand Up @@ -294,3 +297,173 @@ func (ts *HooksTestSuite) TestInvokeHookIntegration() {
// Ensure that all expected HTTP interactions (mocks) have been called
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry")
}

func (ts *HooksTestSuite) TestAccountChangesNotificationsHookPayload() {
// Setup hook config for send_email hook
defer gock.OffAll()

testURL := "http://localhost:8888/functions/v1/send-email"
ts.Config.Hook.SendEmail.URI = testURL
ts.Config.Hook.SendEmail.Enabled = true

// Mock the hook endpoint to capture the payload
var capturedPayload *v0hooks.SendEmailInput

gock.New(testURL).
Post("/").
MatchType("json").
SetMatcher(gock.NewMatcher()).
AddMatcher(func(req *http.Request, greq *gock.Request) (bool, error) {
// Capture the payload
payload := &v0hooks.SendEmailInput{}
if err := json.NewDecoder(req.Body).Decode(payload); err != nil {
return false, err
}
capturedPayload = payload
return true, nil
}).
Persist().
Reply(http.StatusOK).
JSON(v0hooks.SendEmailOutput{})

testCases := []struct {
description string
expectedActionType string
expectedProvider string
expectedOldEmail string
expectedOldPhone string
expectedFactorType string
setupFunc func() error
enableNotification func()
}{
{
description: "IdentityLinkedNotification contains provider",
expectedActionType: mail.IdentityLinkedNotification,
expectedProvider: "google",
enableNotification: func() {
ts.Config.Mailer.Notifications.IdentityLinkedEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("POST", "/identities", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendIdentityLinkedNotification(req, ts.API.db, ts.TestUser, "google")
},
},
{
description: "IdentityUnlinkedNotification contains provider",
expectedActionType: mail.IdentityUnlinkedNotification,
expectedProvider: "github",
enableNotification: func() {
ts.Config.Mailer.Notifications.IdentityUnlinkedEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("DELETE", "/identities/123", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendIdentityUnlinkedNotification(req, ts.API.db, ts.TestUser, "github")
},
},
{
description: "EmailChangedNotification contains old_email",
expectedActionType: mail.EmailChangedNotification,
expectedOldEmail: "old@example.com",
enableNotification: func() {
ts.Config.Mailer.Notifications.EmailChangedEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("PUT", "/user", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendEmailChangedNotification(req, ts.API.db, ts.TestUser, "old@example.com")
},
},
{
description: "PhoneChangedNotification contains old_phone",
expectedActionType: mail.PhoneChangedNotification,
expectedOldPhone: "+15551234567",
enableNotification: func() {
ts.Config.Mailer.Notifications.PhoneChangedEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("PUT", "/user", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendPhoneChangedNotification(req, ts.API.db, ts.TestUser, "+15551234567")
},
},
{
description: "MFAFactorEnrolledNotification contains factor_type",
expectedActionType: mail.MFAFactorEnrolledNotification,
expectedFactorType: "totp",
enableNotification: func() {
ts.Config.Mailer.Notifications.MFAFactorEnrolledEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("POST", "/factors", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendMFAFactorEnrolledNotification(req, ts.API.db, ts.TestUser, "totp")
},
},
{
description: "MFAFactorUnenrolledNotification contains factor_type",
expectedActionType: mail.MFAFactorUnenrolledNotification,
expectedFactorType: "phone",
enableNotification: func() {
ts.Config.Mailer.Notifications.MFAFactorUnenrolledEnabled = true
},
setupFunc: func() error {
req := httptest.NewRequest("DELETE", "/factors/123", nil)
externalHost, err := url.Parse("http://example.com")
require.NoError(ts.T(), err)
req = req.WithContext(withExternalHost(req.Context(), externalHost))
return ts.API.sendMFAFactorUnenrolledNotification(req, ts.API.db, ts.TestUser, "phone")
},
},
}

for _, tc := range testCases {
ts.Run(tc.description, func() {
// Reset captured payload
capturedPayload = nil

// Enable the notification
tc.enableNotification()

// Execute the setup function that triggers the notification
err := tc.setupFunc()
require.NoError(ts.T(), err)

// Verify the payload was captured
require.NotNil(ts.T(), capturedPayload, "Hook should have been called")

// Verify email action type
require.Equal(ts.T(), tc.expectedActionType, capturedPayload.EmailData.EmailActionType)

// Verify notification-specific fields
if tc.expectedProvider != "" {
require.Equal(ts.T(), tc.expectedProvider, capturedPayload.EmailData.Provider, "Provider should be set in EmailData")
}
if tc.expectedOldEmail != "" {
require.Equal(ts.T(), tc.expectedOldEmail, capturedPayload.EmailData.OldEmail, "OldEmail should be set in EmailData")
}
if tc.expectedOldPhone != "" {
require.Equal(ts.T(), tc.expectedOldPhone, capturedPayload.EmailData.OldPhone, "OldPhone should be set in EmailData")
}
if tc.expectedFactorType != "" {
require.Equal(ts.T(), tc.expectedFactorType, capturedPayload.EmailData.FactorType, "FactorType should be set in EmailData")
}

// Verify common fields
require.Equal(ts.T(), ts.TestUser.ID, capturedPayload.User.ID, "User ID should match")
require.NotEmpty(ts.T(), capturedPayload.EmailData.SiteURL, "SiteURL should be set")
require.NotEmpty(ts.T(), capturedPayload.EmailData.RedirectTo, "RedirectTo should be set")
})
}
}
13 changes: 13 additions & 0 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,19 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
emailData.Token = params.otpNew
}
}

// Augment the email data for the email send hook with notification-specific fields
switch params.emailActionType {
case mail.EmailChangedNotification:
emailData.OldEmail = params.oldEmail
case mail.PhoneChangedNotification:
emailData.OldPhone = params.oldPhone
case mail.IdentityLinkedNotification, mail.IdentityUnlinkedNotification:
emailData.Provider = params.provider
case mail.MFAFactorEnrolledNotification, mail.MFAFactorUnenrolledNotification:
emailData.FactorType = params.factorType
}

input := v0hooks.SendEmailInput{
User: u,
EmailData: emailData,
Expand Down
4 changes: 4 additions & 0 deletions internal/mailer/mailer.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,8 @@ type EmailData struct {
SiteURL string `json:"site_url"`
TokenNew string `json:"token_new"`
TokenHashNew string `json:"token_hash_new"`
OldEmail string `json:"old_email"`
OldPhone string `json:"old_phone"`
Provider string `json:"provider"`
FactorType string `json:"factor_type"`
}