/
metrics_for_slice.proto
404 lines (349 loc) · 14 KB
/
metrics_for_slice.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package tensorflow_model_analysis;
import "google/protobuf/wrappers.proto";
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A sub key identifies specialized sub-types of metrics and plots.
message SubKey {
oneof type {
// Used with multi-class metrics to identify a specific class ID.
google.protobuf.Int32Value class_id = 1;
// Used with multi-class metrics to identify the kth predicted value.
google.protobuf.Int32Value k = 2;
// Used with multi-class and ranking metrics to identify top-k predicted
// values.
google.protobuf.Int32Value top_k = 3;
}
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A metric key uniquely identifies a metric.
message MetricKey {
// Name of the metric ('auc', etc).
string name = 1;
// Optional model name associated with metric (if multi-model evaluation).
string model_name = 4;
// Optional output name associated with metric (for multi-output models).
string output_name = 2;
// Optional sub key associated with metric.
SubKey sub_key = 3;
// If True, this metric is a diff metric based on a comparison with the
// baseline.
bool is_diff = 5;
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync the MetricValue types with PerformanceStatistics because of b/110954446.
// LINT.IfChange
// The value will be converted into an error message if we do not know its type.
message UnknownType {
string error = 1;
bytes value = 2;
}
// Represents a real value which could be a pointwise estimate, optionally with
// approximate bounds of some sort. For instance, for AUC, these bounds could be
// the upper and lower Riemann sum of the integral.
message BoundedValue {
// The lower bound of the range.
google.protobuf.DoubleValue lower_bound = 1;
// The upper bound of the range.
google.protobuf.DoubleValue upper_bound = 2;
// Represents an exact value if the lower_bound and upper_bound are unset,
// else it's an approximate value. For the approximate value, it should be
// within the range [lower_bound, uppper_bound].
google.protobuf.DoubleValue value = 3;
enum Methodology {
UNKNOWN = 0;
// Used to calculate AUC, the upper and lower Riemann sum for an integral.
RIEMANN_SUM = 1;
// Used to calculate confidence intervals using Poisson bootstrapping.
// For more details, please see:
// http://www.unofficialgoogledatascience.com/2015/08/an-introduction-to-poisson-bootstrap26.html
POISSON_BOOTSTRAP = 2;
}
// Optionally describe the methodology that was used to calculate the bounds.
Methodology methodology = 4;
}
// Represents a t-distribution, which includes sample mean, sample standard
// deviation and degrees of freedom of samples. It's calculated when evaluation
// runs on multiple samples, which by default generated by the Poisson
// bootstrapping method:
// http://www.unofficialgoogledatascience.com/2015/08/an-introduction-to-poisson-bootstrap26.html
message TDistributionValue {
// Sample Mean.
google.protobuf.DoubleValue sample_mean = 1;
// Sample Standard Deviation.
google.protobuf.DoubleValue sample_standard_deviation = 2;
// Number of degrees of freedom.
google.protobuf.Int64Value sample_degrees_of_freedom = 3;
// Represents the value of data if calculated without bootstrapping.
// This field is deprecated as going forward we will remove the
// TDistributionValue from the oneof in the MetricValue and the unsampled
// value will be populated in MetricValue.double_value
google.protobuf.DoubleValue unsampled_value = 4 [deprecated = true];
}
// Value at cutoffs, e.g. for precision@K, recall@K
message ValueAtCutoffs {
message ValueCutoffPair {
int32 cutoff = 1;
double value = 2;
// Bounded Value representation of the data. Populate both values. See
// comments in ConfusionMatrixAtThresholds for migration plan details.
BoundedValue bounded_value = 3;
TDistributionValue t_distribution_value = 4;
}
repeated ValueCutoffPair values = 1;
}
// Confusion matrix at thresholds.
message ConfusionMatrixAtThresholds {
message ConfusionMatrixAtThreshold {
double threshold = 1;
double false_negatives = 2;
double true_negatives = 3;
double false_positives = 4;
double true_positives = 5;
double precision = 6;
double recall = 7;
// Bounded values. These will be provided if
// 1. the values (fn, tn, ...) are calculated with sampling AND
// 2. the confidence level is provided (it's hard coded as 0.95 currently).
// We have plan to deprecate BoundedValue in favor of TDistributionValue.
BoundedValue bounded_false_negatives = 8;
BoundedValue bounded_true_negatives = 9;
BoundedValue bounded_false_positives = 10;
BoundedValue bounded_true_positives = 11;
BoundedValue bounded_precision = 12;
BoundedValue bounded_recall = 13;
// The t-distribution value. These will be provided if the
// values (fn, tn, ...) are calculated with sampling.
TDistributionValue t_distribution_false_negatives = 14;
TDistributionValue t_distribution_true_negatives = 15;
TDistributionValue t_distribution_false_positives = 16;
TDistributionValue t_distribution_true_positives = 17;
TDistributionValue t_distribution_precision = 18;
TDistributionValue t_distribution_recall = 19;
}
// Matrices has different types of value representations: bounded,
// t-distribution and double.
// 1. Bounded values will be provided if the metrices are calculated using
// bootstrapping (Note: Confidence level is set to 95%).
// 2. T distribution values will be provided if metrices are calculated using
// bootstrapping and confidence level isn't set. Hence user will config
// the confidece levels through the frontend to get the final confidence
// intervals. We will support both TDistributionValue and BoundedValue now.
// But BoundedValue will be eventually deprecated.
// 3. Double values is being deprecated.
repeated ConfusionMatrixAtThreshold matrices = 1;
}
// For metrics which return an array of values.
message ArrayValue {
enum DataType {
UNKNOWN = 0;
BYTES = 1;
INT32 = 2;
INT64 = 3;
FLOAT32 = 4;
FLOAT64 = 5;
}
DataType data_type = 1;
repeated int32 shape = 2;
// Exactly one of these fields, corresponding to the data type, should be set.
repeated bytes bytes_values = 3;
repeated int32 int32_values = 4;
repeated int64 int64_values = 5;
repeated float float32_values = 6;
repeated double float64_values = 7;
}
// It stores metrics values in different types, so that the frontend will know
// how to visualize the values based on the types.
message MetricValue {
oneof type {
// Since BoundedValue can represent a point estimate without intervals it
// makes sense to use that in all non-plot cases in the future.
// Migration plan is to proceed as follows:
// - modify callers to always populate both fields
// - update frontend to preferentially look at the new bounded fields
// - stop populating the old fields
google.protobuf.DoubleValue double_value = 1;
BoundedValue bounded_value = 2;
TDistributionValue t_distribution_value = 9 [deprecated = true];
ValueAtCutoffs value_at_cutoffs = 4;
ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 5;
MultiClassConfusionMatrixAtThresholds
multi_class_confusion_matrix_at_thresholds = 11;
UnknownType unknown_type = 3;
bytes bytes_value = 6;
ArrayValue array_value = 7;
// This field will contain a generic message to be used to communicate any
// extra information, such as in a scenario when no data is aggregated for a
// small data slice due to privacy concerns.
string debug_message = 10;
}
message ConfidenceInterval {
// The lower bound of the range.
google.protobuf.DoubleValue lower_bound = 11;
// The upper bound of the range.
google.protobuf.DoubleValue upper_bound = 12;
// The T distribution value used to compute the confidence interval.
TDistributionValue t_distribution_value = 13;
}
// Going forward, TFMA will populate this when confidence intervals is
// enabled.
ConfidenceInterval confidence_interval = 14;
reserved 8;
// Next tag = 16;
}
// LINT.ThenChange(
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// A single slice key.
message SingleSliceKey {
string column = 1;
oneof kind {
bytes bytes_value = 2;
float float_value = 3;
int64 int64_value = 4;
}
}
// A slice key, which may consist of multiple single slice keys.
message SliceKey {
repeated SingleSliceKey single_slice_keys = 1;
}
// CrossSliceKey contains two slices which are compared with each other.
message CrossSliceKey {
SliceKey baseline_slice_key = 1;
SliceKey comparison_slice_key = 2;
}
message MetricsForSlice {
message MetricKeyAndValue {
MetricKey key = 1;
MetricValue value = 2;
}
// The slice key for the metrics.
SliceKey slice_key = 1;
// Metric keys and values.
repeated MetricKeyAndValue metric_keys_and_values = 51;
// DEPRECATED
// A map to store metrics. Currently we convert the post_export_metric
// provided by TFMA to its appropriate type for better visualization, and map
// all other metrics to DoubleValue type.
map<string, MetricValue> metrics = 2 [deprecated = true];
}
// Sync the Plot types with PerformanceStatistics because of b/110954446.
// LINT.IfChange
message CalibrationHistogramBuckets {
message Bucket {
double lower_threshold_inclusive = 1;
double upper_threshold_exclusive = 2;
google.protobuf.DoubleValue num_weighted_examples = 3;
google.protobuf.DoubleValue total_weighted_label = 4;
google.protobuf.DoubleValue total_weighted_refined_prediction = 5;
}
repeated Bucket buckets = 1;
}
message MultiClassConfusionMatrixAtThresholds {
message MultiClassConfusionMatrixEntry {
int32 actual_class_id = 1;
int32 predicted_class_id = 2;
double num_weighted_examples = 3;
}
message MultiClassConfusionMatrix {
double threshold = 1;
// Only entries with non-zero num_weighted_examples are included. If the top
// prediction was less than the threshold, then the predict_class_id will be
// set to -1. Entries are sorted in order of actual_class_id followed by
// predicted_class_id.
repeated MultiClassConfusionMatrixEntry entries = 2;
}
// Entries are sorted in order of threshold.
repeated MultiClassConfusionMatrix matrices = 1;
}
message MultiLabelConfusionMatrixAtThresholds {
message MultiLabelConfusionMatrixEntry {
int32 actual_class_id = 1;
int32 predicted_class_id = 2;
double false_negatives = 3;
double true_negatives = 4;
double false_positives = 5;
double true_positives = 6;
}
message MultiLabelConfusionMatrix {
double threshold = 1;
// Only entries with no non-zero values are included. Entries are sorted in
// order of actual_class_id followed by predicted_class_id.
repeated MultiLabelConfusionMatrixEntry entries = 2;
}
// Entries are sorted in order of threshold.
repeated MultiLabelConfusionMatrix matrices = 1;
}
message PlotData {
// For calibration plot and prediction distribution.
CalibrationHistogramBuckets calibration_histogram_buckets = 1;
// For auc curve and auprc curve.
ConfusionMatrixAtThresholds confusion_matrix_at_thresholds = 2;
// For multi-class confusion matrix.
MultiClassConfusionMatrixAtThresholds
multi_class_confusion_matrix_at_thresholds = 4;
// For multi-label confusion matrix.
MultiLabelConfusionMatrixAtThresholds
multi_label_confusion_matrix_at_thresholds = 5;
// This field will contain a generic message to be used to communicate any
// extra information, such as in a scenario when no data is aggregated for a
// small data slice due to privacy concerns.
string debug_message = 3;
}
// LINT.ThenChange(
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
// Sync with PerformanceStatistics because of b/110954446.
// LINT.IfChange
//
// A plot key uniquely identifies a set of PlotData.
message PlotKey {
// Optional model name associated with metric (if multi-model evaluation).
string model_name = 4;
// Optional output name associated with plot (for multi-output models).
string output_name = 2;
// Optional sub key associated with plot.
SubKey sub_key = 3;
}
// LINT.ThenChange(
// ../metrics/metric_types.py,
// ../../../../intelligence/lantern/proto/stats/performance_statistics.proto)
message PlotsForSlice {
message PlotKeyAndValue {
PlotKey key = 1;
PlotData value = 2;
}
// The slice key for the metrics.
SliceKey slice_key = 1;
// Plot keys and values.
repeated PlotKeyAndValue plot_keys_and_values = 8;
// DEPRECATED
// The plot data--deprecated please use 'plots' instead.
PlotData plot_data = 2 [deprecated = true];
// Use this field instead of tfma_plots to support multiple plot evaluations
// in a single evaluator run. Note that each entry of TFMAPlotData should
// contain all plots for the same grouping. eg: for the same head of a
// multihead model or for the same class in the case of multiclass. For
// example, the key can be of the form 'post_export_metrics/head_name' for a
// multihead model.
map<string, PlotData> plots = 3 [deprecated = true];
}