-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
TTwitterClientFilterTest.scala
141 lines (112 loc) · 5.66 KB
/
TTwitterClientFilterTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package com.twitter.finagle.thrift
import com.twitter.finagle.Service
import com.twitter.finagle.tracing.ClientTracingFilter.TracingFilter
import com.twitter.finagle.tracing._
import com.twitter.io.Buf
import com.twitter.util.Future
import org.apache.thrift.protocol.TMessageType
import org.apache.thrift.protocol.TMessage
import org.apache.thrift.protocol.TBinaryProtocol
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
import org.scalatestplus.mockito.MockitoSugar
import scala.collection.JavaConverters._
import org.scalatest.funsuite.AnyFunSuite
class TTwitterClientFilterTest extends AnyFunSuite with MockitoSugar {
val protocolFactory = new TBinaryProtocol.Factory()
test("TTwitterClientFilter should set sampled boolean correctly") {
val tracer = mock[Tracer]
//tracer.sampleTrace(any(classManifest[TraceId])) returns Some(true)
when(tracer.sampleTrace(any(classOf[TraceId]))).thenReturn(Some(true))
val filter = new TTwitterClientFilter("service", true, None, protocolFactory)
val buffer = new OutputBuffer(protocolFactory)
buffer().writeMessageBegin(new TMessage(ThriftTracing.CanTraceMethodName, TMessageType.CALL, 0))
val options = new thrift.ConnectionOptions
options.write(buffer())
buffer().writeMessageEnd()
val tracing = new TraceInitializerFilter[ThriftClientRequest, Array[Byte]](tracer, true)
.andThen(new TracingFilter[ThriftClientRequest, Array[Byte]]("TTwitterClientFilterTest"))
val service = mock[Service[ThriftClientRequest, Array[Byte]]]
val _request = ArgumentCaptor.forClass(classOf[ThriftClientRequest])
when(service(_request.capture)).thenReturn(Future(Array[Byte]()))
val stack = tracing andThen filter
stack(new ThriftClientRequest(buffer.toArray, false), service)
val header = new thrift.RequestHeader
InputBuffer.peelMessage(_request.getValue.message, header, protocolFactory)
assert(header.isSampled)
}
test("TTwitterClientFilter should create header correctly") {
val traceId = TraceId(Some(SpanId(1L)), None, SpanId(2L), Some(true), Flags().setDebug)
Trace.letId(traceId) {
val filter = new TTwitterClientFilter("service", true, None, protocolFactory)
val buffer = new OutputBuffer(protocolFactory)
buffer().writeMessageBegin(new TMessage("method", TMessageType.CALL, 0))
val options = new thrift.ConnectionOptions
options.write(buffer())
buffer().writeMessageEnd()
val service = mock[Service[ThriftClientRequest, Array[Byte]]]
val _request = ArgumentCaptor.forClass(classOf[ThriftClientRequest])
when(service(_request.capture)).thenReturn(Future(Array[Byte]()))
filter(new ThriftClientRequest(buffer.toArray, false), service)
val header = new thrift.RequestHeader
InputBuffer.peelMessage(_request.getValue.message, header, protocolFactory)
assert(header.getTrace_id == 1L)
assert(header.getSpan_id == 2L)
assert(!header.isSetParent_span_id)
assert(header.isSampled)
assert(header.isSetFlags)
assert(header.getFlags == 1L)
}
}
test("TTwitterClientFilter should set ClientId in both header and context") {
val clientId = ClientId("foo.bar")
val filter = new TTwitterClientFilter("service", true, Some(clientId), protocolFactory)
val buffer = new OutputBuffer(protocolFactory)
buffer().writeMessageBegin(new TMessage("method", TMessageType.CALL, 0))
val options = new thrift.ConnectionOptions
options.write(buffer())
buffer().writeMessageEnd()
val service = mock[Service[ThriftClientRequest, Array[Byte]]]
val _request = ArgumentCaptor.forClass(classOf[ThriftClientRequest])
when(service(_request.capture)).thenReturn(Future(Array[Byte]()))
filter(new ThriftClientRequest(buffer.toArray, false), service)
val header = new thrift.RequestHeader
InputBuffer.peelMessage(_request.getValue.message, header, protocolFactory)
assert(header.getContexts != null)
val clientIdContextWasSet = header.getContexts.asScala exists { c =>
(Buf.ByteArray.Owned(c.getKey()) == ClientId.clientIdCtx.marshalId) &&
(Buf.ByteArray.Owned(c.getValue()) == Buf.Utf8(clientId.name))
}
assert(header.getClient_id.getName == clientId.name)
assert(clientIdContextWasSet == true)
}
test("TTwitterClientFilter should not be overrideable with externally-set ClientIds") {
val clientId = ClientId("foo.bar")
val otherClientId = ClientId("other.bar")
val filter = new TTwitterClientFilter("service", true, Some(clientId), protocolFactory)
val buffer = new OutputBuffer(protocolFactory)
buffer().writeMessageBegin(new TMessage("method", TMessageType.CALL, 0))
val options = new thrift.ConnectionOptions
options.write(buffer())
buffer().writeMessageEnd()
val service = mock[Service[ThriftClientRequest, Array[Byte]]]
val _request = ArgumentCaptor.forClass(classOf[ThriftClientRequest])
when(service(_request.capture)).thenReturn(Future(Array[Byte]()))
otherClientId.asCurrent {
filter(new ThriftClientRequest(buffer.toArray, false), service)
}
val header = new thrift.RequestHeader
InputBuffer.peelMessage(_request.getValue.message, header, protocolFactory)
val clientIdContextWasSet = header.getContexts.asScala exists { c =>
(Buf.ByteArray.Owned(c.getKey()) == ClientId.clientIdCtx.marshalId) &&
(Buf.ByteArray.Owned(c.getValue()) == Buf.Utf8(clientId.name))
}
assert(header.getClient_id.getName == clientId.name)
assert(
clientIdContextWasSet == true,
"expected ClientId was not set in the ClientIdContext: expected: %s, actual: %s"
.format(clientId.name, header.getClient_id.getName)
)
}
}