Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 107 additions & 8 deletions internal/indexworker/indexworker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package indexworker

import (
"context"
"database/sql"
"errors"
"fmt"
"log"
Expand All @@ -10,6 +11,7 @@ import (
"time"

"github.com/gobuffalo/pop/v6"
pkgerrors "github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/conf"
)
Expand Down Expand Up @@ -93,10 +95,17 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
}
}()

// Look up which schema the pg_trgm extension is installed in
trgmSchema, err := getTrgmExtensionSchema(db)
// Ensure either auth_trgm or pg_trgm extension is installed
extName, err := ensureTrgmExtension(db, config.DB.Namespace, le)
if err != nil {
le.Errorf("Failed to find pg_trgm extension schema: %+v", err)
le.Errorf("Failed to ensure trgm extension is available: %+v", err)
return err
}

// Look up which schema the trgm extension is installed in
trgmSchema, err := getTrgmExtensionSchema(db, extName)
if err != nil {
le.Errorf("Failed to find %s extension schema: %+v", extName, err)
return ErrExtensionNotFound
}

Expand Down Expand Up @@ -170,23 +179,113 @@ func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *lo
return nil
}

// getTrgmExtensionSchema looks up which schema the pg_trgm extension is installed in
func getTrgmExtensionSchema(db *pop.Connection) (string, error) {
// getTrgmExtensionSchema looks up which schema the specified trgm extension is installed in
func getTrgmExtensionSchema(db *pop.Connection, extName string) (string, error) {
var schema string
query := `
SELECT extnamespace::regnamespace::text AS schema_name
FROM pg_extension
WHERE extname = 'pg_trgm'
WHERE extname = $1
LIMIT 1
`

if err := db.RawQuery(query).First(&schema); err != nil {
return "", fmt.Errorf("failed to find pg_trgm extension schema: %w", err)
if err := db.RawQuery(query, extName).First(&schema); err != nil {
return "", fmt.Errorf("failed to find %s extension schema: %w", extName, err)
}

return schema, nil
}

// extensionStatus represents the status of an extension from pg_available_extensions
type extensionStatus struct {
Available bool
Installed bool
}

// getExtensionStatus checks if an extension is available and/or installed
func getExtensionStatus(db *pop.Connection, extName string) (extensionStatus, error) {
var result struct {
Name *string `db:"name"`
InstalledVersion *string `db:"installed_version"`
}

query := `
SELECT name, installed_version
FROM pg_available_extensions
WHERE name = $1
`

if err := db.RawQuery(query, extName).First(&result); err != nil {
// If no rows returned, extension is not available
if pkgerrors.Cause(err) == sql.ErrNoRows {
return extensionStatus{Available: false, Installed: false}, nil
}
return extensionStatus{}, fmt.Errorf("failed to check extension status for %s: %w", extName, err)
}

return extensionStatus{
Available: result.Name != nil,
Installed: result.InstalledVersion != nil,
}, nil
}

// installExtension installs the specified extension in the provided schema
func installExtension(db *pop.Connection, extName string, schema string) error {
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s SCHEMA %s", extName, schema)
if err := db.RawQuery(query).Exec(); err != nil {
return fmt.Errorf("failed to install extension %s in schema %s: %w", extName, schema, err)
}
return nil
}

// ensureTrgmExtension ensures that either auth_trgm or pg_trgm extension is installed
// It prefers auth_trgm if available, otherwise falls back to pg_trgm
// Returns the name of the installed extension
func ensureTrgmExtension(db *pop.Connection, authSchema string, le *logrus.Entry) (string, error) {
authTrgmStatus, err := getExtensionStatus(db, "auth_trgm")
if err != nil {
return "", fmt.Errorf("failed to check auth_trgm extension status: %w", err)
}

if authTrgmStatus.Available {
if !authTrgmStatus.Installed {
le.Infof("auth_trgm extension is available but not installed. Installing...")
if err := installExtension(db, "auth_trgm", authSchema); err != nil {
le.Errorf("Failed to install auth_trgm extension: %v", err)
return "", fmt.Errorf("auth_trgm extension is available but failed to install: %w", err)
}
le.Infof("Successfully installed auth_trgm extension")
} else {
le.Infof("auth_trgm extension is already installed")
}
return "auth_trgm", nil
}

le.Infof("auth_trgm extension is not available, checking pg_trgm...")

pgTrgmStatus, err := getExtensionStatus(db, "pg_trgm")
if err != nil {
return "", fmt.Errorf("failed to check pg_trgm extension status: %w", err)
}

if !pgTrgmStatus.Available {
return "", fmt.Errorf("neither auth_trgm nor pg_trgm extensions are available")
}

if !pgTrgmStatus.Installed {
le.Infof("pg_trgm extension is available but not installed. Installing...")
if err := installExtension(db, "pg_trgm", "pg_catalog"); err != nil {
le.Errorf("Failed to install pg_trgm extension: %v", err)
return "", fmt.Errorf("pg_trgm extension is available but failed to install: %w", err)
}
le.Infof("Successfully installed pg_trgm extension")
} else {
le.Infof("pg_trgm extension is already installed")
}

return "pg_trgm", nil
}

// getUsersIndexes returns the list of indexes to create on the users table
func getUsersIndexes(namespace, trgmSchema string) []struct {
name string
Expand Down
24 changes: 17 additions & 7 deletions internal/indexworker/indexworker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,12 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithInvalidIndexes() {
ts.logger.Infof("Successfully recovered from %d invalid indexes", len(indexesToInvalidate))
}

// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes fails when pg_trgm extension doesn't exist
// and that no indexes are created when this prerequisite check fails.
// TestCreateIndexesWithoutTrgmExtension tests that CreateIndexes installs pg_trgm extension
// when it's available but not installed, and then successfully creates indexes.
func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() {
ctx := context.Background()

// Drop the pg_trgm extension to simulate it not being available
// Drop the pg_trgm extension to simulate it not being installed
dropExtQuery := "DROP EXTENSION IF EXISTS pg_trgm CASCADE"
err := ts.db.RawQuery(dropExtQuery).Exec()
require.NoError(ts.T(), err, "Should be able to drop pg_trgm extension")
Expand All @@ -416,14 +416,24 @@ func (ts *IndexWorkerTestSuite) TestCreateIndexesWithoutTrgmExtension() {
require.NoError(ts.T(), err)
assert.Empty(ts.T(), existingIndexes, "No indexes should exist initially")

// Try to create indexes without pg_trgm extension
// Run CreateIndexes - it should install the pg_trgm extension and create indexes
err = CreateIndexes(ctx, ts.config, ts.logger)
assert.Error(ts.T(), err, "CreateIndexes should fail when pg_trgm extension doesn't exist")
assert.ErrorIs(ts.T(), err, ErrExtensionNotFound)
require.NoError(ts.T(), err, "CreateIndexes should succeed by installing the pg_trgm extension")

// Verify that pg_trgm is now installed
err = ts.db.RawQuery(checkExtQuery).First(&extensionExists)
require.NoError(ts.T(), err)
assert.True(ts.T(), extensionExists, "pg_trgm extension should have been installed")

// Verify all indexes were created successfully
existingIndexes, err = getIndexStatuses(ts.popDB, ts.namespace, getIndexNames(indexes))
require.NoError(ts.T(), err)
assert.Empty(ts.T(), existingIndexes, "No indexes should have been created when pg_trgm is missing")
assert.Equal(ts.T(), len(indexes), len(existingIndexes), "All indexes should have been created")

for _, idx := range existingIndexes {
assert.True(ts.T(), idx.IsValid, "Index %s should be valid", idx.IndexName)
assert.True(ts.T(), idx.IsReady, "Index %s should be ready", idx.IndexName)
}

// Restore pg_trgm extension for other tests
createExtQuery := "CREATE EXTENSION IF NOT EXISTS pg_trgm"
Expand Down