From bb4669c66e1ed80138a6326b3f5185f063f07d46 Mon Sep 17 00:00:00 2001 From: Anton Malinskiy Date: Sun, 3 Oct 2021 19:56:06 +1100 Subject: [PATCH] feat(nut): implement socket timeouts --- example_test.go | 7 +++---- go.mod | 3 +++ nut.go | 47 +++++++++++++++++++++++++++-------------------- 3 files changed, 33 insertions(+), 24 deletions(-) create mode 100644 go.mod diff --git a/example_test.go b/example_test.go index 6a23689..31936a1 100644 --- a/example_test.go +++ b/example_test.go @@ -2,17 +2,16 @@ package nut import ( "fmt" - - nut "github.com/robbiet480/go.nut" + "time" ) // This example connects to NUT, authenticates and returns the first UPS listed. func ExampleGetUPSList() { - client, connectErr := nut.Connect("127.0.0.1") + client, connectErr := Connect("127.0.0.1", 10*time.Second, 30*time.Second) if connectErr != nil { fmt.Print(connectErr) } - _, authenticationError = client.Authenticate("username", "password") + _, authenticationError := client.Authenticate("username", "password") if authenticationError != nil { fmt.Print(authenticationError) } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..cba4e3e --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/Malinskiy/go.nut + +go 1.17 diff --git a/nut.go b/nut.go index 34fd406..9705867 100644 --- a/nut.go +++ b/nut.go @@ -8,32 +8,33 @@ import ( "fmt" "net" "strings" + "time" ) // Client contains information about the NUT server as well as the connection. type Client struct { - Version string - ProtocolVersion string - Hostname net.Addr - conn *net.TCPConn + opTimeout time.Duration + conn net.Conn } // Connect accepts a hostname/IP string and creates a connection to NUT, returning a Client. -func Connect(hostname string) (Client, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:3493", hostname)) +func Connect(hostname string, connectTimeout time.Duration, opTimeout time.Duration) (*Client, error) { + _, _, err := net.SplitHostPort(hostname) if err != nil { - return Client{}, err + hostname = net.JoinHostPort(hostname, "3493") } - conn, err := net.DialTCP("tcp", nil, tcpAddr) + d := net.Dialer{ + Timeout: connectTimeout, + } + conn, err := d.Dial("tcp", hostname) if err != nil { - return Client{}, err + return nil, err } - client := Client{ - Hostname: conn.RemoteAddr(), - conn: conn, + + client := &Client{ + opTimeout: opTimeout, + conn: conn, } - client.GetVersion() - client.GetNetworkProtocolVersion() return client, nil } @@ -55,6 +56,10 @@ func (c *Client) ReadResponse(endLine string, multiLineResponse bool) (resp []st response := []string{} for { + err = c.conn.SetReadDeadline(time.Now().Add(c.opTimeout)) + if err != nil { + return nil, err + } line, err := connbuff.ReadString('\n') if err != nil { return nil, fmt.Errorf("error reading response: %v", err) @@ -79,18 +84,22 @@ func (c *Client) SendCommand(cmd string) (resp []string, err error) { if strings.HasPrefix(cmd, "USERNAME ") || strings.HasPrefix(cmd, "PASSWORD ") || strings.HasPrefix(cmd, "SET ") || strings.HasPrefix(cmd, "HELP ") || strings.HasPrefix(cmd, "VER ") || strings.HasPrefix(cmd, "NETVER ") { endLine = "OK\n" } - _, err = fmt.Fprint(c.conn, cmd) + err = c.conn.SetWriteDeadline(time.Now().Add(c.opTimeout)) + if err != nil { + return nil, err + } + _, err = c.conn.Write([]byte(cmd)) if err != nil { - return []string{}, err + return nil, err } resp, err = c.ReadResponse(endLine, strings.HasPrefix(cmd, "LIST ")) if err != nil { - return []string{}, err + return nil, err } if strings.HasPrefix(resp[0], "ERR ") { - return []string{}, errorForMessage(strings.Split(resp[0], " ")[1]) + return nil, errorForMessage(strings.Split(resp[0], " ")[1]) } return resp, nil @@ -141,13 +150,11 @@ func (c *Client) Help() (string, error) { // GetVersion returns the the version of the server currently in use. func (c *Client) GetVersion() (string, error) { versionResponse, err := c.SendCommand("VER") - c.Version = versionResponse[0] return versionResponse[0], err } // GetNetworkProtocolVersion returns the version of the network protocol currently in use. func (c *Client) GetNetworkProtocolVersion() (string, error) { versionResponse, err := c.SendCommand("NETVER") - c.ProtocolVersion = versionResponse[0] return versionResponse[0], err }