diff --git a/gosnmp.go b/gosnmp.go index 9676034c..497cdb53 100644 --- a/gosnmp.go +++ b/gosnmp.go @@ -90,6 +90,8 @@ type GoSNMP struct { requestID uint32 random *rand.Rand + rxBuf *[rxBufSize]byte // has to be pointer due to https://github.com/golang/go/issues/11728 + // Internal - used to sync requests to responses - snmpv3 msgID uint32 } @@ -196,6 +198,8 @@ func (x *GoSNMP) Connect() error { // RequestID is Integer32 from SNMPV2-SMI and uses all 32 bits x.requestID = x.random.Uint32() + x.rxBuf = new([rxBufSize]byte) + if x.Version == Version3 { x.MsgFlags |= Reportable // tell the snmp server that a report PDU MUST be sent if x.SecurityModel == UserSecurityModel { diff --git a/marshal.go b/marshal.go index 339525e6..7d350ad8 100644 --- a/marshal.go +++ b/marshal.go @@ -12,9 +12,7 @@ import ( "encoding/asn1" "encoding/binary" "fmt" - "net" "sync/atomic" - "syscall" "time" ) @@ -157,10 +155,7 @@ const ( Report PDUType = 0xa8 ) -const ( - rxBufSizeMin = 1024 // Minimal buffer size to handle 1 OID (see dispatch()) - rxBufSizeMax = 131072 // 2 x max MTU size (65507) -) +const rxBufSize = 65535 // max size of IPv4 & IPv6 packet // Logger is an interface used for debugging. Both Print and // Printf have the same interfaces as Package Log in the std library. The @@ -257,46 +252,56 @@ func (x *GoSNMP) sendOneRequest(pdus []SnmpPDU, packetOut *SnmpPacket) (result * break } - var expected int - if packetOut.PDUType == GetBulkRequest { - expected = int(packetOut.MaxRepetitions) - } else { - expected = len(pdus) - } - - var resp []byte - resp, err = x.dispatch(x.Conn, outBuf, expected) - if err != nil { - continue - } - result = new(SnmpPacket) - result.MsgFlags = packetOut.MsgFlags - if packetOut.SecurityParameters != nil { - result.SecurityParameters = packetOut.SecurityParameters.Copy() - } - err = x.unmarshal(resp, result) + _, err = x.Conn.Write(outBuf) if err != nil { - err = fmt.Errorf("Unable to decode packet: %s", err.Error()) - continue - } - if result == nil || len(result.Variables) < 1 { - err = fmt.Errorf("Unable to decode packet: nil") continue } - validID := false - for _, id := range allReqIDs { - if id == result.RequestID { + for { + // Receive response and try receiving again on any decoding error. + // Let the deadline abort us if we don't receive a valid response. + + var resp []byte + resp, err = x.receive() + if err != nil { + // receive error. retrying won't help. abort + break + } + result = new(SnmpPacket) + result.MsgFlags = packetOut.MsgFlags + if packetOut.SecurityParameters != nil { + result.SecurityParameters = packetOut.SecurityParameters.Copy() + } + err = x.unmarshal(resp, result) + if err != nil { + err = fmt.Errorf("Unable to decode packet: %s", err.Error()) + continue + } + if result == nil || len(result.Variables) < 1 { + err = fmt.Errorf("Unable to decode packet: nil") + continue + } + + validID := false + for _, id := range allReqIDs { + if id == result.RequestID { + validID = true + } + } + if result.RequestID == 0 { validID = true } + if !validID { + err = fmt.Errorf("Out of order response") + continue + } + + break } - if result.RequestID == 0 { - validID = true - } - if !validID { - err = fmt.Errorf("Out of order response") + if err != nil { continue } + // Success! return result, nil } @@ -498,7 +503,7 @@ func (packet *SnmpPacket) marshalSnmpV3Header(msgid uint32) ([]byte, error) { } // maximum response msg size - maxmsgsize := marshalUvarInt(rxBufSizeMax) + maxmsgsize := marshalUvarInt(rxBufSize) buf.Write([]byte{byte(Integer), byte(len(maxmsgsize))}) buf.Write(maxmsgsize) @@ -1279,51 +1284,19 @@ func (x *GoSNMP) unmarshalVBL(packet []byte, response *SnmpPacket, return response, nil } -// dispatch request on network, and read the results into a byte array -// -// Previously, resp was allocated rxBufSize (65536) bytes ie a fixed size for -// all responses. To decrease memory usage, resp is dynamically sized, at the -// cost of possible additional network round trips. -func (x *GoSNMP) dispatch(c net.Conn, outBuf []byte, expected int) ([]byte, error) { - if expected <= 0 { - expected = 1 +// receive response from network and read into a byte array +func (x *GoSNMP) receive() ([]byte, error) { + n, err := x.Conn.Read(x.rxBuf[:]) + if err != nil { + return nil, fmt.Errorf("Error reading from UDP: %s", err.Error()) } - var resp []byte - for bufSize := rxBufSizeMin * expected; bufSize < rxBufSizeMax; bufSize *= 2 { - resp = make([]byte, bufSize) - _, err := c.Write(outBuf) - if err != nil { - return resp, fmt.Errorf("Error writing to socket: %s", err.Error()) - } - n, err := c.Read(resp) - if err != nil { - // On Windows we don't get a partial read and truncation. Instead - // we get an error if buff is too small - WSAEMSGSIZE 10040. - const WSAEMSGSIZE syscall.Errno = 10040 - if opErr, ok := err.(*net.OpError); ok { - if opErr.Err == WSAEMSGSIZE { - continue - } - } - return resp, fmt.Errorf("Error reading from UDP: %s", err.Error()) - } - if n < bufSize { - // Memory usage optimization. Help the runtime to release as much memory as - // possible. - // - // See: http://blog.golang.org/go-slices-usage-and-internals, - // section: A possible "gotcha" - // ...As mentioned earlier, re-slicing a slice doesn't make a copy of the - // underlying array. The full array will be kept in memory until it is no - // longer referenced. Occasionally this can cause the program to hold all - // the data in memory when only a small piece of it is needed. - resp = resp[:n] - resp2 := make([]byte, len(resp)) - copy(resp2, resp) - return resp2, nil - } - x.Logger.Printf("Retrying. Buffer size was too small. (size %d)", bufSize) + if n == rxBufSize { + // This should never happen unless we're using something like a unix domain socket. + return nil, fmt.Errorf("response buffer too small") } - return resp, fmt.Errorf("Response bufSize exceeded rxBufSizeMax (%d)", rxBufSizeMax) + + resp := make([]byte, n) + copy(resp, x.rxBuf[:n]) + return resp, nil } diff --git a/marshal_test.go b/marshal_test.go index 846ab965..ac273f5b 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -9,8 +9,10 @@ import ( "fmt" "io/ioutil" "log" + "net" "os" "testing" + "time" ) var _ = fmt.Sprintf("dummy") // dummy @@ -1158,3 +1160,119 @@ func counter64Response() []byte { 0x87, } } + +func TestSendOneRequest_dups(t *testing.T) { + srvr, err := net.ListenUDP("udp4", &net.UDPAddr{}) + defer srvr.Close() + + x := &GoSNMP{ + Version: Version2c, + Target: srvr.LocalAddr().(*net.UDPAddr).IP.String(), + Port: uint16(srvr.LocalAddr().(*net.UDPAddr).Port), + Timeout: time.Millisecond * 100, + Retries: 2, + } + if err := x.Connect(); err != nil { + t.Fatalf("Error connecting: %s", err) + } + + go func() { + buf := make([]byte, 256) + for { + n, addr, err := srvr.ReadFrom(buf) + if err != nil { + return + } + buf := buf[:n] + + var reqPkt SnmpPacket + err = x.unmarshal(buf, &reqPkt) + if err != nil { + t.Errorf("Error: %s", err) + } + rspPkt := x.mkSnmpPacket(GetResponse, 0, 0) + rspPkt.RequestID = reqPkt.RequestID + rspPkt.Variables = []SnmpPDU{ + { + Name: ".1.2", + Type: Integer, + Value: 123, + }, + } + outBuf, err := rspPkt.marshalMsg(rspPkt.Variables, rspPkt.PDUType, rspPkt.MsgID, rspPkt.RequestID) + if err != nil { + t.Errorf("ERR: %s", err) + } + srvr.WriteTo(outBuf, addr) + for i := 0; i <= x.Retries; i++ { + srvr.WriteTo(outBuf, addr) + } + } + }() + + reqPkt := x.mkSnmpPacket(GetResponse, 0, 0) //not actually a GetResponse, but we need something our test server can unmarshal + reqPDU := SnmpPDU{Name: ".1.2", Type: Null} + + _, err = x.sendOneRequest([]SnmpPDU{reqPDU}, reqPkt) + if err != nil { + t.Errorf("Error: %s", err) + return + } + + _, err = x.sendOneRequest([]SnmpPDU{reqPDU}, reqPkt) + if err != nil { + t.Errorf("Error: %s", err) + return + } +} + +func BenchmarkSendOneRequest(b *testing.B) { + b.StopTimer() + + srvr, err := net.ListenUDP("udp4", &net.UDPAddr{}) + defer srvr.Close() + + x := &GoSNMP{ + Version: Version2c, + Target: srvr.LocalAddr().(*net.UDPAddr).IP.String(), + Port: uint16(srvr.LocalAddr().(*net.UDPAddr).Port), + Timeout: time.Millisecond * 100, + Retries: 2, + } + if err := x.Connect(); err != nil { + b.Fatalf("Error connecting: %s", err) + } + + go func() { + buf := make([]byte, 256) + outBuf := counter64Response() + for { + _, addr, err := srvr.ReadFrom(buf) + if err != nil { + return + } + + copy(outBuf[17:21], buf[11:15]) // evil: copy request ID + srvr.WriteTo(outBuf, addr) + } + }() + + reqPkt := x.mkSnmpPacket(GetRequest, 0, 0) + reqPDU := SnmpPDU{Name: ".1.3.6.1.2.1.31.1.1.1.10.1", Type: Null} + + // make sure everything works before starting the test + _, err = x.sendOneRequest([]SnmpPDU{reqPDU}, reqPkt) + if err != nil { + b.Fatalf("Precheck failed: %s", err) + } + + b.StartTimer() + + for n := 0; n < b.N; n++ { + _, err = x.sendOneRequest([]SnmpPDU{reqPDU}, reqPkt) + if err != nil { + b.Fatalf("Error: %s", err) + return + } + } +}