Skip to content

Commit

Permalink
Enable TLS in test Trino container
Browse files Browse the repository at this point in the history
  • Loading branch information
nineinchnick authored and wendigo committed May 3, 2024
1 parent aa08ec4 commit f1b3d95
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
8 changes: 8 additions & 0 deletions trino/etc/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@ node-scheduler.include-coordinator=true
http-server.http.port=8080
discovery-server.enabled=true
discovery.uri=http://localhost:8080

http-server.authentication.type=JWT
http-server.authentication.jwt.key-file=/etc/trino/secrets/public_key.pem
http-server.https.enabled=true
http-server.https.port=8443
http-server.authentication.allow-insecure-over-http=true
http-server.https.keystore.path=/etc/trino/secrets/certificate_with_key.pem
internal-communication.shared-secret=gotrino
1 change: 1 addition & 0 deletions trino/etc/secrets/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pem
121 changes: 121 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@ package trino

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"database/sql/driver"
"encoding/pem"
"errors"
"flag"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -55,6 +63,7 @@ var (
false,
"do not delete containers on exit",
)
tlsServer = ""
)

func TestMain(m *testing.M) {
Expand All @@ -79,6 +88,10 @@ func TestMain(m *testing.M) {
resource, ok = pool.ContainerByName(name)

if !ok {
err = generateCerts(wd + "/etc/secrets")
if err != nil {
log.Fatalf("Could not generate TLS certificates: %s", err)
}
if *trinoImageTagFlag == "" {
*trinoImageTagFlag = "latest"
}
Expand All @@ -87,6 +100,10 @@ func TestMain(m *testing.M) {
Repository: "trinodb/trino",
Tag: *trinoImageTagFlag,
Mounts: []string{wd + "/etc:/etc/trino"},
ExposedPorts: []string{
"8080/tcp",
"8443/tcp",
},
})
if err != nil {
log.Fatalf("Could not start resource: %s", err)
Expand All @@ -106,6 +123,12 @@ func TestMain(m *testing.M) {
log.Fatalf("Timed out waiting for container to get ready: %s", err)
}
*integrationServerFlag = "http://test@localhost:" + resource.GetPort("8080/tcp")
tlsServer = "https://test@localhost:" + resource.GetPort("8443/tcp")

http.DefaultTransport.(*http.Transport).TLSClientConfig, err = getTLSConfig(wd + "/etc/secrets")
if err != nil {
log.Fatalf("Failed to set the default TLS config: %s", err)
}
}

code := m.Run()
Expand All @@ -120,6 +143,104 @@ func TestMain(m *testing.M) {
os.Exit(code)
}

func generateCerts(dir string) error {
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return fmt.Errorf("failed to generate private key: %w", err)
}

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return fmt.Errorf("failed to generate serial number: %w", err)
}

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"Trino Software Foundation"},
},
DNSNames: []string{"localhost"},
NotBefore: time.Now(),
NotAfter: time.Now().Add(1 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}

privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return fmt.Errorf("unable to marshal private key: %w", err)
}
privBlock := &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}
err = writePEM(dir+"/private_key.pem", privBlock)
if err != nil {
return err
}

pubBytes, err := x509.MarshalPKIXPublicKey(&priv.PublicKey)
if err != nil {
return fmt.Errorf("unable to marshal public key: %w", err)
}
pubBlock := &pem.Block{Type: "PUBLIC KEY", Bytes: pubBytes}
err = writePEM(dir+"/public_key.pem", pubBlock)
if err != nil {
return err
}

certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return fmt.Errorf("failed to create certificate: %w", err)
}
certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}
err = writePEM(dir+"/certificate.pem", certBlock)
if err != nil {
return err
}

err = writePEM(dir+"/certificate_with_key.pem", certBlock, privBlock, pubBlock)
if err != nil {
return err
}

return nil
}

func writePEM(filename string, blocks ...*pem.Block) error {
// all files are world-readable, so they can be read inside the Trino container
out, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to open %s for writing: %w", filename, err)
}
for _, block := range blocks {
if err := pem.Encode(out, block); err != nil {
return fmt.Errorf("failed to write %s data to %s: %w", block.Type, filename, err)
}
}
if err := out.Close(); err != nil {
return fmt.Errorf("error closing %s: %w", filename, err)
}
return nil
}

func getTLSConfig(dir string) (*tls.Config, error) {
certPool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to read the system cert pool: %s", err)
}
caCertPEM, err := os.ReadFile(dir + "/certificate.pem")
if err != nil {
return nil, fmt.Errorf("failed to read the certificate: %s", err)
}
ok := certPool.AppendCertsFromPEM(caCertPEM)
if !ok {
return nil, fmt.Errorf("failed to parse the certificate: %s", err)
}
return &tls.Config{
RootCAs: certPool,
}, nil
}

// integrationOpen opens a connection to the integration test server.
func integrationOpen(t *testing.T, dsn ...string) *sql.DB {
if testing.Short() {
Expand Down

0 comments on commit f1b3d95

Please sign in to comment.