Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Narrow Postgres detector to only look for URIs #2314

Merged
merged 24 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 144 additions & 172 deletions pkg/detectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,63 +3,133 @@ package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"net"
"regexp"
"strconv"
"strings"
"time"

_ "github.com/lib/pq" // PostgreSQL driver
"github.com/lib/pq"
"github.com/trufflesecurity/trufflehog/v3/pkg/detectors"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"
)

const (
defaultPort = "5432"
defaultHost = "localhost"

pg_connect_timeout = "connect_timeout"
pg_dbname = "dbname"
pg_host = "host"
pg_password = "password"
pg_port = "port"
pg_requiressl = "requiressl"
pg_sslmode = "sslmode"
pg_sslmode_allow = "allow"
pg_sslmode_disable = "disable"
pg_sslmode_prefer = "prefer"
pg_sslmode_require = "require"
pg_user = "user"
)

// This detector currently only finds Postgres connection string URIs
// (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING-URIS) When it finds one, it uses
// pq.ParseURI to normalize this into space-separated key-value pair Postgres connection string, and then uses a regular
// expression to transform this connection string into a parameters map. This parameters map is manipulated prior to
// verification, which operates by transforming the map back into a space-separated kvp connection string. This is kind
// of clunky overall, but it has the benefit of preserving the connection string as a map when it needs to be modified,
// which is much nicer than having to patch a space-separated string of kvps.

