diff --git a/ext/nio4r/bytebuffer.c b/ext/nio4r/bytebuffer.c index 825c73e..4fa1dcf 100644 --- a/ext/nio4r/bytebuffer.c +++ b/ext/nio4r/bytebuffer.c @@ -265,6 +265,7 @@ static VALUE NIO_ByteBuffer_put(VALUE self, VALUE string) struct NIO_ByteBuffer *buffer; Data_Get_Struct(self, struct NIO_ByteBuffer, buffer); + StringValue(string); length = RSTRING_LEN(string); if(length > buffer->limit - buffer->position) { diff --git a/ext/nio4r/org/nio4r/ByteBuffer.java b/ext/nio4r/org/nio4r/ByteBuffer.java index 0ab5fa0..72e1f32 100644 --- a/ext/nio4r/org/nio4r/ByteBuffer.java +++ b/ext/nio4r/org/nio4r/ByteBuffer.java @@ -163,10 +163,8 @@ public IRubyObject fetch(ThreadContext context, IRubyObject index) { @JRubyMethod(name = "<<") public IRubyObject put(ThreadContext context, IRubyObject str) { - String string = str.asJavaString(); - try { - this.byteBuffer.put(string.getBytes()); + this.byteBuffer.put(str.convertToString().getByteList().bytes()); } catch(BufferOverflowException e) { throw ByteBuffer.newOverflowError(context, "buffer is full"); } diff --git a/lib/nio/bytebuffer.rb b/lib/nio/bytebuffer.rb index 1f45fc7..c065ef5 100644 --- a/lib/nio/bytebuffer.rb +++ b/lib/nio/bytebuffer.rb @@ -111,15 +111,22 @@ def [](index) # Add a String to the buffer # + # @param str [#to_str] data to add to the buffer + # + # @raise [TypeError] given a non-string type # @raise [NIO::ByteBuffer::OverflowError] buffer is full # # @return [self] - def <<(str) + def put(str) + raise TypeError, "expected String, got #{str.class}" unless str.respond_to?(:to_str) + str = str.to_str + raise OverflowError, "buffer is full" if str.length > @limit - @position @buffer[@position...str.length] = str @position += str.length self end + alias << put # Perform a non-blocking read from the given IO object into the buffer # Reads as much data as is immediately available and returns diff --git a/spec/nio/bytebuffer_spec.rb b/spec/nio/bytebuffer_spec.rb index 465703a..262e484 100644 --- a/spec/nio/bytebuffer_spec.rb +++ b/spec/nio/bytebuffer_spec.rb @@ -180,6 +180,11 @@ expect(bytebuffer.limit).to eq capacity end + it "raises TypeError if given a non-String type" do + expect { bytebuffer << 42 }.to raise_error(TypeError) + expect { bytebuffer << nil }.to raise_error(TypeError) + end + it "raises NIO::ByteBuffer::OverflowError if the buffer is full" do bytebuffer << "X" * (capacity - 1) expect { bytebuffer << "X" }.not_to raise_error