Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions internal/indexworker/indexworker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

// ErrAdvisoryLockAlreadyAcquired is returned when another process already holds the advisory lock
var ErrAdvisoryLockAlreadyAcquired = errors.New("advisory lock already acquired by another process")
var ErrExtensionNotFound = errors.New("extension not found")

// CreateIndexes ensures that the necessary indexes on the users table exist.
// If the indexes already exist and are valid, it skips creation.
Expand Down Expand Up @@ -92,7 +93,14 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
}
}()

indexes := getUsersIndexes(config.DB.Namespace)
// Look up which schema the pg_trgm extension is installed in
trgmSchema, err := getTrgmExtensionSchema(db)
if err != nil {
le.Errorf("Failed to find pg_trgm extension schema: %+v", err)
return ErrExtensionNotFound
}

indexes := getUsersIndexes(config.DB.Namespace, trgmSchema)
indexNames := make([]string, len(indexes))
for i, idx := range indexes {
indexNames[i] = idx.name
Expand Down Expand Up @@ -162,8 +170,25 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
return nil
}

// getTrgmExtensionSchema looks up which schema the pg_trgm extension is installed in
func getTrgmExtensionSchema(db *pop.Connection) (string, error) {
var schema string
query := `
SELECT extnamespace::regnamespace::text AS schema_name
FROM pg_extension
WHERE extname = 'pg_trgm'
LIMIT 1
`

if err := db.RawQuery(query).First(&schema); err != nil {
return "", fmt.Errorf("failed to find pg_trgm extension schema: %w", err)
}

return schema, nil
}

// getUsersIndexes returns the list of indexes to create on the users table
func getUsersIndexes(namespace string) []struct {
func getUsersIndexes(namespace, trgmSchema string) []struct {
name string
query string
} {
Expand All @@ -182,7 +207,7 @@ func getUsersIndexes(namespace string) []struct {
{
name: "idx_users_email_trgm",
query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email_trgm
ON %q.users USING gin (email gin_trgm_ops);`, namespace),
ON %q.users USING gin (email %s.gin_trgm_ops);`, namespace, trgmSchema),
},
// enables exact-match and prefix searches and sorting by phone number
{
Expand All @@ -205,8 +230,8 @@ func getUsersIndexes(namespace string) []struct {
{
name: "idx_users_name_trgm",
query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_name_trgm
ON %q.users USING gin ((raw_user_meta_data->>'name') gin_trgm_ops)
WHERE raw_user_meta_data->>'name' IS NOT NULL;`, namespace),
ON %q.users USING gin ((raw_user_meta_data->>'name') %s.gin_trgm_ops)
WHERE raw_user_meta_data->>'name' IS NOT NULL;`, namespace, trgmSchema),
},
}
}
Expand Down
52 changes: 45 additions & 7 deletions internal/indexworker/indexworker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (ts *IndexWorkerTestSuite) SetupTest() {
}

func (ts *IndexWorkerTestSuite) cleanupIndexes() {
indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
for _, idx := range indexes {
// Drop any existing indexes (valid or invalid)
dropQuery := fmt.Sprintf("DROP INDEX IF EXISTS %q.%s", ts.namespace, idx.name)
Expand All @@ -91,7 +91,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesHappyPath() {
err := CreateIndexes(ctx, ts.config, ts.logger)
require.NoError(ts.T(), err)

indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)

Expand Down Expand Up @@ -135,7 +135,7 @@ func (ts *IndexWorkerTestSuite) TestIdempotency() {
require.NoError(ts.T(), err)

// Get the state after first run
indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
firstRunIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
require.Equal(ts.T(), len(indexes), len(firstRunIndexes))
Expand Down Expand Up @@ -191,7 +191,7 @@ func (ts *IndexWorkerTestSuite) TestOutOfBandIndexRemoval() {
require.NoError(ts.T(), err)

// Verify all indexes exist
indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Equal(ts.T(), len(indexes), len(existingIndexes))
Expand Down Expand Up @@ -277,7 +277,7 @@ func (ts *IndexWorkerTestSuite) TestConcurrentWorkers() {
assert.Equal(ts.T(), numWorkers-1, lockSkipCount, "Other workers should skip due to lock")

// Verify all indexes were created successfully
indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should be created")
Expand Down Expand Up @@ -306,7 +306,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() {
require.NoError(ts.T(), err, "Initial CreateIndexes should succeed")

// Verify all indexes were created and are valid
indexes := getUsersIndexes(ts.namespace)
indexes := getUsersIndexes(ts.namespace, ts.namespace)
initialIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Equal(ts.T(), len(indexes), len(initialIndexes), "All indexes should be created initially")
Expand Down Expand Up @@ -337,7 +337,7 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() {
defer manipulatorDB.Close()

// Select the first 2 indexes to mark as invalid
allIndexes := getUsersIndexes(ts.namespace)
allIndexes := getUsersIndexes(ts.namespace, ts.namespace)
indexesToInvalidate := []string{allIndexes[0].name, allIndexes[1].name}

for _, indexName := range indexesToInvalidate {
Expand Down Expand Up @@ -393,6 +393,44 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() {
ts.logger.Infof("Successfully recovered from %d invalid indexes", len(indexesToInvalidate))
}

// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes fails when pg_trgm extension doesn't exist
// and that no indexes are created when this prerequisite check fails.
func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() {
ctx := context.Background()

// Drop the pg_trgm extension to simulate it not being available
dropExtQuery := "DROP EXTENSION IF EXISTS pg_trgm CASCADE"
err := ts.db.RawQuery(dropExtQuery).Exec()
require.NoError(ts.T(), err, "Should be able to drop pg_trgm extension")

// Verify the extension is dropped
var extensionExists bool
checkExtQuery := "SELECT EXISTS(SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm')"
err = ts.db.RawQuery(checkExtQuery).First(&extensionExists)
require.NoError(ts.T(), err)
assert.False(ts.T(), extensionExists, "pg_trgm extension should not exist")

// Verify no indexes exist initially
indexes := getUsersIndexes(ts.namespace, ts.namespace)
existingIndexes, err := getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Empty(ts.T(), existingIndexes, "No indexes should exist initially")

// Try to create indexes without pg_trgm extension
err = CreateIndexes(ctx, ts.config, ts.logger)
assert.Error(ts.T(), err, "CreateIndexes should fail when pg_trgm extension doesn't exist")
assert.ErrorIs(ts.T(), err, ErrExtensionNotFound)

existingIndexes, err = getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Empty(ts.T(), existingIndexes, "No indexes should have been created when pg_trgm is missing")

// Restore pg_trgm extension for other tests
createExtQuery := "CREATE EXTENSION IF NOT EXISTS pg_trgm"
err = ts.db.RawQuery(createExtQuery).Exec()
require.NoError(ts.T(), err, "Should be able to restore pg_trgm extension")
}

// Run the test suite
func TestIndexWorker(t *testing.T) {
suite.Run(t, new(IndexWorkerTestSuite))
Expand Down

This file was deleted.