Permalink
Browse files

Add ReadValue and fix framed io

  • Loading branch information...
1 parent c763dcd commit 3b2542a7a22dcad9c40e5c774d2b85b38b445288 @samuel committed Aug 14, 2012
Showing with 130 additions and 8 deletions.
  1. +7 −3 framed.go
  2. +16 −0 framed_test.go
  3. +107 −5 thrift.go
View
@@ -80,9 +80,13 @@ func (f *FramedReadWriteCloser) Close() error {
func (f *FramedReadWriteCloser) Flush() error {
frameSize := uint32(f.wbuf.Len())
- if err := binary.Write(f.wrapped, binary.BigEndian, frameSize); err != nil {
+ if frameSize > 0 {
+ if err := binary.Write(f.wrapped, binary.BigEndian, frameSize); err != nil {
+ return err
+ }
+ _, err := io.Copy(f.wrapped, f.wbuf)
+ f.wbuf.Reset()
return err
}
- _, err := io.Copy(f.wrapped, f.wbuf)
- return err
+ return nil
}
View
@@ -9,6 +9,10 @@ type ClosingBuffer struct {
Buffer *bytes.Buffer
}
+func (c *ClosingBuffer) Len() int {
+ return c.Buffer.Len()
+}
+
func (c *ClosingBuffer) Bytes() []byte {
return c.Buffer.Bytes()
}
@@ -32,9 +36,21 @@ func TestFramed(t *testing.T) {
if _, err := framed.Write([]byte{1, 2, 3, 4}); err != nil {
t.Fatalf("Framed: error on Write %s", err)
}
+ if buf.Len() != 0 {
+ t.Fatalf("Framed: wrote %d bytes before flush", buf.Len())
+ }
+ if err := framed.Flush(); err != nil {
+ t.Fatalf("Framed: error on Flush %s", err)
+ }
+ if buf.Len() != 8 {
+ t.Fatalf("Framed: wrote (%d) other than 8 bytes after flush", buf.Len())
+ }
if err := framed.Flush(); err != nil {
t.Fatalf("Framed: error on Flush %s", err)
}
+ if buf.Len() != 8 {
+ t.Fatalf("Framed: flush didn't clear write buffer")
+ }
out := buf.Bytes()
expected := []byte{0, 0, 0, 4, 1, 2, 3, 4}
View
112 thrift.go
@@ -1,6 +1,7 @@
package thrift
import (
+ "errors"
"fmt"
"io"
"reflect"
@@ -268,7 +269,9 @@ func SkipValue(r io.Reader, p Protocol, thriftType byte) error {
if ftype == typeStop {
break
}
- SkipValue(r, p, ftype)
+ if err = SkipValue(r, p, ftype); err != nil {
+ return err
+ }
if err = p.ReadFieldEnd(r); err != nil {
return err
}
@@ -281,8 +284,12 @@ func SkipValue(r io.Reader, p Protocol, thriftType byte) error {
}
for i := 0; i < n; i++ {
- SkipValue(r, p, keyType)
- SkipValue(r, p, valueType)
+ if err = SkipValue(r, p, keyType); err != nil {
+ return err
+ }
+ if err = SkipValue(r, p, valueType); err != nil {
+ return err
+ }
}
return p.ReadMapEnd(r)
@@ -292,7 +299,9 @@ func SkipValue(r io.Reader, p Protocol, thriftType byte) error {
return err
}
for i := 0; i < n; i++ {
- SkipValue(r, p, valueType)
+ if err = SkipValue(r, p, valueType); err != nil {
+ return err
+ }
}
return p.ReadListEnd(r)
case typeSet:
@@ -301,9 +310,102 @@ func SkipValue(r io.Reader, p Protocol, thriftType byte) error {
return err
}
for i := 0; i < n; i++ {
- SkipValue(r, p, valueType)
+ if err = SkipValue(r, p, valueType); err != nil {
+ return err
+ }
}
return p.ReadSetEnd(r)
}
return err
}
+
+func ReadValue(r io.Reader, p Protocol, thriftType byte) (interface{}, error) {
+ switch thriftType {
+ case typeBool:
+ return p.ReadBool(r)
+ case typeByte:
+ return p.ReadByte(r)
+ case typeI16:
+ return p.ReadI16(r)
+ case typeI32:
+ return p.ReadI32(r)
+ case typeI64:
+ return p.ReadI64(r)
+ case typeDouble:
+ return p.ReadDouble(r)
+ case typeString:
+ return p.ReadString(r)
+ case typeStruct:
+ if err := p.ReadStructBegin(r); err != nil {
+ return nil, err
+ }
+ st := make(map[int]interface{})
+ for {
+ ftype, id, err := p.ReadFieldBegin(r)
+ if err != nil {
+ return st, err
+ }
+ if ftype == typeStop {
+ break
+ }
+ v, err := ReadValue(r, p, ftype)
+ if err != nil {
+ return st, err
+ }
+ st[int(id)] = v
+ if err = p.ReadFieldEnd(r); err != nil {
+ return st, err
+ }
+ }
+ return st, p.ReadStructEnd(r)
+ case typeMap:
+ keyType, valueType, n, err := p.ReadMapBegin(r)
+ if err != nil {
+ return nil, err
+ }
+
+ mp := make(map[interface{}]interface{})
+ for i := 0; i < n; i++ {
+ k, err := ReadValue(r, p, keyType)
+ if err != nil {
+ return mp, err
+ }
+ v, err := ReadValue(r, p, valueType)
+ if err != nil {
+ return mp, err
+ }
+ mp[k] = v
+ }
+
+ return mp, p.ReadMapEnd(r)
+ case typeList:
+ valueType, n, err := p.ReadListBegin(r)
+ if err != nil {
+ return nil, err
+ }
+ lst := make([]interface{}, 0)
+ for i := 0; i < n; i++ {
+ v, err := ReadValue(r, p, valueType)
+ if err != nil {
+ return lst, err
+ }
+ lst = append(lst, v)
+ }
+ return lst, p.ReadListEnd(r)
+ case typeSet:
+ valueType, n, err := p.ReadSetBegin(r)
+ if err != nil {
+ return nil, err
+ }
+ set := make([]interface{}, 0)
+ for i := 0; i < n; i++ {
+ v, err := ReadValue(r, p, valueType)
+ if err != nil {
+ return set, err
+ }
+ set = append(set, v)
+ }
+ return set, p.ReadSetEnd(r)
+ }
+ return nil, errors.New("unknown type")
+}

0 comments on commit 3b2542a

Please sign in to comment.