/
AdaptTrackingDecoder.scala
128 lines (109 loc) · 4.63 KB
/
AdaptTrackingDecoder.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package com.twitter.scrooge.adapt
import com.twitter.scrooge.{ThriftStruct, ThriftStructCodec}
import java.lang.reflect.InvocationTargetException
import java.util.concurrent.atomic.AtomicInteger
private[adapt] object AdaptTrackingDecoder {
val AdaptSuffix = "$Adapt"
val AdaptDecoderSuffix = "$AdaptDecoder"
val DecodeMethodName = "decode"
}
/**
* A thrift decoder that adapts itself based on usage pattern of generated
* thrift objects. Goal is to minimize costs for unused fields. This is done by
* skipping unused fields during parsing and setting up a mechanism for them to
* be decoded later on access. Delayed decoding is typically the regular eager
* decoding so it's expensive because we end up doing decoding twice.
* Expectation is that fields that are considered unused will rarely be accessed.
* When fields are considered unused is decided using useThreshold.
* @param fallbackDecoder Sometimes it may not be worth doing adaptive decoding,
* fallback to this decoder in those cases.
* @param accessRecordingDecoderBuilder Builder for decoder used during learning
* phase. Allows injecting AccessRecorder
* to learn about how fields are accessed.
* @param settings Settings that govern how adaptation is done
* @param classLoader ClassLoader used to load the adapted classes generated
* at runtime.
*/
private[adapt] class AdaptTrackingDecoder[T <: ThriftStruct](
codec: ThriftStructCodec[T],
fallbackDecoder: Decoder[T],
accessRecordingDecoderBuilder: AccessRecorder => Decoder[T],
settings: AdaptSettings,
classLoader: AdaptClassLoader
) extends AccessRecorder
with Decoder[T] {
import AdaptTrackingDecoder._
private[this] val trackedCount = new AtomicInteger()
private[this] val fieldAccessCounts: Map[Short, AtomicInteger] =
codec.metaData.fields.map { f =>
(f.id, new AtomicInteger(0))
}.toMap
def fieldAccessed(fieldId: Short): Unit =
fieldAccessCounts(fieldId).getAndIncrement()
@volatile private[this] var adaptiveDecoder: Decoder[T] = _
private[this] def allFieldsUsed(useMap: Map[Short, Boolean]): Boolean =
useMap.values.forall(identity)
private[this] def buildDecoder(): Decoder[T] = {
val useMapByField = codec.metaData.fields.map { f =>
(f, fieldAccessCounts(f.id).get >= settings.useThreshold)
}.toMap
val useMapByName = useMapByField.map {
case (f, v) =>
val normalizedName = CaseConverter.toCamelCase(f.name)
(normalizedName, v)
}
val useMapById = useMapByField.map { case (f, v) => (f.id, v) }
if (allFieldsUsed(useMapById)) {
fallbackDecoder
} else {
buildAdaptiveDecoder(useMapByName, useMapById)
}
}
private[this] def buildAdaptiveDecoder(
useMapByName: Map[String, Boolean],
useMapById: Map[Short, Boolean]
): Decoder[T] = {
val codecClassName = codec.getClass.getName
val adaptFqdn = codecClassName + AdaptSuffix
val adaptDecoderFqdn = codecClassName + AdaptDecoderSuffix
// Prune AdaptTemplate to create Adapt and load it
val adaptClassBytes = AdaptAsmPruner.pruneAdapt(adaptFqdn, useMapByName)
classLoader.defineClass(adaptFqdn, adaptClassBytes)
// Prune AdaptDecoderTemplate to create AdaptDecoder and load it
val adaptDecoderClassBytes =
AdaptAsmPruner.pruneAdaptDecoder(adaptDecoderFqdn, useMapById)
val decoderClass = classLoader.defineClass(adaptDecoderFqdn, adaptDecoderClassBytes)
val prunedDecoder = decoderClass.newInstance()
val decodeMethod = decoderClass.getMethod(DecodeMethodName, classOf[AdaptTProtocol])
new Decoder[T] {
def apply(prot: AdaptTProtocol): T = {
try {
decodeMethod.invoke(prunedDecoder, prot).asInstanceOf[T]
} catch {
case e: InvocationTargetException if e.getCause != null =>
// Throw the original exception if present
throw e.getCause
}
}
}
}
def apply(prot: AdaptTProtocol): T = {
if (adaptiveDecoder != null) {
adaptiveDecoder(prot)
} else {
/**
* Note that we only block one event, one that makes trackedCount
* reach settings.trackedReads, to build the decoder. Subsequent
* events will continue to use accessRecordingDecoderBuilder until
* adaptiveDecoder is built. At which point adaptiveDecoder takes
* over.
*/
if (trackedCount.incrementAndGet == settings.trackedReads + 1) {
val decoder = buildDecoder()
adaptiveDecoder = decoder
decoder(prot)
} else
accessRecordingDecoderBuilder(this)(prot)
}
}
}