-
Notifications
You must be signed in to change notification settings - Fork 0
/
objective.Rmd
175 lines (136 loc) · 7.23 KB
/
objective.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
---
title: "Decreases in the Objective Function"
author: "Jason Willwerscheid"
date: "7/14/2018"
output:
workflowr::wflow_html
---
## Introduction
Here I begin to look into why the FLASH objective function can decrease after an iteration.
## Illustration of problem
I'm using the "strong" tests from the MASH paper GTEx dataset. The first problem appears when fitting the fourth factor. Notice that in the final iteration, the objective decreases by a very small amount and a warning is displayed.
```{r example}
# devtools::install_github("stephenslab/flashr", ref="trackObj")
devtools::load_all("/Users/willwerscheid/GitHub/flashr")
# devtools::install_github("stephenslab/ebnm")
devtools::load_all("/Users/willwerscheid/GitHub/ebnm")
gtex <- readRDS(gzcon(url("https://github.com/stephenslab/gtexresults/blob/master/data/MatrixEQTLSumStats.Portable.Z.rds?raw=TRUE")))
strong <- gtex$strong.z
res <- flash_add_greedy(strong, Kmax=3, verbose=FALSE)
res <- flash_add_greedy(strong, f_init=res$f, Kmax=1, verbose=TRUE)
```
## Analysis
A more granular tracking of the objective function reveals a larger problem. Recall that there are three steps in each iteration: updating the precision matrix, updating the factors (via the prior $g_f$), and updating the loadings (via $g_l$). Plotting the objective after each step rather than each iteration reveals a sawtooth pattern. (See branch `trackObj`, file `r1_opt.R` for the code used to obtain these results.)
```{r plot}
obj_data <- as.vector(rbind(res$obj[[1]]$after_tau,
res$obj[[1]]$after_f,
res$obj[[1]]$after_l))
max_obj <- max(obj_data)
obj_data <- obj_data - max_obj
iter <- 1:length(obj_data) / 3
plt_xlab = "Iteration"
plt_ylab = "Diff. from maximum obj."
plot(iter, obj_data, type='l', xlab=plt_xlab, ylab=plt_ylab)
```
Discarding the first 8 iterations in order to zoom in on the problem area:
```{r plot2}
obj_data <- obj_data[-(1:24)]
iter <- iter[-(1:24)]
plt_colors <- c("indianred1", "indianred3", "indianred4")
plt_pch <- c(16, 17, 15)
plot(iter, obj_data, col=plt_colors, pch=plt_pch,
xlab=plt_xlab, ylab=plt_ylab)
legend("bottomright", c("after tau", "after f", "after l"),
col=plt_colors, pch=plt_pch)
```
I backtrack to just before the "bad" update.
```{r slow1}
res2 <- flash_add_greedy(strong, Kmax=4, stopAtObj=-1297147.7)
flash_get_objective(strong, res2$f) - flash_get_objective(strong, res$f)
```
So at this point, the objective is indeed better than for the flash object attained above. The component parts of the objective are:
```{r slow2}
fl <- res2$f
data <- flash_set_data(strong)
k <- 4
KL_l <- fl$KL_l[[k]]
KL_f <- fl$KL_f[[k]]
loglik <- flashr:::e_loglik(data, fl)
list(KL_l = KL_l, KL_f = KL_f, loglik = loglik)
```
First I update the precision (I follow the code in `r1_opt`). Only the "loglik" component is affected by this update:
```{r slow3}
init_fl = fl
init_KL_l = KL_l
init_KL_f = KL_f
init_loglik = loglik
R2 = flashr:::flash_get_R2(data, fl)
fl$tau = flashr:::compute_precision(R2, data$missing,
"by_column", data$S)
flashr:::e_loglik(data, fl) - init_loglik
```
So the overall objective indeed increases. Now I update the loadings (FLASH updates factors first, but the order of updates is not supposed to affect the monotonicity of the objective function).
```{r slow4}
s2 = 1/(fl$EF2[, k] %*% t(fl$tau))
s = sqrt(s2)
Rk = flashr:::flash_get_Rk(data, fl, k)
x = fl$EF[, k] %*% t(Rk * fl$tau) * s2
ebnm_l = flashr:::ebnm_pn(x, s, list())
KL_l = (ebnm_l$penloglik
- flashr:::NM_posterior_e_loglik(x, s, ebnm_l$postmean,
ebnm_l$postmean2))
fl$EL[, k] = ebnm_l$postmean
fl$EL2[, k] = ebnm_l$postmean2
fl$gl[[k]] = ebnm_l$fitted_g
fl$KL_l[[k]] = KL_l
flash_get_objective(data, fl) - flash_get_objective(data, init_fl)
```
So the objective has in fact gotten worse. And tightening the control parameters or changing the initialization for the `ebnm` function does not help matters. For example:
```{r slow5}
s2 = 1/(fl$EF2[, k] %*% t(fl$tau))
s = sqrt(s2)
Rk = flashr:::flash_get_Rk(data, fl, k)
x = fl$EF[, k] %*% t(Rk * fl$tau) * s2
ebnm_l = flashr:::ebnm_pn(x, s, list(startpar=c(5,5),
control=list(factr=100)))
KL_l = (ebnm_l$penloglik
- flashr:::NM_posterior_e_loglik(x, s, ebnm_l$postmean,
ebnm_l$postmean2))
fl$EL[, k] = ebnm_l$postmean
fl$EL2[, k] = ebnm_l$postmean2
fl$gl[[k]] = ebnm_l$fitted_g
fl$KL_l[[k]] = KL_l
flash_get_objective(data, fl) - flash_get_objective(data, init_fl)
```
## Perturbation analysis
It's possible that numerical error is responsible for the decrease, but it seems unlikely to me that this is the whole story.
Indeed, assume that numerical error is sufficient to explain the decrease. Recall that the objective consists of a part that is calculated from `R2` and `tau`, a part that comes from `KL_l`, and a part that comes from `KL_f`. The first part is coded as `-0.5 * sum(log((2 * pi)/tau) + tau * R2)`, and R2 is updated as `R2k - 2 * outer(l, f) * Rk + outer(l2, f2)` (where `Rk` is residuals for all factors but the `k`th and similarly for `R2k`). The *updated* parts of the objective have magnitude:
```{r obj1}
sum(fl$tau * outer(fl$EL[, k], fl$EF[, k]) * Rk)
-0.5 * sum(fl$tau * outer(fl$EL2[, k], fl$EF2[, k]))
```
So, errors in the sixth digit of either of these components could explain the decrease in the objective function. Let there be errors in the updates to `EL2` and consider the latter part of the objective:
$$ -\frac{1}{2} \sum_{i, j} \tau_{i, j} \left( \bar{l^2}_i + \epsilon_i \right) \bar{f^2}_j
= -\frac{1}{2} \sum_i \bar{l^2}_i \sum_j \tau_{i, j} \bar{f^2}_j -\frac{1}{2} \sum_i \epsilon_i \sum_j \tau_{i, j} \bar{f^2}_j $$
so we'd need to see errors in (roughly) the sixth digit of `EL2`. A similar calculation shows that errors in the sixth digit of `EL` could suffice to explain the decrease.
To test this hypothesis, I check to see what happens if only five digits are retained when performing the above calculations.
```{r slow6}
last_obj = flash_get_objective(data, fl)
digits = 5
s2 = 1/(fl$EF2[, k] %*% t(fl$tau))
s = sqrt(s2)
Rk = flashr:::flash_get_Rk(data, fl, k)
x = fl$EF[, k] %*% t(Rk * fl$tau) * s2
ebnm_l = flashr:::ebnm_pn(x, s, list())
KL_l = (ebnm_l$penloglik
- flashr:::NM_posterior_e_loglik(x, s, ebnm_l$postmean,
ebnm_l$postmean2))
fl$EL[, k] = signif(ebnm_l$postmean, digits=digits)
fl$EL2[, k] = signif(ebnm_l$postmean2, digits=digits)
fl$gl[[k]] = ebnm_l$fitted_g
fl$KL_l[[k]] = KL_l
flash_get_objective(data, fl) - last_obj
```
So an overall error that is *roughly* on the scale of the decrease in objective function is produced.
## Conclusions and questions
Still, the error is not quite as large, and it would be very surprising to me if `EL` and `EL2` could only be trusted to five digits. More seriously, the sawtooth pattern discussed above points to a more regular feature of the optimization. Indeed, it appears that all of the triangles (objectives after updating factors) are biased upwards and all of the squares (objectives after updating loadings) are biased slightly downwards. Still, this would not explain the decrease in the objective that occurs after a complete iteration.