-
Notifications
You must be signed in to change notification settings - Fork 0
/
index.rmd
269 lines (212 loc) · 11.6 KB
/
index.rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
---
authors:
- admin
categories:
date: "2023-06-02"
draft: false
featured: true
image:
lastmod: ""
projects: []
subtitle: ""
summary: "An illustrative example of when and how to use MASHR (multivariate adaptive shrinkage) for improved treatment effect estimates."
tags:
title: "Improving treatment effect estimates with multivariate adaptive shrinkage (mashr)"
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(collapse = TRUE)
### load libraries
library(MASS)
library(tidyverse)
library(mashr)
```
Suppose that you have many experiments and their corresponding observed treatment effects on multiple outcome metrics, or perhaps you want to determine if one metric is a good "proxy" or "surrogate" for another more difficult or expensive to measure ground truth metric of interest (see our work on the [experimentation with modeled variables playbook, xMVP](https://ryansritter.com/project/xmvp/)).
Multivariate adaptive shrinkage (mashr) is your friend. It will provide improved posterior treatment effect estimates by leveraging the information about the patterns of similarity across experiments while also effectively correcting for multiple comparisons.
## An illustrative example
In this post I illustrate the utility of mashr in a simple case where we have two outcomes of interest (metric `a` and metric `b`) and their corresponding observed treatment effects for many A/A experiments -- that is, where there is actually no true experimental effects at all. This will help demonstrate the problem of **correlated sampling error** and how mashr helps address it.
First let's start by simulating two standard normal metrics -- `a` and `b` -- with a unit level correlation of 0.40:
```{r}
set.seed(1)
# set sample size
N <- 1000000
# define two standard normal variables with a correlation of r = 0.40
data_matrix <- mvrnorm(
n = N,
mu = c(0, 0),
Sigma = matrix(c(1, .4, .4, 1), nrow = 2),
empirical = TRUE
)
df <- tibble(
id = 1:N,
a = data_matrix[, 1],
b = data_matrix[, 2]
)
# confirm unit level correlation of 0.4
df %>% summarise(correlation = cor(a, b))
```
Next let's simulate 1000 A/A tests by randomly assigning people to treatment or control groups:
```{r}
df <- df %>%
mutate(
# randomly assign participants to a treatment or control condition (i.e., a/a test)
condition = if_else(
rbinom(N, 1, .5) == 0,
"control",
"treatment"
),
# split into 1000 groups
experiment = id %% 1000
)
```
Finally we can calculate the 1k observed treatment effects for metric `a` and metric `b`:
```{r message=FALSE, warning=FALSE}
# set critical value
alpha_level <- 0.05
confidence_level <- 1 - alpha_level
critical_value <- qnorm(1 - alpha_level / 2)
# calculate observed treatment effects
treatment_effects <- df %>%
pivot_longer(
cols = c(a, b),
names_to = "measure",
values_to = "value"
) %>%
group_by(experiment, condition, measure) %>%
summarise(
mean = mean(value),
se = sd(value) / sqrt(n())
) %>%
ungroup() %>%
pivot_wider(
names_from = c(condition, measure),
values_from = c(mean, se)
) %>%
mutate(
a_delta = mean_treatment_a - mean_control_a,
a_delta_se = sqrt(se_treatment_a^2 + se_control_a^2),
a_delta_moe = a_delta_se * critical_value,
a_delta_conf.low = a_delta - a_delta_moe,
a_delta_conf.high = a_delta + a_delta_moe,
a_delta_is_sig = if_else(
a_delta_conf.low > 0 | a_delta_conf.high < 0,
1,
0
),
b_delta = mean_treatment_b - mean_control_b,
b_delta_se = sqrt(se_treatment_b^2 + se_control_b^2),
b_delta_moe = b_delta_se * critical_value,
b_delta_conf.low = b_delta - b_delta_moe,
b_delta_conf.high = b_delta + b_delta_moe,
b_delta_is_sig = if_else(
b_delta_conf.low > 0 | b_delta_conf.high < 0,
1,
0
),
)
```
What do we observe? As expected, we observe a false positive rate of around 5% for each metric. But more importantly, recall that the unit level correlation between metrics `a` and `b` is r = 0.40. Now, even though there are no "real" treatment effects in any of these experiments (i.e., they are all A/A tests), **we still expect to observe a correlation among the treatment effects around this same value (i.e., r = 0.40) due to correlated sampling error.**
```{r}
# look at correlation of treatment effects and frequency of false positives
treatment_effects %>% summarise(
delta_correlation = cor(a_delta, b_delta),
a_delta_false_positive = mean(a_delta_is_sig),
b_delta_false_positive = mean(b_delta_is_sig)
)
```
To help drive this point home, if we look at a scatter plot of these treatment effects (w/ a regression line) it looks like there could be a relationship, but it is actually entirely a result of sampling error:
```{r echo=FALSE, message=FALSE, warning=FALSE}
treatment_effects %>%
ggplot(aes(x = a_delta, y = b_delta, color = a_delta_is_sig | b_delta_is_sig)) +
# geom_linerange(aes(ymin = b_delta_conf.low, ymax = b_delta_conf.high), alpha = .05) +
# geom_linerange(aes(xmin = a_delta_conf.low, xmax = a_delta_conf.high), alpha = .05) +
geom_point(alpha = .7) +
geom_smooth(method = "lm", se = FALSE, color = "black") +
scale_color_manual(values = c("#a6cee3","#1f78b4")) +
labs(
title = "1000 simulated A/A tests (no real treatment effects)",
subtitle = "Any apparent relationship between the two observed treatment effects is due entirely to\ncorrelated sampling error",
x = "Treatment effect on metric A",
y = "Treatment effect on metric B",
caption = "Experiments in dark blue are false positives on one or both metrics\nConfidence intervals not shown for visual clarity"
) +
theme_minimal() +
theme(legend.position = "none")
```
Given that we know the "real" treatment effect for all of these experiments is actually 0, what can we do to get better estimates?
Apply multivariate adaptive shrinkage (mashr).
## Applying mashr
The [mashr](https://cran.r-project.org/web/packages/mashr/index.html) R package (see [Urbut et al., 2019](https://www.nature.com/articles/s41588-018-0268-8)) implements the multivariate adaptive shrinkage method.
At its core, all we need to provide `mashr` is a matrix of the observed treatment effect estimates (`Bhat`) and a corresponding matrix of the standard errors of those treatment effect estimates (`Shat`):
```{r}
Bhat <- treatment_effects %>%
select(
a_delta,
b_delta
) %>%
as.matrix()
Shat <- treatment_effects %>%
select(
a_delta_se,
b_delta_se
) %>%
as.matrix()
mashr_data <- mash_set_data(Bhat, Shat)
```
In settings where the outcome measurements are correlated with each other (like ours), `mashr` also provides a method to incorporate this information to help reduce false positives. Specifically, we can use an estimate of the correlation in treatment effects among null experiments:
```{r}
null_correlations <- estimate_null_correlation_simple(mashr_data)
null_correlations
```
And then we can update our `mashr` data object with this information:
```{r}
mashr_data_w_correlations <- mash_update_data(mashr_data, V = null_correlations)
```
Finally, we create a set of covariance matrices that `mashr` will use to fit the model. The estimated mixture proportions assigned to each of these matrices will help us determine which is most consistent with the data. `mashr` provides methods for both "canonical" and "data driven" covariance matrices, and here we'll just use the canonical matrices:
* Identity matrix: Effects are independent among conditions
* Singleton matrices: Effects are specific to a given metric (`a` or `b`) only
* Equal effects matrix: Effects are equal for both metrics
* Simple het matrices: Effects are correlated to varying degrees
```{r}
# specify the canonical covariance matrices
U.c <- cov_canonical(mashr_data_w_correlations)
```
Now we can fit the model!
```{r}
mashr_results <- mash(mashr_data_w_correlations, U.c)
```
`mashr` then provides a series of very useful functions to extract relevant information from the model.
For example, we can get the improved posterior treatment effect estimates and their corresponding standard errors (which we could choose to visualize in a plot):
```{r}
# get posterior means
head(get_pm(mashr_results) %>% as_tibble())
# get posterior SDs
head(get_psd(mashr_results) %>% as_tibble())
```
Or we can find experiments where the improved posterior treatment effect estimate is statistically significant in at least one condition.
*Note: we find no "significant" conditions here because mashr has (correctly!) learned that most experiments are null and shrunk the treatment effect estimates toward 0*
```{r}
# get effects that are significant in at least one condition
get_significant_results(mashr_results)
```
Understanding how often two treatment effects share effects is particularly useful for evaluating the utility of a potential "proxy" or "surrogate" metric for another metric that is more expensive or difficult to measure (e.g., if my proxy candidate moves, how often does my ground truth metric of interest also move at least in the same direction but ideally also within a similar magnitude?).
*Note: we again find no relevant signals here because mashr has correctly identified that there are actually no "real" treatment effects in these experiments... but these functions are very useful in applied use cases.*
```{r}
# get the proportion of (significant) signals shared by some specified magnitude
get_pairwise_sharing(mashr_results, factor = 0.5)
# get the proportion of (significant) signals that at least share the same sign
get_pairwise_sharing(mashr_results, factor = 0)
```
Finally we can look at the mixture proportions for the different types of covariance matrices used in the modeling. This is very useful to understand which covariance matrices are most consistent with the observed data (where in this case the model correctly assigns the overwhelming majority of the weight to the null covariance matrix, meaning the data is most consistent with entirely null effects):
```{r}
# get the estimated mixture proportions
get_estimated_pi(mashr_results)
```
## Takeaway
Any time we have observed treatment effects on multiple outcomes -- here we just used two outcomes but in practice we could have many, many more! -- multivariate adaptive shrinkage (mashr) is an extremely valuable tool for deriving improved (posterior) treatment effect estimates for all of them by leveraging all the shared information across experiments.
`mashr` is also particularly useful when evaluating the utility of a "proxy" or "surrogate" metric, because it will result in better estimates of how often the proxy vs. ground truth treatment effects actually correspond with each other in direction and/or magnitude.
*Note: full R Markdown used for this post is available @ https://github.com/ryansritter/mashr_intro*
### Recommendations for further reading
* [Urbut et al., 2019](https://www.nature.com/articles/s41588-018-0268-8) - Flexible statistical methods for estimating and testing effects in genomic studies with multiple conditions
* [mashr](https://cran.r-project.org/web/packages/mashr/index.html) R package and its corresponding vignettes
* [Cunningham & Kim, 2020](https://joshkim.org/files/InterpretingExperiments.pdf) - Interpreting experiments with multiple outcomes
* [Don’t be seduced by the allure: A guide for how (not) to use machine learning metrics in experiments](https://ryansritter.com/project/xmvp/)