diff --git a/pkg/detectors/sqlserver/sqlserver.go b/pkg/detectors/sqlserver/sqlserver.go index a3d377de982a..5d7a4ef592d5 100644 --- a/pkg/detectors/sqlserver/sqlserver.go +++ b/pkg/detectors/sqlserver/sqlserver.go @@ -73,6 +73,7 @@ func (s Scanner) FromData(ctx context.Context, verify bool, data []byte) (result var ping = func(config msdsn.Config) (bool, error) { cleanConfig := msdsn.Config{} cleanConfig.Host = config.Host + cleanConfig.Port = config.Port cleanConfig.User = config.User cleanConfig.Password = config.Password cleanConfig.Database = config.Database diff --git a/pkg/detectors/sqlserver/sqlserver_integration_test.go b/pkg/detectors/sqlserver/sqlserver_integration_test.go index 6e7fcac3cdee..43ebc952d2ed 100644 --- a/pkg/detectors/sqlserver/sqlserver_integration_test.go +++ b/pkg/detectors/sqlserver/sqlserver_integration_test.go @@ -4,22 +4,42 @@ package sqlserver import ( - "bytes" "context" - "errors" - "os/exec" - "strings" + "fmt" + "net/url" "testing" - "time" + "github.com/brianvoe/gofakeit/v7" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mssql" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" ) func TestSQLServerIntegration_FromChunk(t *testing.T) { + ctx := context.Background() + + password := gofakeit.Password(true, true, true, false, false, 10) + + container, err := mssql.RunContainer( + ctx, + testcontainers.WithImage("mcr.microsoft.com/azure-sql-edge"), + mssql.WithAcceptEULA(), + mssql.WithPassword(password)) + if err != nil { + t.Fatalf("could not start container: %v", err) + } + + defer container.Terminate(ctx) + + port, err := container.MappedPort(ctx, "1433") + if err != nil { + t.Fatalf("could get mapped port: %v", err) + } + type args struct { ctx context.Context data []byte @@ -37,17 +57,22 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { name: "found, verified", s: Scanner{}, args: args{ - ctx: context.Background(), - data: []byte("Server=localhost;Initial Catalog=master;User ID=sa;Password=P@ssw0rd!;Persist Security Info=true;MultipleActiveResultSets=true;"), + ctx: context.Background(), + data: []byte(fmt.Sprintf("Server=localhost;Port=%s;Initial Catalog=master;User ID=sa;Password=%s;Persist Security Info=true;MultipleActiveResultSets=true;", + port.Port(), + password)), verify: true, }, want: []detectors.Result{ { DetectorType: detectorspb.DetectorType_SQLServer, - Raw: []byte("P@ssw0rd!"), - RawV2: []byte("sqlserver://sa:P%40ssw0rd%21@localhost?database=master&disableRetry=false"), - Redacted: "sqlserver://sa:********@localhost?database=master&disableRetry=false", - Verified: true, + Raw: []byte(password), + RawV2: []byte(urlEncode(fmt.Sprintf("sqlserver://sa:%s@localhost:%s?database=master&dial+timeout=15&disableretry=false", + password, + port.Port()))), + Redacted: fmt.Sprintf("sqlserver://sa:********@localhost:%s?database=master&dial+timeout=15&disableretry=false", + port.Port()), + Verified: true, }, }, wantErr: false, @@ -57,16 +82,18 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { s: Scanner{}, args: args{ ctx: context.Background(), - data: []byte("Server=localhost;User ID=sa;Password=123"), + data: []byte(fmt.Sprintf("Server=localhost;Port=%s;User ID=sa;Password=123", port.Port())), verify: true, }, want: []detectors.Result{ { DetectorType: detectorspb.DetectorType_SQLServer, Raw: []byte("123"), - RawV2: []byte("sqlserver://sa:123@localhost?disableRetry=false"), - Redacted: "sqlserver://sa:********@localhost?disableRetry=false", - Verified: false, + RawV2: []byte(fmt.Sprintf("sqlserver://sa:123@localhost:%s?dial+timeout=15&disableretry=false", + port.Port())), + Redacted: fmt.Sprintf("sqlserver://sa:********@localhost:%s?dial+timeout=15&disableretry=false", + port.Port()), + Verified: false, }, }, wantErr: false, @@ -76,7 +103,7 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { s: Scanner{}, args: args{ ctx: context.Background(), - data: []byte(``), + data: []byte(``), verify: true, }, want: nil, @@ -86,17 +113,22 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { name: "found, verified, in XML", s: Scanner{}, args: args{ - ctx: context.Background(), - data: []byte(``), + ctx: context.Background(), + data: []byte(fmt.Sprintf(``, + port.Port(), + password)), verify: true, }, want: []detectors.Result{ { DetectorType: detectorspb.DetectorType_SQLServer, - Redacted: "sqlserver://sa:********@localhost?database=master&disableRetry=false", - Raw: []byte("P@ssw0rd!"), - RawV2: []byte("sqlserver://sa:P%40ssw0rd%21@localhost?database=master&disableRetry=false"), - Verified: true, + Redacted: fmt.Sprintf("sqlserver://sa:********@localhost:%s?database=master&dial+timeout=15&disableretry=false", + port.Port()), + Raw: []byte(password), + RawV2: []byte(urlEncode(fmt.Sprintf("sqlserver://sa:%s@localhost:%s?database=master&dial+timeout=15&disableretry=false", + password, + port.Port()))), + Verified: true, }, }, wantErr: false, @@ -113,8 +145,8 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { { DetectorType: detectorspb.DetectorType_SQLServer, Raw: []byte("P@ssw0rd!"), - RawV2: []byte("sqlserver://sa:P%40ssw0rd%21@unreachablehost?database=master&disableRetry=false"), - Redacted: "sqlserver://sa:********@unreachablehost?database=master&disableRetry=false", + RawV2: []byte("sqlserver://sa:P%40ssw0rd%21@unreachablehost?database=master&dial+timeout=15&disableretry=false"), + Redacted: "sqlserver://sa:********@unreachablehost?database=master&dial+timeout=15&disableretry=false", Verified: false, }, }, @@ -134,11 +166,6 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { }, } - if err := startSqlServer(); err != nil { - t.Fatalf("could not start sql server for integration testing: %v", err) - } - defer stopSqlServer() - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := Scanner{} @@ -163,47 +190,7 @@ func TestSQLServerIntegration_FromChunk(t *testing.T) { } } -var sqlServerDockerHash string - -func dockerLogLine(hash string, needle string) chan struct{} { - ch := make(chan struct{}, 1) - go func() { - for { - out, err := exec.Command("docker", "logs", hash).CombinedOutput() - if err != nil { - panic(err) - } - if strings.Contains(string(out), needle) { - ch <- struct{}{} - return - } - time.Sleep(1 * time.Second) - } - }() - return ch -} - -func startSqlServer() error { - cmd := exec.Command( - "docker", "run", "--rm", "-p", "1433:1433", - "-e", "ACCEPT_EULA=1", - "-e", "MSSQL_SA_PASSWORD=P@ssw0rd!", - "-d", "mcr.microsoft.com/azure-sql-edge", - ) - out, err := cmd.Output() - if err != nil { - return err - } - sqlServerDockerHash = string(bytes.TrimSpace(out)) - select { - case <-dockerLogLine(sqlServerDockerHash, "EdgeTelemetry starting up"): - return nil - case <-time.After(30 * time.Second): - stopSqlServer() - return errors.New("timeout waiting for sql server database to be ready") - } -} - -func stopSqlServer() { - exec.Command("docker", "kill", sqlServerDockerHash).Run() +func urlEncode(s string) string { + parsed, _ := url.Parse(s) + return parsed.String() }