diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b24b1b..f3b82c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changes +## 2.0.0-beta.9 + +- **SECURITY**: Fixed integer overflow vulnerability in search tree size + calculation that could potentially allow malformed databases to trigger + security issues. +- **SECURITY**: Enhanced bounds checking in tree traversal functions to return + proper errors instead of silent failures when encountering malformed + databases. +- Added validation for invalid prefixes in `NetworksWithin` to prevent + unexpected behavior with malformed input. + ## 2.0.0-beta.8 - 2025-07-15 - Fixed "no next offset available" error that occurred when using custom diff --git a/internal/decoder/data_decoder.go b/internal/decoder/data_decoder.go index ab85610..fb65134 100644 --- a/internal/decoder/data_decoder.go +++ b/internal/decoder/data_decoder.go @@ -246,6 +246,8 @@ func (d *DataDecoder) decodePointer( pointerValueOffset = 526336 case 4: pointerValueOffset = 0 + default: + return 0, 0, mmdberrors.NewInvalidDatabaseError("invalid pointer size: %d", pointerSize) } pointer := unpacked + pointerValueOffset @@ -477,6 +479,8 @@ func (d *DataDecoder) sizeFromCtrlByte( size = 285 + uintFromBytes(0, sizeBytes) case size > 30: size = uintFromBytes(0, sizeBytes) + 65821 + default: + // size < 30, no modification needed } return size, newOffset, nil } diff --git a/internal/decoder/reflection.go b/internal/decoder/reflection.go index 702ea8d..36f63ea 100644 --- a/internal/decoder/reflection.go +++ b/internal/decoder/reflection.go @@ -374,6 +374,8 @@ func (d *ReflectionDecoder) unmarshalBool( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -400,6 +402,8 @@ func (d *ReflectionDecoder) unmarshalBytes( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -421,6 +425,8 @@ func (d *ReflectionDecoder) unmarshalFloat32( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -445,6 +451,8 @@ func (d *ReflectionDecoder) unmarshalFloat64( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -481,6 +489,8 @@ func (d *ReflectionDecoder) unmarshalInt32( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -492,8 +502,6 @@ func (d *ReflectionDecoder) unmarshalMap( depth int, ) (uint, error) { switch result.Kind() { - default: - return 0, mmdberrors.NewUnmarshalTypeStrError("map", result.Type()) case reflect.Struct: return d.decodeStruct(size, offset, result, depth) case reflect.Map: @@ -508,6 +516,8 @@ func (d *ReflectionDecoder) unmarshalMap( return newOffset, err } return 0, mmdberrors.NewUnmarshalTypeStrError("map", result.Type()) + default: + return 0, mmdberrors.NewUnmarshalTypeStrError("map", result.Type()) } } @@ -556,6 +566,8 @@ func (d *ReflectionDecoder) unmarshalSlice( result.Set(rv.Value) return newOffset, err } + default: + // Fall through to error return } return 0, mmdberrors.NewUnmarshalTypeStrError("array", result.Type()) } @@ -578,6 +590,8 @@ func (d *ReflectionDecoder) unmarshalString( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -632,6 +646,8 @@ func (d *ReflectionDecoder) unmarshalUint( result.SetUint(value) return newOffset, nil } + default: + // Fall through to general unmarshaling logic } switch result.Kind() { @@ -656,6 +672,8 @@ func (d *ReflectionDecoder) unmarshalUint( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -691,6 +709,8 @@ func (d *ReflectionDecoder) unmarshalUint128( result.Set(reflect.ValueOf(value)) return newOffset, nil } + default: + // Fall through to error return } return newOffset, mmdberrors.NewUnmarshalTypeError(value, result.Type()) } @@ -1210,6 +1230,8 @@ func (d *ReflectionDecoder) tryFastDecodeTyped( addressableValue{result.Elem(), false}, expectedType.Elem(), ) + default: + // Type not supported for fast path } return 0, false diff --git a/reader.go b/reader.go index 639e8bc..3232f4c 100644 --- a/reader.go +++ b/reader.go @@ -301,6 +301,17 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) { return nil, err } + // Check for integer overflow in search tree size calculation + if metadata.NodeCount > 0 && metadata.RecordSize > 0 { + recordSizeQuarter := metadata.RecordSize / 4 + if recordSizeQuarter > 0 { + maxNodes := ^uint(0) / recordSizeQuarter + if metadata.NodeCount > maxNodes { + return nil, mmdberrors.NewInvalidDatabaseError("database tree size would overflow") + } + } + } + searchTreeSize := metadata.NodeCount * (metadata.RecordSize / 4) dataSectionStart := searchTreeSize + dataSectionSeparatorSize dataSectionEnd := uint(metadataStart - len(metadataStartMarker)) @@ -319,9 +330,12 @@ func FromBytes(buffer []byte, options ...ReaderOption) (*Reader, error) { nodeOffsetMult: metadata.RecordSize / 4, } - reader.setIPv4Start() + err = reader.setIPv4Start() + if err != nil { + return nil, err + } - return reader, err + return reader, nil } // Lookup retrieves the database record for ip and returns a Result, which can @@ -365,10 +379,10 @@ func (r *Reader) LookupOffset(offset uintptr) Result { return Result{decoder: r.decoder, offset: uint(offset)} } -func (r *Reader) setIPv4Start() { +func (r *Reader) setIPv4Start() error { if r.Metadata.IPVersion != 6 { r.ipv4StartBitDepth = 96 - return + return nil } nodeCount := r.Metadata.NodeCount @@ -376,10 +390,16 @@ func (r *Reader) setIPv4Start() { node := uint(0) i := 0 for ; i < 96 && node < nodeCount; i++ { - node = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize) + var err error + node, err = readNodeBySize(r.buffer, node*r.nodeOffsetMult, 0, r.Metadata.RecordSize) + if err != nil { + return err + } } r.ipv4Start = node r.ipv4StartBitDepth = i + + return nil } var zeroIP = netip.MustParseAddr("::") @@ -409,46 +429,64 @@ func (r *Reader) lookupPointer(ip netip.Addr) (uint, int, error) { } // readNodeBySize reads a node value from the buffer based on record size and bit. -func readNodeBySize(buffer []byte, offset, bit, recordSize uint) uint { +func readNodeBySize(buffer []byte, offset, bit, recordSize uint) (uint, error) { + bufferLen := uint(len(buffer)) switch recordSize { case 24: offset += bit * 3 + if offset > bufferLen-3 { + return 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 24-bit node read", + ) + } return (uint(buffer[offset]) << 16) | (uint(buffer[offset+1]) << 8) | - uint(buffer[offset+2]) + uint(buffer[offset+2]), nil case 28: if bit == 0 { + if offset > bufferLen-4 { + return 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 28-bit node read", + ) + } return ((uint(buffer[offset+3]) & 0xF0) << 20) | (uint(buffer[offset]) << 16) | (uint(buffer[offset+1]) << 8) | - uint(buffer[offset+2]) + uint(buffer[offset+2]), nil + } + if offset > bufferLen-7 { + return 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 28-bit node read", + ) } return ((uint(buffer[offset+3]) & 0x0F) << 24) | (uint(buffer[offset+4]) << 16) | (uint(buffer[offset+5]) << 8) | - uint(buffer[offset+6]) + uint(buffer[offset+6]), nil case 32: offset += bit * 4 + if offset > bufferLen-4 { + return 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed: insufficient buffer for 32-bit node read", + ) + } return (uint(buffer[offset]) << 24) | (uint(buffer[offset+1]) << 16) | (uint(buffer[offset+2]) << 8) | - uint(buffer[offset+3]) + uint(buffer[offset+3]), nil default: - return 0 + return 0, mmdberrors.NewInvalidDatabaseError("unsupported record size") } } func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, error) { switch r.Metadata.RecordSize { case 24: - n, i := r.traverseTree24(ip, node, stopBit) - return n, i, nil + return r.traverseTree24(ip, node, stopBit) case 28: - n, i := r.traverseTree28(ip, node, stopBit) - return n, i, nil + return r.traverseTree28(ip, node, stopBit) case 32: - n, i := r.traverseTree32(ip, node, stopBit) - return n, i, nil + return r.traverseTree32(ip, node, stopBit) default: return 0, 0, mmdberrors.NewInvalidDatabaseError( "unsupported record size: %d", @@ -457,7 +495,7 @@ func (r *Reader) traverseTree(ip netip.Addr, node uint, stopBit int) (uint, int, } } -func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int) { +func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, int, error) { i := 0 if ip.Is4() { i = r.ipv4StartBitDepth @@ -465,6 +503,7 @@ func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, in } nodeCount := r.Metadata.NodeCount buffer := r.buffer + bufferLen := uint(len(buffer)) ip16 := ip.As16() for ; i < stopBit && node < nodeCount; i++ { @@ -475,15 +514,21 @@ func (r *Reader) traverseTree24(ip netip.Addr, node uint, stopBit int) (uint, in baseOffset := node * 6 offset := baseOffset + bit*3 + if offset > bufferLen-3 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed during tree traversal", + ) + } + node = (uint(buffer[offset]) << 16) | (uint(buffer[offset+1]) << 8) | uint(buffer[offset+2]) } - return node, i + return node, i, nil } -func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int) { +func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, int, error) { i := 0 if ip.Is4() { i = r.ipv4StartBitDepth @@ -491,6 +536,7 @@ func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, in } nodeCount := r.Metadata.NodeCount buffer := r.buffer + bufferLen := uint(len(buffer)) ip16 := ip.As16() for ; i < stopBit && node < nodeCount; i++ { @@ -499,11 +545,18 @@ func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, in bit := (uint(ip16[byteIdx]) >> bitPos) & 1 baseOffset := node * 7 + offset := baseOffset + bit*4 + + if baseOffset > bufferLen-4 || offset > bufferLen-3 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed during tree traversal", + ) + } + sharedByte := uint(buffer[baseOffset+3]) mask := uint(0xF0 >> (bit * 4)) shift := 20 + bit*4 nibble := ((sharedByte & mask) << shift) - offset := baseOffset + bit*4 node = nibble | (uint(buffer[offset]) << 16) | @@ -511,10 +564,10 @@ func (r *Reader) traverseTree28(ip netip.Addr, node uint, stopBit int) (uint, in uint(buffer[offset+2]) } - return node, i + return node, i, nil } -func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int) { +func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, int, error) { i := 0 if ip.Is4() { i = r.ipv4StartBitDepth @@ -522,6 +575,7 @@ func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, in } nodeCount := r.Metadata.NodeCount buffer := r.buffer + bufferLen := uint(len(buffer)) ip16 := ip.As16() for ; i < stopBit && node < nodeCount; i++ { @@ -532,20 +586,33 @@ func (r *Reader) traverseTree32(ip netip.Addr, node uint, stopBit int) (uint, in baseOffset := node * 8 offset := baseOffset + bit*4 + if offset > bufferLen-4 { + return 0, 0, mmdberrors.NewInvalidDatabaseError( + "bounds check failed during tree traversal", + ) + } + node = (uint(buffer[offset]) << 24) | (uint(buffer[offset+1]) << 16) | (uint(buffer[offset+2]) << 8) | uint(buffer[offset+3]) } - return node, i + return node, i, nil } func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) { - resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize) - - if resolved >= uintptr(len(r.buffer)) { + // Check for integer underflow: pointer must be greater than nodeCount + separator + minPointer := r.Metadata.NodeCount + dataSectionSeparatorSize + if pointer >= minPointer { + resolved := uintptr(pointer - minPointer) + bufferLen := uintptr(len(r.buffer)) + if resolved < bufferLen { + return resolved, nil + } + // Error case - bounds exceeded return 0, mmdberrors.NewInvalidDatabaseError("the MaxMind DB file's search tree is corrupt") } - return resolved, nil + // Error case - underflow + return 0, mmdberrors.NewInvalidDatabaseError("the MaxMind DB file's search tree is corrupt") } diff --git a/reader_test.go b/reader_test.go index 818bbe0..54599fb 100644 --- a/reader_test.go +++ b/reader_test.go @@ -1248,3 +1248,99 @@ func TestMetadataBuildTime(t *testing.T) { assert.True(t, buildTime.After(time.Date(2010, 1, 1, 0, 0, 0, 0, time.UTC))) assert.True(t, buildTime.Before(time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC))) } + +func TestIntegerOverflowProtection(t *testing.T) { + // Test that FromBytes detects integer overflow in search tree size calculation + t.Run("NodeCount overflow protection", func(t *testing.T) { + // Create metadata that would cause overflow: very large NodeCount + // For a 64-bit system with RecordSize=32, this should trigger overflow + // RecordSize/4 = 8, so maxNodes would be ^uint(0)/8 + // We'll use a NodeCount larger than this limit + overflowNodeCount := ^uint(0)/8 + 1000 // Guaranteed to overflow + + // Build minimal metadata map structure in MMDB format + // This is simplified - in a real MMDB, metadata is encoded differently + // But we can't easily create a valid MMDB file structure in a unit test + // So this test verifies the logic with mocked values + + // Create a test by directly calling the validation logic + metadata := Metadata{ + NodeCount: overflowNodeCount, + RecordSize: 32, // 32 bits = 4 bytes, so RecordSize/4 = 8 + } + + // Test the overflow detection logic directly + recordSizeQuarter := metadata.RecordSize / 4 + maxNodes := ^uint(0) / recordSizeQuarter + + // Verify our test setup is correct + assert.Greater(t, metadata.NodeCount, maxNodes, + "Test setup error: NodeCount should exceed maxNodes for overflow test") + + // Since we can't easily create an invalid MMDB file that parses but has overflow values, + // we test the core logic validation here and rely on integration tests + // for the full FromBytes flow + + if metadata.NodeCount > 0 && metadata.RecordSize > 0 { + recordSizeQuarter := metadata.RecordSize / 4 + if recordSizeQuarter > 0 { + maxNodes := ^uint(0) / recordSizeQuarter + if metadata.NodeCount > maxNodes { + // This is what should happen in FromBytes + err := mmdberrors.NewInvalidDatabaseError("database tree size would overflow") + assert.Equal(t, "database tree size would overflow", err.Error()) + } + } + } + }) + + t.Run("Valid large values should not trigger overflow", func(t *testing.T) { + // Test that reasonable large values don't trigger false positives + metadata := Metadata{ + NodeCount: 1000000, // 1 million nodes + RecordSize: 32, + } + + recordSizeQuarter := metadata.RecordSize / 4 + maxNodes := ^uint(0) / recordSizeQuarter + + // Verify this doesn't trigger overflow + assert.LessOrEqual(t, metadata.NodeCount, maxNodes, + "Valid large NodeCount should not trigger overflow protection") + }) + + t.Run("Edge case: RecordSize/4 is 0", func(t *testing.T) { + // Test edge case where RecordSize/4 could be 0 + recordSize := uint(3) // 3/4 = 0 in integer division + + recordSizeQuarter := recordSize / 4 + // Should be 0, which means no overflow check is performed + assert.Equal(t, uint(0), recordSizeQuarter) + + // The overflow protection should skip when recordSizeQuarter is 0 + // This tests the condition: if recordSizeQuarter > 0 + }) +} + +func TestNetworksWithinInvalidPrefix(t *testing.T) { + reader, err := Open(testFile("GeoIP2-Country-Test.mmdb")) + require.NoError(t, err) + defer func() { + require.NoError(t, reader.Close()) + }() + + // Test what happens when user ignores ParsePrefix error and passes invalid prefix + var invalidPrefix netip.Prefix // Zero value - invalid prefix + + foundError := false + for result := range reader.NetworksWithin(invalidPrefix) { + if result.Err() != nil { + foundError = true + // Check that we get an appropriate error message + assert.Contains(t, result.Err().Error(), "invalid prefix") + break + } + } + + assert.True(t, foundError, "Expected error when using invalid prefix") +} diff --git a/traverse.go b/traverse.go index 34f6cd2..ffb6177 100644 --- a/traverse.go +++ b/traverse.go @@ -1,8 +1,8 @@ package maxminddb import ( + "errors" "fmt" - // comment to prevent gofumpt from randomly moving iter. "iter" "net/netip" @@ -78,6 +78,12 @@ func (r *Reader) Networks(options ...NetworksOption) iter.Seq[Result] { // [IncludeNetworksWithoutData]. func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) iter.Seq[Result] { return func(yield func(Result) bool) { + if !prefix.IsValid() { + yield(Result{ + err: errors.New("invalid prefix"), + }) + return + } if r.Metadata.IPVersion == 4 && prefix.Addr().Is6() { yield(Result{ err: fmt.Errorf( @@ -101,6 +107,13 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) stopBit += 96 } + if stopBit > 128 { + yield(Result{ + err: errors.New("invalid prefix: exceeds IPv6 maximum of 128 bits"), + }) + return + } + pointer, bit, err := r.traverseTree(ip, 0, stopBit) if err != nil { yield(Result{ @@ -189,7 +202,15 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) ipRight[node.bit>>3] |= 1 << (7 - (node.bit % 8)) offset := node.pointer * r.nodeOffsetMult - rightPointer := readNodeBySize(r.buffer, offset, 1, r.Metadata.RecordSize) + rightPointer, err := readNodeBySize(r.buffer, offset, 1, r.Metadata.RecordSize) + if err != nil { + yield(Result{ + ip: mappedIP(node.ip), + prefixLen: uint8(node.bit), + err: err, + }) + return + } node.bit++ nodes = append(nodes, netNode{ @@ -198,7 +219,16 @@ func (r *Reader) NetworksWithin(prefix netip.Prefix, options ...NetworksOption) bit: node.bit, }) - node.pointer = readNodeBySize(r.buffer, offset, 0, r.Metadata.RecordSize) + leftPointer, err := readNodeBySize(r.buffer, offset, 0, r.Metadata.RecordSize) + if err != nil { + yield(Result{ + ip: mappedIP(node.ip), + prefixLen: uint8(node.bit), + err: err, + }) + return + } + node.pointer = leftPointer } } }