// Multi-host connection string URIs are currently not supported because pq.ParseURI doesn't parse them correctly. If we
// happen to run into a case where this matters we can address it then.
var (
_ detectors.Detector = (*Scanner)(nil)
uriPattern = regexp.MustCompile(`\b(?i)postgresql://[\S]+\b`)
hostnamePattern = regexp.MustCompile(`(?i)(?:host|server|address).{0,40}?(\b[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*\b)`)
portPattern = regexp.MustCompile(`(?i)(?:port|p).{0,40}?(\b[0-9]{1,5}\b)`)
usernamePattern = regexp.MustCompile(`(?im)(?:user|usr)\S{0,40}?[:=\s]{1,3}[ '"=]{0,1}([^:'"\s]{4,40})`)
passwordPattern = regexp.MustCompile(`(?im)(?:pass)\S{0,40}?[:=\s]{1,3}[ '"=]{0,1}([^:'"\s]{4,40})`)
_ detectors.Detector = (*Scanner)(nil)
uriPattern = regexp.MustCompile(`\b(?i)postgres(?:ql)?://\S+\b`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't find things like "postgresql+asyncpg://postgres:secret@localhost/testdb". Not sure what the + syntax is called.
https://github.com/agronholm/apscheduler/blob/b4ceea0ed300545a27bb8dbbbfb382a46d8ea90f/examples/web/asgi_starlette.py#L59

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I was just going off the official connection URI docs, and I don't see that mentioned anywhere. Do you have any references you can point me to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears to be a Python-specific quirk:

dialect+driver://username:password@host:port/database

https://www.tutorialspoint.com/sqlalchemy/sqlalchemy_dialects.htm

Copy link
Contributor Author

@rosecodym rosecodym Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, ok. For this pass we're aiming for correctness over completeness, because the previous implementation was yielding a ton of false positives. The eventual plan is to add other types of Postgres secrets (non-URI connection strings, at the very least), so we can add SQLAlchemy-extended URIs to that list.

connStrPartPattern = regexp.MustCompile(`([[:alpha:]]+)='(.+?)' ?`)
)

type Scanner struct{}
type Scanner struct {
detectLoopback bool // Automated tests run against localhost, but we want to ignore those results in the wild
}

func (s Scanner) Keywords() []string {
return []string{"postgres", "psql", "pghost"}
return []string{"postgres"}
}

func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]detectors.Result, error) {
var results []detectors.Result
var pgURLs []url.URL
pgURLs = append(pgURLs, findUriMatches(string(data)))
pgURLs = append(pgURLs, findComponentMatches(verify, string(data))...)
candidateParamSets := findUriMatches(data)

for _, params := range candidateParamSets {
user, ok := params[pg_user]
if !ok {
continue
}

for _, pgURL := range pgURLs {
if pgURL.User == nil {
password, ok := params[pg_password]
if !ok {
continue
}
username := pgURL.User.Username()
password, _ := pgURL.User.Password()
hostport := pgURL.Host

host, ok := params[pg_host]
if !ok {
continue
}
if !s.detectLoopback {
if host == "localhost" {
continue
}
if ip := net.ParseIP(host); ip != nil && ip.IsLoopback() {
continue
}
}

port, ok := params[pg_port]
if !ok {
port = defaultPort
params[pg_port] = port
}

raw := []byte(fmt.Sprintf("postgresql://%s:%s@%s:%s", user, password, host, port))

result := detectors.Result{
DetectorType: detectorspb.DetectorType_Postgres,
Raw: []byte(hostport + username + password),
RawV2: []byte(hostport + username + password),
Raw: raw,
RawV2: raw,
}

// We don't need to normalize the (deprecated) requiressl option into the (up-to-date) sslmode option - pq can
// do it for us - but we will do it anyway here so that when we later capture sslmode into ExtraData we will
// capture it post-normalization. (The detector's behavior is undefined for candidate secrets that have both
// requiressl and sslmode set.)
if requiressl := params[pg_requiressl]; requiressl == "0" {
params[pg_sslmode] = pg_sslmode_prefer
} else if requiressl == "1" {
params[pg_sslmode] = pg_sslmode_require
}

if verify {
timeoutInSeconds := getDeadlineInSeconds(ctx)
isVerified, verificationErr := verifyPostgres(&pgURL, timeoutInSeconds)
// pq appears to ignore the context deadline, so we copy any timeout that's been set into the connection
// parameters themselves.
if timeout := getDeadlineInSeconds(ctx); timeout != 0 {
params[pg_connect_timeout] = strconv.Itoa(timeout)
}

isVerified, verificationErr := verifyPostgres(params)
result.Verified = isVerified
result.SetVerificationError(verificationErr, password)
}

// We gather SSL information into ExtraData in case it's useful for later reporting.
sslmode := params[pg_sslmode]
if sslmode == "" {
sslmode = "<unset>"
}
result.ExtraData = map[string]string{
pg_sslmode: sslmode,
}

if !result.Verified && detectors.IsKnownFalsePositive(password, detectors.DefaultFalsePositives, true) {
continue
}
Expand All @@ -69,184 +139,86 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) ([]dete
return results, nil
}

func getDeadlineInSeconds(ctx context.Context) int {
deadline, ok := ctx.Deadline()
if !ok {
// Context does not have a deadline
return 0
}

duration := time.Until(deadline)
return int(duration.Seconds())
}

func findUriMatches(dataStr string) url.URL {
var pgURL url.URL
for _, uri := range uriPattern.FindAllString(dataStr, -1) {
pgURL, err := url.Parse(uri)
func findUriMatches(data []byte) []map[string]string {
var matches []map[string]string
for _, uri := range uriPattern.FindAll(data, -1) {
connStr, err := pq.ParseURL(string(uri))
if err != nil {
continue
}
if pgURL.User != nil {
return *pgURL
}
}
return pgURL
}

// check if postgres is running
func postgresRunning(hostname, port string) bool {
connStr := fmt.Sprintf("host=%s port=%s sslmode=disable", hostname, port)
db, err := sql.Open("postgres", connStr)
if err != nil {
return false
}
defer db.Close()
return true
}

func findComponentMatches(verify bool, dataStr string) []url.URL {
usernameMatches := usernamePattern.FindAllStringSubmatch(dataStr, -1)
passwordMatches := passwordPattern.FindAllStringSubmatch(dataStr, -1)
hostnameMatches := hostnamePattern.FindAllStringSubmatch(dataStr, -1)
portMatches := portPattern.FindAllStringSubmatch(dataStr, -1)

var pgURLs []url.URL

hosts := findHosts(verify, hostnameMatches, portMatches)

for _, username := range dedupMatches(usernameMatches) {
for _, password := range dedupMatches(passwordMatches) {
for _, host := range hosts {
hostname, port := strings.Split(host, ":")[0], strings.Split(host, ":")[1]
if combinedLength := len(username) + len(password) + len(hostname); combinedLength > 255 {
continue
}
postgresURL := url.URL{
Scheme: "postgresql",
User: url.UserPassword(username, password),
Host: fmt.Sprintf("%s:%s", hostname, port),
}
pgURLs = append(pgURLs, postgresURL)
}
parts := connStrPartPattern.FindAllStringSubmatch(connStr, -1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we construct params after parts, we can pre-alloc the map. it'll make it go SOOOO much faster :)

params := make(map[string]string, len(parts))

params := make(map[string]string, len(parts))
for _, part := range parts {
params[part[1]] = part[2]
}
}
return pgURLs
}

// if verification is turned on, and we can confirm that postgres is running on at least one host,
// return only hosts where it's running. otherwise return all hosts.
func findHosts(verify bool, hostnameMatches, portMatches [][]string) []string {
hostnames := dedupMatches(hostnameMatches)
ports := dedupMatches(portMatches)
var hosts []string

if len(hostnames) < 1 {
hostnames = append(hostnames, defaultHost)
}

if len(ports) < 1 {
ports = append(ports, defaultPort)
}

for _, hostname := range hostnames {
for _, port := range ports {
hosts = append(hosts, fmt.Sprintf("%s:%s", hostname, port))
}
matches = append(matches, params)
}
return matches
}

if verify {
var verifiedHosts []string
for _, host := range hosts {
parts := strings.Split(host, ":")
hostname, port := parts[0], parts[1]
if postgresRunning(hostname, port) {
verifiedHosts = append(verifiedHosts, host)
}
}
if len(verifiedHosts) > 0 {
return verifiedHosts
}
func getDeadlineInSeconds(ctx context.Context) int {
deadline, ok := ctx.Deadline()
if !ok {
// Context does not have a deadline
return 0
}

return hosts
duration := time.Until(deadline)
return int(duration.Seconds())
}

// deduplicate matches in order to reduce the number of verification requests
func dedupMatches(matches [][]string) []string {
setOfMatches := make(map[string]struct{})
for _, match := range matches {
if len(match) > 1 {
setOfMatches[match[1]] = struct{}{}
}
func isErrorDatabaseNotFound(err error, dbName string) bool {
if dbName == "" {
dbName = "postgres"
}
var results []string
for match := range setOfMatches {
results = append(results, match)
}
return results
missingDbErrorText := fmt.Sprintf("database \"%s\" does not exist", dbName)

return strings.Contains(err.Error(), missingDbErrorText)
}

func verifyPostgres(pgURL *url.URL, timeoutInSeconds int) (bool, error) {
if pgURL.User == nil {
return false, nil
}
username := pgURL.User.Username()
password, _ := pgURL.User.Password()
func verifyPostgres(params map[string]string) (bool, error) {
if sslmode := params[pg_sslmode]; sslmode == pg_sslmode_allow || sslmode == pg_sslmode_prefer {
// pq doesn't support 'allow' or 'prefer'. If we find either of them, we'll just ignore it. This will trigger
// the same logic that is run if no sslmode is set at all (which mimics 'prefer', which is the default).
delete(params, pg_sslmode)

hostname, port := pgURL.Hostname(), pgURL.Port()
if hostname == "" {
hostname = defaultHost
}
if port == "" {
port = defaultPort
// We still want to save the original sslmode in ExtraData, so we'll re-add it before returning.
defer func() {
params[pg_sslmode] = sslmode
}()
}

sslmode := determineSSLMode(pgURL)

connStr := fmt.Sprintf("user=%s password=%s host=%s port=%s sslmode=%s", username, password, hostname, port, sslmode)
if timeoutInSeconds > 0 {
connStr = fmt.Sprintf("%s connect_timeout=%d", connStr, timeoutInSeconds)
var connStr string
for key, value := range params {
connStr += fmt.Sprintf("%s='%s'", key, value)
}

db, err := sql.Open("postgres", connStr)
if err != nil {
if strings.Contains(err.Error(), "connection refused") {
// inactive host
return false, nil
}
return false, err
}
defer db.Close()

ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()

err = db.PingContext(ctx)
if err == nil {
err = db.Ping()
switch {
case err == nil:
return true, nil
} else if strings.Contains(err.Error(), "password authentication failed") || // incorrect username or password
strings.Contains(err.Error(), "connection refused") { // inactive host
case strings.Contains(err.Error(), "password authentication failed"):
return false, nil
case errors.Is(err, pq.ErrSSLNotSupported) && params[pg_sslmode] == "":
// If the sslmode is unset, then either it was unset in the candidate secret, or we've intentionally unset it
// because it was specified as 'allow' or 'prefer', neither of which pq supports. In all of these cases, non-SSL
// connections are acceptable, so now we try a connection without SSL.
params[pg_sslmode] = pg_sslmode_disable
defer delete(params, pg_sslmode) // We want to return with the original params map intact (for ExtraData)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nifty!

return verifyPostgres(params)
case isErrorDatabaseNotFound(err, params[pg_dbname]):
return true, nil // If we know this, we were able to authenticate
default:
return false, err
}

// if ssl is not enabled, manually fall-back to sslmode=disable
if strings.Contains(err.Error(), "SSL is not enabled on the server") {
pgURL.RawQuery = fmt.Sprintf("sslmode=%s", "disable")
return verifyPostgres(pgURL, timeoutInSeconds)
}
return false, err
}

func determineSSLMode(pgURL *url.URL) string {
// default ssl mode is "prefer" per https://www.postgresql.org/docs/current/libpq-ssl.html
// but is currently not implemented in the driver per https://github.com/lib/pq/issues/1006
// default for the driver is "require". ideally we would use "allow" but that is also not supported by the driver.
sslmode := "require"
if sslQuery, ok := pgURL.Query()["sslmode"]; ok && len(sslQuery) > 0 {
sslmode = sslQuery[0]
}
return sslmode
}

func (s Scanner) Type() detectorspb.DetectorType {
Expand Down
Loading
Loading