diff --git a/internal/indexworker/indexworker.go b/internal/indexworker/indexworker.go index 6292fc930..64e8f0415 100644 --- a/internal/indexworker/indexworker.go +++ b/internal/indexworker/indexworker.go @@ -2,6 +2,7 @@ package indexworker import ( "context" + "database/sql" "errors" "fmt" "log" @@ -10,6 +11,7 @@ import ( "time" "github.com/gobuffalo/pop/v6" + pkgerrors "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/conf" ) @@ -93,10 +95,17 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo } }() - // Look up which schema the pg_trgm extension is installed in - trgmSchema, err := getTrgmExtensionSchema(db) + // Ensure either auth_trgm or pg_trgm extension is installed + extName, err := ensureTrgmExtension(db, config.DB.Namespace, le) if err != nil { - le.Errorf("Failed to find pg_trgm extension schema: %+v", err) + le.Errorf("Failed to ensure trgm extension is available: %+v", err) + return err + } + + // Look up which schema the trgm extension is installed in + trgmSchema, err := getTrgmExtensionSchema(db, extName) + if err != nil { + le.Errorf("Failed to find %s extension schema: %+v", extName, err) return ErrExtensionNotFound } @@ -170,23 +179,113 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo return nil } -// getTrgmExtensionSchema looks up which schema the pg_trgm extension is installed in -func getTrgmExtensionSchema(db *pop.Connection) (string, error) { +// getTrgmExtensionSchema looks up which schema the specified trgm extension is installed in +func getTrgmExtensionSchema(db *pop.Connection, extName string) (string, error) { var schema string query := ` SELECT extnamespace::regnamespace::text AS schema_name FROM pg_extension - WHERE extname = 'pg_trgm' + WHERE extname = $1 LIMIT 1 ` - if err := db.RawQuery(query).First(&schema); err != nil { - return "", fmt.Errorf("failed to find pg_trgm extension schema: %w", err) + if err := db.RawQuery(query, extName).First(&schema); err != nil { + return "", fmt.Errorf("failed to find %s extension schema: %w", extName, err) } return schema, nil } +// extensionStatus represents the status of an extension from pg_available_extensions +type extensionStatus struct { + Available bool + Installed bool +} + +// getExtensionStatus checks if an extension is available and/or installed +func getExtensionStatus(db *pop.Connection, extName string) (extensionStatus, error) { + var result struct { + Name *string `db:"name"` + InstalledVersion *string `db:"installed_version"` + } + + query := ` + SELECT name, installed_version + FROM pg_available_extensions + WHERE name = $1 + ` + + if err := db.RawQuery(query, extName).First(&result); err != nil { + // If no rows returned, extension is not available + if pkgerrors.Cause(err) == sql.ErrNoRows { + return extensionStatus{Available: false, Installed: false}, nil + } + return extensionStatus{}, fmt.Errorf("failed to check extension status for %s: %w", extName, err) + } + + return extensionStatus{ + Available: result.Name != nil, + Installed: result.InstalledVersion != nil, + }, nil +} + +// installExtension installs the specified extension in the provided schema +func installExtension(db *pop.Connection, extName string, schema string) error { + query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s SCHEMA %s", extName, schema) + if err := db.RawQuery(query).Exec(); err != nil { + return fmt.Errorf("failed to install extension %s in schema %s: %w", extName, schema, err) + } + return nil +} + +// ensureTrgmExtension ensures that either auth_trgm or pg_trgm extension is installed +// It prefers auth_trgm if available, otherwise falls back to pg_trgm +// Returns the name of the installed extension +func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry) (string, error) { + authTrgmStatus, err := getExtensionStatus(db, "auth_trgm") + if err != nil { + return "", fmt.Errorf("failed to check auth_trgm extension status: %w", err) + } + + if authTrgmStatus.Available { + if !authTrgmStatus.Installed { + le.Infof("auth_trgm extension is available but not installed. Installing...") + if err := installExtension(db, "auth_trgm", authSchema); err != nil { + le.Errorf("Failed to install auth_trgm extension: %v", err) + return "", fmt.Errorf("auth_trgm extension is available but failed to install: %w", err) + } + le.Infof("Successfully installed auth_trgm extension") + } else { + le.Infof("auth_trgm extension is already installed") + } + return "auth_trgm", nil + } + + le.Infof("auth_trgm extension is not available, checking pg_trgm...") + + pgTrgmStatus, err := getExtensionStatus(db, "pg_trgm") + if err != nil { + return "", fmt.Errorf("failed to check pg_trgm extension status: %w", err) + } + + if !pgTrgmStatus.Available { + return "", fmt.Errorf("neither auth_trgm nor pg_trgm extensions are available") + } + + if !pgTrgmStatus.Installed { + le.Infof("pg_trgm extension is available but not installed. Installing...") + if err := installExtension(db, "pg_trgm", "pg_catalog"); err != nil { + le.Errorf("Failed to install pg_trgm extension: %v", err) + return "", fmt.Errorf("pg_trgm extension is available but failed to install: %w", err) + } + le.Infof("Successfully installed pg_trgm extension") + } else { + le.Infof("pg_trgm extension is already installed") + } + + return "pg_trgm", nil +} + // getUsersIndexes returns the list of indexes to create on the users table func getUsersIndexes(namespace, trgmSchema string) []struct { name string diff --git a/internal/indexworker/indexworker_test.go b/internal/indexworker/indexworker_test.go index 1f4390d76..5bd4779de 100644 --- a/internal/indexworker/indexworker_test.go +++ b/internal/indexworker/indexworker_test.go @@ -393,12 +393,12 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() { ts.logger.Infof("Successfully recovered from %d invalid indexes", len(indexesToInvalidate)) } -// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes fails when pg_trgm extension doesn't exist -// and that no indexes are created when this prerequisite check fails. +// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes installs pg_trgm extension +// when it's available but not installed, and then successfully creates indexes. func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() { ctx := context.Background() - // Drop the pg_trgm extension to simulate it not being available + // Drop the pg_trgm extension to simulate it not being installed dropExtQuery := "DROP EXTENSION IF EXISTS pg_trgm CASCADE" err := ts.db.RawQuery(dropExtQuery).Exec() require.NoError(ts.T(), err, "Should be able to drop pg_trgm extension") @@ -416,14 +416,24 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() { require.NoError(ts.T(), err) assert.Empty(ts.T(), existingIndexes, "No indexes should exist initially") - // Try to create indexes without pg_trgm extension + // Run CreateIndexes - it should install the pg_trgm extension and create indexes err = CreateIndexes(ctx, ts.config, ts.logger) - assert.Error(ts.T(), err, "CreateIndexes should fail when pg_trgm extension doesn't exist") - assert.ErrorIs(ts.T(), err, ErrExtensionNotFound) + require.NoError(ts.T(), err, "CreateIndexes should succeed by installing the pg_trgm extension") + // Verify that pg_trgm is now installed + err = ts.db.RawQuery(checkExtQuery).First(&extensionExists) + require.NoError(ts.T(), err) + assert.True(ts.T(), extensionExists, "pg_trgm extension should have been installed") + + // Verify all indexes were created successfully existingIndexes, err = getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes)) require.NoError(ts.T(), err) - assert.Empty(ts.T(), existingIndexes, "No indexes should have been created when pg_trgm is missing") + assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should have been created") + + for _, idx := range existingIndexes { + assert.True(ts.T(), idx.IsValid, "Index %s should be valid", idx.IndexName) + assert.True(ts.T(), idx.IsReady, "Index %s should be ready", idx.IndexName) + } // Restore pg_trgm extension for other tests createExtQuery := "CREATE EXTENSION IF NOT EXISTS pg_trgm"