diff --git a/zkplus/zktest/zktest.go b/zkplus/zktest/zktest.go index 7182e36..0e2e4cb 100644 --- a/zkplus/zktest/zktest.go +++ b/zkplus/zktest/zktest.go @@ -2,6 +2,7 @@ package zktest import ( "errors" + "reflect" "strings" "sync" "sync/atomic" @@ -134,15 +135,27 @@ func (z *MemoryZkServer) Conn() (ZkConnSupported, <-chan zk.Event, error) { return z.Connect() } +// based on MemoryZkServer logger type set the logger for ZkConn +// as we got to honor if the logger is set to discard +func (z *ZkConn) setZkConnLogger(logger log.Logger, nextID *int64) { + if reflect.TypeOf(logger) == reflect.TypeOf(log.Discard) { + z.Logger = log.Discard + } else { + z.Logger = log.NewContext(z.Logger).With("id", atomic.AddInt64(nextID, 1)) + } +} + // Connect to this server func (z *MemoryZkServer) Connect() (*ZkConn, <-chan zk.Event, error) { r := &ZkConn{ connectedTo: z, - Logger: log.NewContext(z.Logger).With("id", atomic.AddInt64(&z.nextID, 1)), events: make(chan zk.Event, 1000), pathWatch: make(map[string]chan zk.Event), chanTimeout: z.ChanTimeout, } + + r.setZkConnLogger(z.Logger, &z.nextID) + z.childrenConnectionsLock.Lock() defer z.childrenConnectionsLock.Unlock() z.childrenConnections[r] = struct{}{} diff --git a/zkplus/zktest/zktest_test.go b/zkplus/zktest/zktest_test.go index fbdde5c..f0cb8f6 100644 --- a/zkplus/zktest/zktest_test.go +++ b/zkplus/zktest/zktest_test.go @@ -418,3 +418,12 @@ func testChildrenWNotHere(t *testing.T, z ZkConnSupported, z2 ZkConnSupported, _ case <-time.After(time.Microsecond): } } + +func TestDiscardLoggerSetup(t *testing.T) { + zkConn := &ZkConn{} + nextID := int64(0) + zkConn.setZkConnLogger(log.Discard, &nextID) + assert.Equal(t, zkConn.Logger, log.Discard) + zkConn.setZkConnLogger(log.DefaultLogger, &nextID) + assert.NotEqual(t, zkConn.Logger, log.Discard) +}