Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add default ports and fallback behavior for SSL and TLS #170

Merged
merged 4 commits into from Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 92 additions & 4 deletions client.go
Expand Up @@ -20,9 +20,15 @@ import (

// Defaults
const (
// DefaultPort is the default connection port cto the SMTP server
// DefaultPort is the default connection port to the SMTP server
DefaultPort = 25

// DefaultPortSSL is the default connection port for SSL/TLS to the SMTP server
DefaultPortSSL = 465

// DefaultPortTLS is the default connection port for STARTTLS to the SMTP server
DefaultPortTLS = 587

// DefaultTimeout is the default connection timeout
DefaultTimeout = time.Second * 15

Expand Down Expand Up @@ -105,14 +111,15 @@ type Client struct {
// HELO/EHLO string for the greeting the target SMTP server
helo string

// Hostname of the target SMTP server cto connect cto
// Hostname of the target SMTP server to connect to
host string

// pass is the corresponding SMTP AUTH password
pass string

// Port of the SMTP server cto connect cto
port int
// Port of the SMTP server to connect to
port int
fallbackPort int

// sa is a pointer to smtp.Auth
sa smtp.Auth
Expand Down Expand Up @@ -246,13 +253,27 @@ func WithTimeout(t time.Duration) Option {
}

// WithSSL tells the client to use a SSL/TLS connection
//
// Deprecated: use WithSSLPort instead.
func WithSSL() Option {
return func(c *Client) error {
c.ssl = true
return nil
}
}

// WithSSLPort tells the client to use a SSL/TLS connection.
// It automatically sets the port to 465.
//
// When the SSL connection fails and fallback is set to true,
// the client will attempt to connect on port 25 using plaintext.
func WithSSLPort(fb bool) Option {
return func(c *Client) error {
c.SetSSLPort(true, fb)
return nil
}
}

// WithDebugLog tells the client to log incoming and outgoing messages of the SMTP client
// to StdErr
func WithDebugLog() Option {
Expand Down Expand Up @@ -282,13 +303,29 @@ func WithHELO(h string) Option {
}

// WithTLSPolicy tells the client to use the provided TLSPolicy
//
// Deprecated: use WithTLSPortPolicy instead.
func WithTLSPolicy(p TLSPolicy) Option {
return func(c *Client) error {
c.tlspolicy = p
return nil
}
}

// WithTLSPortPolicy tells the client to use the provided TLSPolicy,
// The correct port is automatically set.
//
// Port 587 is used for TLSMandatory and TLSOpportunistic.
// If the connection fails with TLSOpportunistic,
// a plaintext connection is attempted on port 25 as a fallback.
// NoTLS will allways use port 25.
func WithTLSPortPolicy(p TLSPolicy) Option {
return func(c *Client) error {
c.SetTLSPortPolicy(p)
return nil
}
}

// WithTLSConfig tells the client to use the provided *tls.Config
func WithTLSConfig(co *tls.Config) Option {
return func(c *Client) error {
Expand Down Expand Up @@ -430,11 +467,52 @@ func (c *Client) SetTLSPolicy(p TLSPolicy) {
c.tlspolicy = p
}

// SetTLSPortPolicy overrides the current TLSPolicy with the given TLSPolicy
// value. The correct port is automatically set.
//
// Port 587 is used for TLSMandatory and TLSOpportunistic.
// If the connection fails with TLSOpportunistic, a plaintext connection is
// attempted on port 25 as a fallback.
// NoTLS will allways use port 25.
func (c *Client) SetTLSPortPolicy(p TLSPolicy) {
c.port = DefaultPortTLS

if p == TLSOpportunistic {
c.fallbackPort = DefaultPort
}
if p == NoTLS {
c.port = DefaultPort
}

c.tlspolicy = p
}

// SetSSL tells the Client wether to use SSL or not
func (c *Client) SetSSL(s bool) {
c.ssl = s
}

// SetSSLPort tells the Client wether or not to use SSL and fallback.
// The correct port is automatically set.
//
// Port 465 is used when SSL set (true).
// Port 25 is used when SSL is unset (false).
// When the SSL connection fails and fb is set to true,
// the client will attempt to connect on port 25 using plaintext.
func (c *Client) SetSSLPort(ssl bool, fb bool) {
c.port = DefaultPort
if ssl {
c.port = DefaultPortSSL
}

c.fallbackPort = 0
if fb {
c.fallbackPort = DefaultPort
}

c.ssl = ssl
}

// SetDebugLog tells the Client whether debug logging is enabled or not
func (c *Client) SetDebugLog(v bool) {
c.dl = v
Expand Down Expand Up @@ -507,6 +585,10 @@ func (c *Client) DialWithContext(pc context.Context) error {
}
var err error
c.co, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
c.co, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return err
}
Expand Down Expand Up @@ -606,6 +688,12 @@ func (c *Client) checkConn() error {
return nil
}

// serverFallbackAddr returns the currently set combination of hostname
// and fallback port.
func (c *Client) serverFallbackAddr() string {
return fmt.Sprintf("%s:%d", c.host, c.fallbackPort)
}

// tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error {
if c.co == nil {
Expand Down
107 changes: 107 additions & 0 deletions client_test.go
Expand Up @@ -89,9 +89,12 @@ func TestNewClientWithOptions(t *testing.T) {
{"WithTimeout()", WithTimeout(time.Second * 5), false},
{"WithTimeout()", WithTimeout(-10), true},
{"WithSSL()", WithSSL(), false},
{"WithSSLPort(false)", WithSSLPort(false), false},
{"WithSSLPort(true)", WithSSLPort(true), false},
{"WithHELO()", WithHELO(host), false},
{"WithHELO(); helo is empty", WithHELO(""), true},
{"WithTLSPolicy()", WithTLSPolicy(TLSOpportunistic), false},
{"WithTLSPortPolicy()", WithTLSPortPolicy(TLSOpportunistic), false},
{"WithTLSConfig()", WithTLSConfig(&tls.Config{}), false},
{"WithTLSConfig(); config is nil", WithTLSConfig(nil), true},
{"WithSMTPAuth()", WithSMTPAuth(SMTPAuthLogin), false},
Expand Down Expand Up @@ -235,6 +238,42 @@ func TestWithTLSPolicy(t *testing.T) {
}
}

// TestWithTLSPortPolicy tests the WithTLSPortPolicy() option for the NewClient() method
func TestWithTLSPortPolicy(t *testing.T) {
tests := []struct {
name string
value TLSPolicy
want string
wantPort int
fbPort int
sf bool
}{
{"Policy: TLSMandatory", TLSMandatory, TLSMandatory.String(), 587, 0, false},
{"Policy: TLSOpportunistic", TLSOpportunistic, TLSOpportunistic.String(), 587, 25, false},
{"Policy: NoTLS", NoTLS, NoTLS.String(), 25, 0, false},
{"Policy: Invalid", -1, "UnknownPolicy", 587, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(DefaultHost, WithTLSPortPolicy(tt.value))
if err != nil && !tt.sf {
t.Errorf("failed to create new client: %s", err)
return
}
if c.tlspolicy.String() != tt.want {
t.Errorf("failed to set TLSPortPolicy. Want: %s, got: %s", tt.want, c.tlspolicy)
}
if c.port != tt.wantPort {
t.Errorf("failed to set TLSPortPolicy, wanted port: %d, got: %d", tt.wantPort, c.port)
}
if c.fallbackPort != tt.fbPort {
t.Errorf("failed to set TLSPortPolicy, wanted fallbakc port: %d, got: %d", tt.fbPort,
c.fallbackPort)
}
})
}
}

// TestSetTLSPolicy tests the SetTLSPolicy() method for the Client object
func TestSetTLSPolicy(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -312,6 +351,42 @@ func TestSetSSL(t *testing.T) {
}
}

// TestSetSSLPort tests the Client.SetSSLPort method
func TestClient_SetSSLPort(t *testing.T) {
tests := []struct {
name string
value bool
fb bool
port int
fbPort int
}{
{"SSL: on, fb: off", true, false, 465, 0},
{"SSL: on, fb: on", true, true, 465, 25},
{"SSL: off, fb: off", false, false, 25, 0},
{"SSL: off, fb: on", false, true, 25, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, err := NewClient(DefaultHost)
if err != nil {
t.Errorf("failed to create new client: %s", err)
return
}
c.SetSSLPort(tt.value, tt.fb)
if c.ssl != tt.value {
t.Errorf("failed to set SSL setting. Got: %t, want: %t", c.ssl, tt.value)
}
if c.port != tt.port {
t.Errorf("failed to set SSLPort, wanted port: %d, got: %d", c.port, tt.port)
}
if c.fallbackPort != tt.fbPort {
t.Errorf("failed to set SSLPort, wanted fallback port: %d, got: %d", c.fallbackPort,
tt.fbPort)
}
})
}
}

// TestSetUsername tests the SetUsername method for the Client object
func TestSetUsername(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -550,6 +625,38 @@ func TestClient_DialWithContext(t *testing.T) {
}
}

// TestClient_DialWithContext_Fallback tests the Client.DialWithContext method with the fallback
// port functionality
func TestClient_DialWithContext_Fallback(t *testing.T) {
c, err := getTestConnection(true)
if err != nil {
t.Skipf("failed to create test client: %s. Skipping tests", err)
}
c.SetTLSPortPolicy(TLSOpportunistic)
c.port = 999
ctx := context.Background()
if err := c.DialWithContext(ctx); err != nil {
t.Errorf("failed to dial with context: %s", err)
return
}
if c.co == nil {
t.Errorf("DialWithContext didn't fail but no connection found.")
}
if c.sc == nil {
t.Errorf("DialWithContext didn't fail but no SMTP client found.")
}
if err := c.Close(); err != nil {
t.Errorf("failed to close connection: %s", err)
}

c.port = 999
c.fallbackPort = 999
if err = c.DialWithContext(ctx); err == nil {
t.Error("dial with context was supposed to fail, but didn't")
return
}
}

// TestClient_DialWithContext_Debug tests the DialWithContext method for the Client object with debug
// logging enabled on the SMTP client
func TestClient_DialWithContext_Debug(t *testing.T) {
Expand Down