diff --git a/README.md b/README.md index 6927576..ab9f4b5 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ That's it! Your HTTP endpoint is now available at `https://{your-domain}/mcp`. - stdio (when a command is specified): MCP endpoint is https://{your-domain}/mcp. - SSE/HTTP (when a URL is specified): MCP endpoint uses the backend’s original path (no conversion). -> Don't want the proxy to manage TLS? Add `--no-auto-tls` so you can terminate TLS elsewhere or keep the backend on plain HTTP. +> Already have certificates? Pass `--tls-cert-file` and `--tls-key-file` instead of `--tls-accept-tos`. ## Why not MCP Gateway? diff --git a/docs/docs/configuration.md b/docs/docs/configuration.md index 6170665..005c032 100644 --- a/docs/docs/configuration.md +++ b/docs/docs/configuration.md @@ -16,12 +16,14 @@ Complete reference for all MCP Auth Proxy configuration options. ### TLS Options -| Option | Environment Variable | Default | Description | -| --------------------- | -------------------- | ------------------------------------------------ | ----------------------------------------------------- | -| `--no-auto-tls` | `NO_AUTO_TLS` | `false` | Disable automatic TLS host detection from externalURL | -| `--tls-accept-tos` | `TLS_ACCEPT_TOS` | `false` | Accept TLS terms of service | -| `--tls-directory-url` | `TLS_DIRECTORY_URL` | `https://acme-v02.api.letsencrypt.org/directory` | ACME directory URL for TLS certificates | -| `--tls-host` | `TLS_HOST` | - | Host name for TLS | +| Option | Environment Variable | Default | Description | +| --------------------- | -------------------- | ------------------------------------------------ | -------------------------------------------------------------------------------------------------- | +| `--no-auto-tls` | `NO_AUTO_TLS` | `false` | Disable automatic TLS host detection from externalURL (ignored when `--tls-cert-file` is provided) | +| `--tls-accept-tos` | `TLS_ACCEPT_TOS` | `false` | Accept TLS terms of service | +| `--tls-directory-url` | `TLS_DIRECTORY_URL` | `https://acme-v02.api.letsencrypt.org/directory` | ACME directory URL for TLS certificates | +| `--tls-host` | `TLS_HOST` | - | Host name used for automatic TLS certificate provisioning | +| `--tls-cert-file` | `TLS_CERT_FILE` | - | Path to PEM-encoded TLS certificate served directly by the proxy (auto-reloads on file changes) | +| `--tls-key-file` | `TLS_KEY_FILE` | - | Path to PEM-encoded TLS private key (requires `--tls-cert-file`, auto-reloads on file changes) | ### Authentication Options diff --git a/docs/docs/quickstart.md b/docs/docs/quickstart.md index 4a21ab9..d08128a 100644 --- a/docs/docs/quickstart.md +++ b/docs/docs/quickstart.md @@ -79,10 +79,11 @@ For URL-based MCP servers: ### TLS Configuration -MCP Auth Proxy automatically handles HTTPS certificates: +MCP Auth Proxy can automatically issue certificates or serve an existing pair: - `--tls-accept-tos`: Accept Let's Encrypt terms of service -- `--no-auto-tls`: Disable automatic TLS (use with TLS reverse proxy) +- `--tls-cert-file` / `--tls-key-file`: Serve the provided PEM certificate and key with automatic reload when files change (overrides `--no-auto-tls`) +- `--no-auto-tls`: Disable automatic TLS (use with TLS reverse proxy or custom certificate) ## Accessing Your Server diff --git a/main.go b/main.go index b632984..c009784 100644 --- a/main.go +++ b/main.go @@ -71,6 +71,8 @@ func main() { var tlsHost string var tlsDirectoryURL string var tlsAcceptTOS bool + var tlsCertFile string + var tlsKeyFile string var dataPath string var repositoryBackend string var repositoryDSN string @@ -176,10 +178,12 @@ func main() { if err := mcpproxy.Run( listen, tlsListen, - !noAutoTLS, + (!noAutoTLS) || tlsCertFile != "" || tlsKeyFile != "", tlsHost, tlsDirectoryURL, tlsAcceptTOS, + tlsCertFile, + tlsKeyFile, dataPath, repositoryBackend, repositoryDSN, @@ -216,9 +220,11 @@ func main() { rootCmd.Flags().StringVar(&listen, "listen", getEnvWithDefault("LISTEN", ":80"), "Address to listen on") rootCmd.Flags().StringVar(&tlsListen, "tls-listen", getEnvWithDefault("TLS_LISTEN", ":443"), "Address to listen on for TLS") rootCmd.Flags().BoolVar(&noAutoTLS, "no-auto-tls", getEnvBoolWithDefault("NO_AUTO_TLS", false), "Disable automatic TLS host detection from externalURL") - rootCmd.Flags().StringVarP(&tlsHost, "tls-host", "H", getEnvWithDefault("TLS_HOST", ""), "Host name for TLS") + rootCmd.Flags().StringVarP(&tlsHost, "tls-host", "H", getEnvWithDefault("TLS_HOST", ""), "Host name for automatic TLS certificate provisioning") rootCmd.Flags().StringVar(&tlsDirectoryURL, "tls-directory-url", getEnvWithDefault("TLS_DIRECTORY_URL", "https://acme-v02.api.letsencrypt.org/directory"), "ACME directory URL for TLS certificates") rootCmd.Flags().BoolVar(&tlsAcceptTOS, "tls-accept-tos", getEnvBoolWithDefault("TLS_ACCEPT_TOS", false), "Accept TLS terms of service") + rootCmd.Flags().StringVar(&tlsCertFile, "tls-cert-file", getEnvWithDefault("TLS_CERT_FILE", ""), "Path to TLS certificate file (PEM). Requires --tls-key-file") + rootCmd.Flags().StringVar(&tlsKeyFile, "tls-key-file", getEnvWithDefault("TLS_KEY_FILE", ""), "Path to TLS private key file (PEM). Requires --tls-cert-file") rootCmd.Flags().StringVarP(&dataPath, "data-path", "d", getEnvWithDefault("DATA_PATH", "./data"), "Path to the data directory") rootCmd.Flags().StringVar(&repositoryBackend, "repository-backend", getEnvWithDefault("REPOSITORY_BACKEND", "local"), "Repository backend to use: local, sqlite, postgres, or mysql") rootCmd.Flags().StringVar(&repositoryDSN, "repository-dsn", getEnvWithDefault("REPOSITORY_DSN", ""), "DSN passed directly to the SQL driver (required when repository-backend is sqlite/postgres/mysql)") diff --git a/pkg/mcp-proxy/main.go b/pkg/mcp-proxy/main.go index acbc22d..cb4d4ed 100644 --- a/pkg/mcp-proxy/main.go +++ b/pkg/mcp-proxy/main.go @@ -2,6 +2,7 @@ package mcpproxy import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -23,6 +24,7 @@ import ( "github.com/sigbit/mcp-auth-proxy/pkg/idp" "github.com/sigbit/mcp-auth-proxy/pkg/proxy" "github.com/sigbit/mcp-auth-proxy/pkg/repository" + "github.com/sigbit/mcp-auth-proxy/pkg/tlsreload" "github.com/sigbit/mcp-auth-proxy/pkg/utils" "go.uber.org/zap" "golang.org/x/crypto/acme" @@ -39,6 +41,8 @@ func Run( tlsHost string, tlsDirectoryURL string, tlsAcceptTOS bool, + tlsCertFile string, + tlsKeyFile string, dataPath string, repositoryBackend string, repositoryDSN string, @@ -78,6 +82,20 @@ func Run( return fmt.Errorf("external URL must not have a path, got: %s", parsedExternalURL.Path) } + if (tlsCertFile == "") != (tlsKeyFile == "") { + return fmt.Errorf("both TLS certificate and key files must be provided together") + } + var manualTLS bool + if tlsCertFile != "" && tlsKeyFile != "" { + manualTLS = true + } + if manualTLS && tlsHost != "" { + return fmt.Errorf("tlsHost cannot be used when TLS certificate and key files are provided") + } + if !manualTLS && !autoTLS && tlsHost != "" { + return fmt.Errorf("tlsHost requires automatic TLS; remove noAutoTLS or provide certificate files instead") + } + secret, err := utils.LoadOrGenerateSecret(path.Join(dataPath, "secret")) if err != nil { return fmt.Errorf("failed to load or generate secret: %w", err) @@ -271,7 +289,7 @@ func Run( proxyRouter.SetupRoutes(router) var tlsHostDetected bool - if autoTLS && + if autoTLS && !manualTLS && tlsHost == "" && parsedExternalURL.Scheme == "https" && parsedExternalURL.Host != "localhost" { @@ -284,13 +302,75 @@ func Run( errs := []error{} lock := sync.Mutex{} - if tlsHost != "" { + if manualTLS { + certReloader, err := tlsreload.NewFileReloader(tlsCertFile, tlsKeyFile, logger) + if err != nil { + return fmt.Errorf("failed to prepare TLS certificate reloader: %w", err) + } + + logger.Info("Starting server with provided TLS certificate") + httpServer := &http.Server{ + Addr: listen, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host := r.Host + if host == "" { + host = r.URL.Host + } + target := "https://" + host + r.RequestURI + http.Redirect(w, r, target, http.StatusMovedPermanently) + }), + } + httpsServer := &http.Server{ + Addr: tlsListen, + Handler: router, + TLSConfig: &tls.Config{GetCertificate: certReloader.GetCertificate}, + } + wg.Add(1) + go func() { + defer wg.Done() + err := httpServer.ListenAndServe() + if err != nil && !errors.Is(err, http.ErrServerClosed) { + lock.Lock() + errs = append(errs, err) + lock.Unlock() + } + logger.Debug("HTTP server closed") + exit <- struct{}{} + }() + go func() { + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout) + defer shutdownCancel() + if shutdownErr := httpServer.Shutdown(shutdownCtx); shutdownErr != nil { + logger.Warn("HTTP server shutdown error", zap.Error(shutdownErr)) + } + }() + wg.Add(1) + go func() { + defer wg.Done() + err := httpsServer.ListenAndServeTLS("", "") + if err != nil && !errors.Is(err, http.ErrServerClosed) { + lock.Lock() + errs = append(errs, err) + lock.Unlock() + } + logger.Debug("HTTPS server closed") + exit <- struct{}{} + }() + go func() { + <-ctx.Done() + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), ServerShutdownTimeout) + defer shutdownCancel() + if shutdownErr := httpsServer.Shutdown(shutdownCtx); shutdownErr != nil { + logger.Warn("HTTPS server shutdown error", zap.Error(shutdownErr)) + } + }() + } else if tlsHost != "" { if !tlsAcceptTOS { if tlsHostDetected { return errors.New("TLS host is auto-detected, but tlsAcceptTOS is not set to true. Please agree to the TOS or set noAutoTLS to true") - } else { - return errors.New("TLS is enabled, but tlsAcceptTOS is not set to true. Please explicitly agree to the TOS") } + return errors.New("TLS is enabled, but tlsAcceptTOS is not set to true. Please explicitly agree to the TOS") } m := autocert.Manager{ @@ -401,7 +481,7 @@ func Run( }() } - if tlsHost != "" { + if manualTLS || tlsHost != "" { logger.Info("Starting server", zap.Strings("listen", []string{listen, tlsListen})) } else { logger.Info("Starting server", zap.Strings("listen", []string{listen})) diff --git a/pkg/tlsreload/file_reloader.go b/pkg/tlsreload/file_reloader.go new file mode 100644 index 0000000..c92f831 --- /dev/null +++ b/pkg/tlsreload/file_reloader.go @@ -0,0 +1,118 @@ +package tlsreload + +import ( + "crypto/tls" + "fmt" + "os" + "sync" + "sync/atomic" + "time" + + "go.uber.org/zap" +) + +type fileState struct { + modTime time.Time + size int64 +} + +// FileReloader watches certificate and key files and reloads them when they change. +type FileReloader struct { + certPath string + keyPath string + logger *zap.Logger + + cert atomic.Value // *tls.Certificate + mu sync.Mutex + certState fileState + keyState fileState +} + +// NewFileReloader loads the initial certificate/key pair and prepares for reloads. +func NewFileReloader(certPath, keyPath string, logger *zap.Logger) (*FileReloader, error) { + certInfo, err := os.Stat(certPath) + if err != nil { + return nil, fmt.Errorf("stat cert file: %w", err) + } + keyInfo, err := os.Stat(keyPath) + if err != nil { + return nil, fmt.Errorf("stat key file: %w", err) + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("load key pair: %w", err) + } + + reloader := &FileReloader{ + certPath: certPath, + keyPath: keyPath, + logger: logger, + certState: fileState{ + modTime: certInfo.ModTime(), + size: certInfo.Size(), + }, + keyState: fileState{ + modTime: keyInfo.ModTime(), + size: keyInfo.Size(), + }, + } + reloader.cert.Store(&cert) + return reloader, nil +} + +func (r *FileReloader) maybeReload() error { + r.mu.Lock() + defer r.mu.Unlock() + + certInfo, err := os.Stat(r.certPath) + if err != nil { + return fmt.Errorf("stat cert file: %w", err) + } + keyInfo, err := os.Stat(r.keyPath) + if err != nil { + return fmt.Errorf("stat key file: %w", err) + } + + newCertState := fileState{modTime: certInfo.ModTime(), size: certInfo.Size()} + newKeyState := fileState{modTime: keyInfo.ModTime(), size: keyInfo.Size()} + + if newCertState.modTime.Equal(r.certState.modTime) && newCertState.size == r.certState.size && + newKeyState.modTime.Equal(r.keyState.modTime) && newKeyState.size == r.keyState.size { + return nil + } + + cert, err := tls.LoadX509KeyPair(r.certPath, r.keyPath) + if err != nil { + return fmt.Errorf("load key pair: %w", err) + } + + r.cert.Store(&cert) + r.certState = newCertState + r.keyState = newKeyState + + if r.logger != nil { + r.logger.Info("Reloaded TLS certificate files", + zap.String("certFile", r.certPath), + zap.String("keyFile", r.keyPath), + zap.Time("certModTime", newCertState.modTime), + zap.Time("keyModTime", newKeyState.modTime), + ) + } + return nil +} + +// GetCertificate reloads certificate/key as needed and returns the current pair. +func (r *FileReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + if err := r.maybeReload(); err != nil { + if r.logger != nil { + r.logger.Warn("Failed to reload TLS certificate", zap.Error(err)) + } + } + + value := r.cert.Load() + if value == nil { + return nil, fmt.Errorf("no TLS certificate loaded") + } + return value.(*tls.Certificate), nil +}