-
Notifications
You must be signed in to change notification settings - Fork 0
/
obj_notes.Rmd
97 lines (83 loc) · 4.28 KB
/
obj_notes.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
---
title: "Notes on Computing the FLASH Objective"
author: "Jason Willwerscheid"
date: "7/17/2018"
output:
workflowr::wflow_html
---
## Indirect method
Recall the FLASH model:
$$ Y = LF' + E $$
When updating loading $l_k$, we are optimizing over $g_{l_k}$ and $q_{l_k}$. $g_{l_k} \in \mathcal{G}$ is the prior on the elements of the $k$th column of the loadings matrix:
$$ l_{1k}, \ldots, l_{nk} \sim^{iid} g_{l_k} $$
$q_{l_k}$ is an arbitrary distribution which enters the problem via the variational approach. For convenience, I drop the subscripts in the following.
The part of the objective that depends on $g$ and $q$ is
$$ F(g, q) := E_q \left[ -\frac{1}{2} \sum_i (A_i l_i^2 - 2 B_i l_i) \right]
+ E_q \log \frac{g(\mathbf{l})}{q(\mathbf{l})} $$
with
$$ A_i = \sum_j \tau_{ij} Ef^2_j \text{ and }
B_i = \sum_j \tau_{ij} R_{ij} Ef_j, $$
($R$ is the matrix of residuals (excluding factor $k$) and $EF_j$ and $EF^2_j$ are the expected values of $f_jk$ and $f_jk^2$ with respect to the distribution $q_{f_k}$ fitted during the factor update.)
As Lemma 2 in the paper shows (see Appendix A.2), this expression is optimized by setting $s_j^2 = A_j$ and $x_j = B_j s_j^2$, and then solving the EBNM problem, where the EBNM model is:
$$ \mathbf{x} = \mathbf{\theta} + \mathbf{e},\ \theta_1, \ldots, \theta_n \sim^{iid} g $$
Solving the EBNM problem gives
$$\hat{g} = {\arg \max}_g\ p(x \mid g) $$
and
$$ \hat{q} = p(\theta \mid x, \hat{g}) $$
Finally, to update the overall objective, we need to compute $E_q \log \frac{g(\mathbf{l})}{q(\mathbf{l})}$. FLASH uses a clever trick, noticing that
$$ E_{\hat{q}} \log \frac{\hat{g}(\mathbf{l})}{\hat{q}(\mathbf{l})}
= F(\hat{g}, \hat{q}) + \frac{1}{2}
E_{\hat{q}} \left[ \sum_j \log 2\pi s_j^2 + (1/s+j^2)(x_j - \theta_j)^2 \right]$$
(See Appendix A.4.)
## Direct method
When using `ebnm_pn`, however, it seems possible to compute $E_q \log \frac{g(\mathbf{l})}{q(\mathbf{l})}$ directly. Since the elements $l_1, \ldots, l_n$ are i.i.d. from $g$ (by the FLASH model) and the posterior distributions are mutually independent (by the EBNM model),
$$ E_q \log \frac{g(\mathbf{l})}{q(\mathbf{l})}
= \sum_j E_{q_j} \log \frac{g(l_j)}{q(l_j)} $$
I drop the subscripts $j$. Write
$$ g \sim \pi_0 \delta_0 + (1 - \pi_0) N(0, 1/a) $$
and
$$ q \sim \tilde{\pi}_0 \delta_0 + (1 - \tilde{\pi}_0) N(\tilde{\mu}, \tilde{\sigma}^2) $$
(I parametrize the normals differently to follow the code more closely.)
Then
$$\begin{aligned}
E_q \log \frac{g(l)}{q(l)}
&= \tilde{\pi_0} \log \frac{\pi_0}{\tilde{\pi}_0}
+ \int (1 - \tilde{\pi}_0) \text{dnorm}(x; \tilde{\mu}, \tilde{\sigma}^2)
\log \frac{(1 - \pi_0)\text{dnorm}(x; 0, 1/a)}
{(1 - \tilde{\pi}_0)\text{dnorm}(x; \tilde{\mu}, \tilde{\sigma}^2)}\ dx \\
&= \tilde{\pi_0} \log \frac{\pi_0}{\tilde{\pi}_0}
+ (1 - \tilde{\pi_0}) \log \frac{1 - \pi_0}{1 - \tilde{\pi}_0} \\
&\
+ \int (1 - \tilde{\pi}_0) \text{dnorm}(x; \tilde{\mu}, \tilde{\sigma}^2)
\log \left( \sqrt{a \tilde{\sigma}^2}
\exp \left( -\frac{ax^2}{2} + \frac{(x - \tilde{\mu})^2}{2 \tilde{\sigma}^2} \right) \right)
\ dx \\
&= \tilde{\pi_0} \log \frac{\pi_0}{\tilde{\pi}_0}
+ (1 - \tilde{\pi_0}) \log \frac{1 - \pi_0}{1 - \tilde{\pi}_0}
+ \frac{1 - \tilde{\pi_0}}{2} \log (a \tilde{\sigma}^2) \\
&\
- \frac{(1 - \tilde{\pi}_0)a}{2} E_{N(x; \tilde{\mu}, \tilde{\sigma}^2)} x^2 + \frac{1 - \tilde{\pi}_0}{2 \tilde{\sigma}^2} E_{N(x; \tilde{\mu}, \tilde{\sigma}^2)} (x - \tilde{\mu})^2 \\
&= \tilde{\pi_0} \log \frac{\pi_0}{\tilde{\pi}_0}
+ (1 - \tilde{\pi_0}) \log \frac{1 - \pi_0}{1 - \tilde{\pi}_0}
+ \frac{1 - \tilde{\pi_0}}{2} \log (a \tilde{\sigma}^2)
- \frac{(1 - \tilde{\pi}_0)a}{2} (\tilde{\mu}^2 + \tilde{\sigma}^2)
+ \frac{1 - \tilde{\pi}_0}{2}
\end{aligned}$$
So we should be able to calculate $E_q} \log \frac{g(l)}{q(l)} as follows:
```{r KL}
calc_KL <- function(x, s, g) {
pi0 <- g$pi0
w <- 1 - g$pi0
a <- g$a
wpost <- ebnm:::wpost_normal(x, s, w, a) # 1 - \tilde{\pi}_0
pi0post <- 1 - wpost[wpost < 1] # \tilde{\pi}_0
pmean_cond <- ebnm:::pmean_cond_normal(x, s, a) # \tilde{\mu}
pvar_cond <- ebnm:::pvar_cond_normal(s, a) # \tilde{\sigma}^2
KLa <- pi0post * log(pi0 / pi0post)
KLb <- wpost * log(w / wpost)
KLc <- (wpost / 2) * log(a * pvar_cond)
KLd <- -(wpost / 2) * a * (pvar_cond + pmean_cond^2)
KLe <- (wpost / 2)
sum(KLa) + sum(KLb) + sum(KLc) + sum(KLd) + sum(KLe)
}
```