From 655c191d8ee39a559c140749c6351a3e96a9661a Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Thu, 23 Oct 2025 04:24:05 +0900 Subject: [PATCH 1/4] feat(repository): add SQL repository support with multiple backends (sqlite, postgres, mysql) --- go.mod | 13 ++ go.sum | 25 +++ main.go | 6 + pkg/mcp-proxy/main.go | 45 ++++- pkg/repository/sql.go | 395 +++++++++++++++++++++++++++++++++++++ pkg/repository/sql_test.go | 59 ++++++ 6 files changed, 540 insertions(+), 3 deletions(-) create mode 100644 pkg/repository/sql.go create mode 100644 pkg/repository/sql_test.go diff --git a/go.mod b/go.mod index 164b5de..1177bda 100644 --- a/go.mod +++ b/go.mod @@ -19,11 +19,16 @@ require ( go.uber.org/zap v1.27.0 golang.org/x/crypto v0.40.0 golang.org/x/oauth2 v0.14.0 + gorm.io/driver/mysql v1.6.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.31.0 ) require ( cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/asaskevich/govalidator v0.0.0-20230301143203-a9d515a09cc2 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect @@ -46,6 +51,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.26.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/gobuffalo/pop/v6 v6.1.1 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -62,12 +68,19 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.6.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/mattn/goveralls v0.0.12 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index a42034b..da38925 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= @@ -137,6 +139,8 @@ github.com/go-playground/validator/v10 v10.26.0 h1:SP05Nqhjcvz81uJaRfEV0YBSSSGMc github.com/go-playground/validator/v10 v10.26.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/attrs v1.0.3/go.mod h1:KvDJCE0avbufqS0Bw3UV7RQynESY0jjod+572ctX4t8= github.com/gobuffalo/envy v1.10.2/go.mod h1:qGAGwdvDsaEtPhfBzb3o0SfDea8ByGn9j8bKmVft9z8= @@ -277,6 +281,7 @@ github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bY github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= @@ -287,6 +292,8 @@ github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwX github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgproto3/v2 v2.3.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= @@ -297,12 +304,20 @@ github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9 github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= github.com/jackc/pgx/v4 v4.17.2/go.mod h1:lcxIZN44yMIrWI78a5CpucdD14hX0SBDbNRvjDBItsw= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jandelgado/gcov2lcov v1.0.5 h1:rkBt40h0CVK4oCb8Dps950gvfd1rYvQ8+cWa346lVU0= github.com/jandelgado/gcov2lcov v1.0.5/go.mod h1:NnSxK6TMlg1oGDBfGelGbjgorT5/L3cchlbtgFYZSss= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= @@ -366,6 +381,8 @@ github.com/mattn/go-jsonpointer v0.0.1/go.mod h1:1s8vx7JSjlgVRF+LW16MPpWSRZAxyrc github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/goveralls v0.0.12 h1:PEEeF0k1SsTjOBQ8FOmrOAoCu4ytuMaWCnWe94zxbCg= github.com/mattn/goveralls v0.0.12/go.mod h1:44ImGEUfmqH8bBtaMrYKsM65LXfNLWmwaxFGjZwgMSQ= github.com/microcosm-cc/bluemonday v1.0.20/go.mod h1:yfBmMi8mxvaZut3Yytv+jTXRY8mxyjJ0/kQBTElld50= @@ -938,6 +955,14 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.6.0 h1:eNbLmNTpPpTOVZi8MMxCi2aaIm0ZpInbORNXDwyLGvg= +gorm.io/driver/mysql v1.6.0/go.mod h1:D/oCC2GWK3M/dqoLxnOlaNKmXz8WNTfcS9y5ovaSqKo= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.31.0 h1:0VlycGreVhK7RF/Bwt51Fk8v0xLiiiFdbGDPIZQ7mJY= +gorm.io/gorm v1.31.0/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/main.go b/main.go index 6ed74d1..b632984 100644 --- a/main.go +++ b/main.go @@ -72,6 +72,8 @@ func main() { var tlsDirectoryURL string var tlsAcceptTOS bool var dataPath string + var repositoryBackend string + var repositoryDSN string var externalURL string var googleClientID string var googleClientSecret string @@ -179,6 +181,8 @@ func main() { tlsDirectoryURL, tlsAcceptTOS, dataPath, + repositoryBackend, + repositoryDSN, externalURL, googleClientID, googleClientSecret, @@ -216,6 +220,8 @@ func main() { rootCmd.Flags().StringVar(&tlsDirectoryURL, "tls-directory-url", getEnvWithDefault("TLS_DIRECTORY_URL", "https://acme-v02.api.letsencrypt.org/directory"), "ACME directory URL for TLS certificates") rootCmd.Flags().BoolVar(&tlsAcceptTOS, "tls-accept-tos", getEnvBoolWithDefault("TLS_ACCEPT_TOS", false), "Accept TLS terms of service") rootCmd.Flags().StringVarP(&dataPath, "data-path", "d", getEnvWithDefault("DATA_PATH", "./data"), "Path to the data directory") + rootCmd.Flags().StringVar(&repositoryBackend, "repository-backend", getEnvWithDefault("REPOSITORY_BACKEND", "local"), "Repository backend to use: local, sqlite, postgres, or mysql") + rootCmd.Flags().StringVar(&repositoryDSN, "repository-dsn", getEnvWithDefault("REPOSITORY_DSN", ""), "DSN passed directly to the SQL driver (required when repository-backend is sqlite/postgres/mysql)") rootCmd.Flags().StringVarP(&externalURL, "external-url", "e", getEnvWithDefault("EXTERNAL_URL", "http://localhost"), "External URL for the proxy") // Google OAuth configuration diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 7e4a16f..7fc47bb 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -40,6 +40,8 @@ func Run( tlsDirectoryURL string, tlsAcceptTOS bool, dataPath string, + repositoryBackend string, + repositoryDSN string, externalURL string, googleClientID string, googleClientSecret string, @@ -138,10 +140,47 @@ func Run( proxyHeadersMap.Set("Authorization", "Bearer "+proxyBearerToken) } - repo, err := repository.NewKVSRepository(path.Join(dataPath, "db"), "mcp-oauth-proxy") - if err != nil { - return fmt.Errorf("failed to create repository: %w", err) + var repo repository.Repository + switch backend := strings.ToLower(repositoryBackend); backend { + case "", "local": + repo, err = repository.NewKVSRepository(path.Join(dataPath, "db"), "mcp-oauth-proxy") + if err != nil { + return fmt.Errorf("failed to create repository: %w", err) + } + case "sqlite": + if repositoryDSN == "" { + return fmt.Errorf("repository DSN must be provided for sqlite backend") + } + repo, err = repository.NewSQLRepository("sqlite", repositoryDSN) + if err != nil { + return fmt.Errorf("failed to create repository: %w", err) + } + case "postgres", "postgresql": + if repositoryDSN == "" { + return fmt.Errorf("repository DSN must be provided for postgres backend") + } + repo, err = repository.NewSQLRepository("postgres", repositoryDSN) + if err != nil { + return fmt.Errorf("failed to create repository: %w", err) + } + case "mysql": + if repositoryDSN == "" { + return fmt.Errorf("repository DSN must be provided for mysql backend") + } + repo, err = repository.NewSQLRepository("mysql", repositoryDSN) + if err != nil { + return fmt.Errorf("failed to create repository: %w", err) + } + case "sql": + return fmt.Errorf("repository backend 'sql' is deprecated; use sqlite, postgres, or mysql") + default: + return fmt.Errorf("unsupported repository backend: %s", repositoryBackend) } + defer func() { + if err := repo.Close(); err != nil { + logger.Warn("failed to close repository", zap.Error(err)) + } + }() privKey, err := utils.LoadOrGeneratePrivateKey(path.Join(dataPath, "private_key.pem")) if err != nil { diff --git a/pkg/repository/sql.go b/pkg/repository/sql.go new file mode 100644 index 0000000..35f18bd --- /dev/null +++ b/pkg/repository/sql.go @@ -0,0 +1,395 @@ +package repository + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/ory/fosite" + "github.com/sigbit/mcp-auth-proxy/pkg/models" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type sqlRepository struct { + db *gorm.DB +} + +type authorizeCodeSession struct { + Code string `gorm:"primaryKey;size:255"` + Request []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type accessTokenSession struct { + Signature string `gorm:"primaryKey;size:255"` + Request []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type refreshTokenSession struct { + Signature string `gorm:"primaryKey;size:255"` + AccessSignature string `gorm:"size:255"` + Request []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type clientRecord struct { + ID string `gorm:"primaryKey;size:255"` + Client []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type pkceRequestSession struct { + Signature string `gorm:"primaryKey;size:255"` + Request []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +type authorizeRequestRecord struct { + RequestID string `gorm:"primaryKey;size:255"` + Request []byte `gorm:"not null"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func NewSQLRepository(driver string, dsn string) (Repository, error) { + if driver == "" { + return nil, fmt.Errorf("driver must not be empty") + } + if dsn == "" { + return nil, fmt.Errorf("dsn must not be empty") + } + + var dialector gorm.Dialector + switch strings.ToLower(driver) { + case "sqlite": + dialector = sqlite.Open(dsn) + case "postgres", "postgresql": + dialector = postgres.Open(dsn) + case "mysql": + dialector = mysql.Open(dsn) + default: + return nil, fmt.Errorf("unsupported driver: %s", driver) + } + + db, err := gorm.Open(dialector, &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("failed to connect database: %w", err) + } + + if err := db.AutoMigrate( + &authorizeCodeSession{}, + &accessTokenSession{}, + &refreshTokenSession{}, + &clientRecord{}, + &pkceRequestSession{}, + &authorizeRequestRecord{}, + ); err != nil { + return nil, fmt.Errorf("failed to migrate schema: %w", err) + } + + return &sqlRepository{db: db}, nil +} + +func (r *sqlRepository) CreateAuthorizeCodeSession(ctx context.Context, code string, fositeReq fosite.Requester) error { + data, err := marshalRequest(fositeReq) + if err != nil { + return err + } + + session := authorizeCodeSession{ + Code: code, + Request: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&session).Error +} + +func (r *sqlRepository) GetAuthorizeCodeSession(ctx context.Context, code string, sess fosite.Session) (fosite.Requester, error) { + var session authorizeCodeSession + if err := r.db.WithContext(ctx).First(&session, "code = ?", code).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load authorize code session: %w", err) + } + + return unmarshalRequest(session.Request, sess) +} + +func (r *sqlRepository) InvalidateAuthorizeCodeSession(ctx context.Context, code string) error { + return r.db.WithContext(ctx).Delete(&authorizeCodeSession{}, "code = ?", code).Error +} + +func (r *sqlRepository) CreateAccessTokenSession(ctx context.Context, signature string, fositeReq fosite.Requester) error { + data, err := marshalRequest(fositeReq) + if err != nil { + return err + } + + session := accessTokenSession{ + Signature: signature, + Request: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&session).Error +} + +func (r *sqlRepository) GetAccessTokenSession(ctx context.Context, signature string, sess fosite.Session) (fosite.Requester, error) { + var session accessTokenSession + if err := r.db.WithContext(ctx).First(&session, "signature = ?", signature).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load access token session: %w", err) + } + + return unmarshalRequest(session.Request, sess) +} + +func (r *sqlRepository) DeleteAccessTokenSession(ctx context.Context, signature string) error { + return r.db.WithContext(ctx).Delete(&accessTokenSession{}, "signature = ?", signature).Error +} + +func (r *sqlRepository) CreateRefreshTokenSession(ctx context.Context, signature string, accessSignature string, req fosite.Requester) error { + data, err := marshalRequest(req) + if err != nil { + return err + } + + session := refreshTokenSession{ + Signature: signature, + AccessSignature: accessSignature, + Request: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&session).Error +} + +func (r *sqlRepository) GetRefreshTokenSession(ctx context.Context, signature string, sess fosite.Session) (fosite.Requester, error) { + var session refreshTokenSession + if err := r.db.WithContext(ctx).First(&session, "signature = ?", signature).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load refresh token session: %w", err) + } + + return unmarshalRequest(session.Request, sess) +} + +func (r *sqlRepository) DeleteRefreshTokenSession(ctx context.Context, signature string) error { + return r.db.WithContext(ctx).Delete(&refreshTokenSession{}, "signature = ?", signature).Error +} + +func (r *sqlRepository) RotateRefreshToken(ctx context.Context, requestID string, signature string) error { + var session refreshTokenSession + if err := r.db.WithContext(ctx).First(&session, "signature = ?", signature).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return fosite.ErrNotFound + } + return fmt.Errorf("failed to load refresh token session: %w", err) + } + + var req models.Request + if err := json.Unmarshal(session.Request, &req); err != nil { + return fmt.Errorf("failed to decode refresh token session: %w", err) + } + req.RotatedAt = time.Now().UTC() + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to encode refresh token session: %w", err) + } + + return r.db.WithContext(ctx). + Model(&refreshTokenSession{}). + Where("signature = ?", signature). + Update("request", data).Error +} + +func (r *sqlRepository) RevokeRefreshToken(ctx context.Context, requestID string) error { + return r.db.WithContext(ctx).Delete(&refreshTokenSession{}, "signature = ?", requestID).Error +} + +func (r *sqlRepository) RevokeAccessToken(ctx context.Context, requestID string) error { + return r.db.WithContext(ctx).Delete(&accessTokenSession{}, "signature = ?", requestID).Error +} + +func (r *sqlRepository) RegisterClient(ctx context.Context, fositeClient fosite.Client) error { + data, err := marshalClient(fositeClient) + if err != nil { + return err + } + + record := clientRecord{ + ID: fositeClient.GetID(), + Client: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&record).Error +} + +func (r *sqlRepository) GetClient(ctx context.Context, id string) (fosite.Client, error) { + var record clientRecord + if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load client: %w", err) + } + + return unmarshalClient(record.Client) +} + +func (r *sqlRepository) ClientAssertionJWTValid(ctx context.Context, jti string) error { + return errors.New("not implemented") +} + +func (r *sqlRepository) SetClientAssertionJWT(ctx context.Context, jti string, exp time.Time) error { + return errors.New("not implemented") +} + +func (r *sqlRepository) CreatePKCERequestSession(ctx context.Context, signature string, req fosite.Requester) error { + data, err := marshalRequest(req) + if err != nil { + return err + } + + session := pkceRequestSession{ + Signature: signature, + Request: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&session).Error +} + +func (r *sqlRepository) GetPKCERequestSession(ctx context.Context, signature string, sess fosite.Session) (fosite.Requester, error) { + var session pkceRequestSession + if err := r.db.WithContext(ctx).First(&session, "signature = ?", signature).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load pkce request session: %w", err) + } + + return unmarshalRequest(session.Request, sess) +} + +func (r *sqlRepository) DeletePKCERequestSession(ctx context.Context, signature string) error { + return r.db.WithContext(ctx).Delete(&pkceRequestSession{}, "signature = ?", signature).Error +} + +func (r *sqlRepository) CreateAuthorizeRequest(ctx context.Context, fositeAR fosite.AuthorizeRequester) error { + data, err := marshalAuthorizeRequest(fositeAR) + if err != nil { + return err + } + + record := authorizeRequestRecord{ + RequestID: fositeAR.GetID(), + Request: data, + } + + return r.db.WithContext(ctx). + Clauses(clause.OnConflict{UpdateAll: true}). + Create(&record).Error +} + +func (r *sqlRepository) GetAuthorizeRequest(ctx context.Context, requestID string) (fosite.AuthorizeRequester, error) { + var record authorizeRequestRecord + if err := r.db.WithContext(ctx).First(&record, "request_id = ?", requestID).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fosite.ErrNotFound + } + return nil, fmt.Errorf("failed to load authorize request: %w", err) + } + + return unmarshalAuthorizeRequest(record.Request) +} + +func (r *sqlRepository) DeleteAuthorizeRequest(ctx context.Context, requestID string) error { + return r.db.WithContext(ctx).Delete(&authorizeRequestRecord{}, "request_id = ?", requestID).Error +} + +func (r *sqlRepository) Close() error { + sqlDB, err := r.db.DB() + if err != nil { + return fmt.Errorf("failed to get sql db: %w", err) + } + return sqlDB.Close() +} + +func marshalRequest(req fosite.Requester) ([]byte, error) { + data, err := json.Marshal(models.FromFositeReq(req)) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + return data, nil +} + +func unmarshalRequest(data []byte, sess fosite.Session) (fosite.Requester, error) { + var req models.Request + if err := json.Unmarshal(data, &req); err != nil { + return nil, fmt.Errorf("failed to unmarshal request: %w", err) + } + fositeReq := req.ToFositeReq() + if sess != nil { + fositeReq.SetSession(sess) + } + return fositeReq, nil +} + +func marshalClient(client fosite.Client) ([]byte, error) { + data, err := json.Marshal(models.FromFositeClient(client)) + if err != nil { + return nil, fmt.Errorf("failed to marshal client: %w", err) + } + return data, nil +} + +func unmarshalClient(data []byte) (fosite.Client, error) { + var client models.Client + if err := json.Unmarshal(data, &client); err != nil { + return nil, fmt.Errorf("failed to unmarshal client: %w", err) + } + return client.ToFositeClient(), nil +} + +func marshalAuthorizeRequest(req fosite.AuthorizeRequester) ([]byte, error) { + data, err := json.Marshal(models.FromFositeAuthorizeRequest(req)) + if err != nil { + return nil, fmt.Errorf("failed to marshal authorize request: %w", err) + } + return data, nil +} + +func unmarshalAuthorizeRequest(data []byte) (fosite.AuthorizeRequester, error) { + var req models.AuthorizeRequest + if err := json.Unmarshal(data, &req); err != nil { + return nil, fmt.Errorf("failed to unmarshal authorize request: %w", err) + } + return req.ToFositeAuthorizeRequest(), nil +} diff --git a/pkg/repository/sql_test.go b/pkg/repository/sql_test.go new file mode 100644 index 0000000..1c1fd41 --- /dev/null +++ b/pkg/repository/sql_test.go @@ -0,0 +1,59 @@ +package repository + +import ( + "context" + "net/url" + "testing" + "time" + + "github.com/ory/fosite" +) + +func TestSQLRepositoryAccessTokenSession(t *testing.T) { + repo, err := NewSQLRepository("sqlite", "file::memory:?cache=shared") + if err != nil { + t.Fatalf("failed to create sql repository: %v", err) + } + defer repo.Close() + + ctx := context.Background() + client := &fosite.DefaultClient{ + ID: "client-1", + Secret: []byte("secret"), + RedirectURIs: []string{"https://example.com/callback"}, + } + + req := &fosite.Request{ + ID: "req-1", + RequestedAt: time.Now().UTC().Round(time.Second), + Client: client, + RequestedScope: []string{"scope.read"}, + Form: url.Values{"code": {"value"}}, + } + + if err := repo.CreateAccessTokenSession(ctx, "sig-1", req); err != nil { + t.Fatalf("CreateAccessTokenSession failed: %v", err) + } + + result, err := repo.GetAccessTokenSession(ctx, "sig-1", &fosite.DefaultSession{}) + if err != nil { + t.Fatalf("GetAccessTokenSession failed: %v", err) + } + + retrievedReq := result.(*fosite.Request) + if retrievedReq.ID != req.ID { + t.Fatalf("expected request ID %s, got %s", req.ID, retrievedReq.ID) + } + if retrievedReq.Client.GetID() != client.GetID() { + t.Fatalf("expected client ID %s, got %s", client.GetID(), retrievedReq.Client.GetID()) + } + if len(retrievedReq.RequestedScope) != 1 || retrievedReq.RequestedScope[0] != "scope.read" { + t.Fatalf("unexpected requested scope: %#v", retrievedReq.RequestedScope) + } +} + +func TestSQLRepositoryUnsupportedDriver(t *testing.T) { + if _, err := NewSQLRepository("unsupported", "dsn"); err == nil { + t.Fatalf("expected error for unsupported driver but got nil") + } +} From 8fca207ae164a0da9995d7885e4c0c3f16e53410 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Thu, 23 Oct 2025 04:26:53 +0900 Subject: [PATCH 2/4] feat(docs): add repository options for SQL backends in configuration --- docs/docs/configuration.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index df6cf2d..ab069de 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -93,6 +93,26 @@ You can use both exact matching and glob patterns for OIDC user authorization: | `--tls-listen` | `TLS_LISTEN` | `:443` | Address to listen on for TLS | | `--data-path` | `DATA_PATH` | `./data` | Path to the data directory | +### Repository Options + +| Option | Environment Variable | Default | Description | +| ----------------------- | --------------------- | ------- | --------------------------------------------------------------------------------------------------------------------- | +| `--repository-backend` | `REPOSITORY_BACKEND` | `local` | Storage backend for OAuth state. Supported values: `local` (embedded BoltDB), `sqlite`, `postgres`, or `mysql`. | +| `--repository-dsn` | `REPOSITORY_DSN` | - | Connection string passed directly to the SQL driver. Required when `--repository-backend` is `sqlite/postgres/mysql`. | + +`local` uses an embedded BoltDB file under `--data-path`. SQL backends run migrations automatically via GORM; the DSN must be valid for the chosen driver (examples below). The deprecated value `sql` is no longer accepted—select the concrete driver instead. + +```bash title="DSN examples" +# sqlite +--repository-backend sqlite --repository-dsn "file:data/mcp-auth.db?cache=shared&mode=rwc" + +# postgres +--repository-backend postgres --repository-dsn "postgres://user:pass@hostname:5432/database?sslmode=disable" + +# mysql +--repository-backend mysql --repository-dsn "user:pass@tcp(hostname:3306)/database?parseTime=true" +``` + ### Proxy Options | Option | Environment Variable | Default | Description | From 7b7928896105708fe5759e9c445be64c8c0ab13b Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Thu, 23 Oct 2025 04:27:48 +0900 Subject: [PATCH 3/4] fix(repository): remove deprecated SQL backend support in Run function --- pkg/mcp-proxy/main.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index 7fc47bb..acbc22d 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -171,8 +171,6 @@ func Run( if err != nil { return fmt.Errorf("failed to create repository: %w", err) } - case "sql": - return fmt.Errorf("repository backend 'sql' is deprecated; use sqlite, postgres, or mysql") default: return fmt.Errorf("unsupported repository backend: %s", repositoryBackend) } From 9677b9bcd3cd543abe3e03d8a2459af11b91c607 Mon Sep 17 00:00:00 2001 From: Takanori Hirano Date: Thu, 23 Oct 2025 04:31:18 +0900 Subject: [PATCH 4/4] docs: update repository options section for clarity and remove deprecated value --- docs/docs/configuration.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index ab069de..6170665 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -95,10 +95,10 @@ You can use both exact matching and glob patterns for OIDC user authorization: ### Repository Options -| Option | Environment Variable | Default | Description | -| ----------------------- | --------------------- | ------- | --------------------------------------------------------------------------------------------------------------------- | -| `--repository-backend` | `REPOSITORY_BACKEND` | `local` | Storage backend for OAuth state. Supported values: `local` (embedded BoltDB), `sqlite`, `postgres`, or `mysql`. | -| `--repository-dsn` | `REPOSITORY_DSN` | - | Connection string passed directly to the SQL driver. Required when `--repository-backend` is `sqlite/postgres/mysql`. | +| Option | Environment Variable | Default | Description | +| ---------------------- | -------------------- | ------- | --------------------------------------------------------------------------------------------------------------------- | +| `--repository-backend` | `REPOSITORY_BACKEND` | `local` | Storage backend for OAuth state. Supported values: `local` (embedded BoltDB), `sqlite`, `postgres`, or `mysql`. | +| `--repository-dsn` | `REPOSITORY_DSN` | - | Connection string passed directly to the SQL driver. Required when `--repository-backend` is `sqlite/postgres/mysql`. | `local` uses an embedded BoltDB file under `--data-path`. SQL backends run migrations automatically via GORM; the DSN must be valid for the chosen driver (examples below). The deprecated value `sql` is no longer accepted—select the concrete driver instead.