-
Notifications
You must be signed in to change notification settings - Fork 2
/
vamp_01.Rmd
193 lines (141 loc) · 5.25 KB
/
vamp_01.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
---
title: "vamp_01"
author: "Matthew Stephens"
date: "2021-02-02"
output: workflowr::wflow_html
editor_options:
chunk_output_type: console
---
```{r}
library(ebnm)
library(glmnet)
library(ashr)
```
## Introduction
My goal here is to implement a version of VAMP in R.
I'm using algorithm 1 from Fletcher+Schniter (which
includes EM steps, but I am ignoring those for now.)
I will try to use mostly their notation, where the model is
$$y \sim N(Ax, 1/\theta_2)$$
First I simulate some data under this model for testing:
```{r}
M = 100
N = 10
A = matrix(rnorm(M*N, 0,1),nrow=M)
theta2 = 1
x = rnorm(N)
y = A %*% x + rnorm(M,0,sd=sqrt(1/theta2))
```
For comparison I'm going to do the ridge regression estimate.
For prior $x \sim N(0,s_x^2)$ the posterior on $x$ is
$x \sim N(\mu_1,\Sigma_1)$ where
$$\mu_1 = \theta_2 \Sigma_1 A'y$$
and
$$\Sigma_1 = (\theta_2 A'A + s_x^2 I)^{-1}.$$
```{r}
S = chol2inv(chol(theta2 * t(A) %*% A + diag(N)))
x.rr = theta2 * S %*% t(A) %*% y
```
Now here is my initial implementation of vamp. Note there is no EB for now - the ebnm function has a fixed prior and just does the shrinkage.
This implmentation uses the idea of performing an svd of A to improve
efficiency per iteration. The computationally intensive part without this trick is computing the inverse of $Q$ (equations 8-10 in the EM-VAMP paper).
Here I briefly outline this trick.
Assume $A$ has SVD $A=UDV'$, so $A'A = VD^2V'$. If necessary include 0 eigenvalues in $D$, so $V$ is a square matrix with $VV'=V'V=I$.
Recall that
$$Q:=\theta_2 A'A + \gamma_2 I$$
so
$$Q^{-1} = V (\theta_2 D^2 + \gamma_2 I)^{-1} V'$$
Note that if $d=diag(D)$ then $$(\theta_2 d_k^2 + \gamma_2)^{-1}= (1/\gamma_2)(1- a_k)$$
where $$a_k:= \theta_2 d_k^2/(\theta_2 d_k^2 + \gamma_2).$$
So $$Q^{-1} = (1/\gamma_2)(I - V diag(a) V')$$
and this has diagonal elements
$$Q^{-1}_{ii} = (1/\gamma_2)(1 - \sum_k V^2_{ik} a_k)$$
Note that if $d_k=0$ then $a_k=0$ so there is no need to actually compute the parts of $V$ that correspond to 0 eigenvalues.
```{r}
#' @param A an M by N matrix of covariates
#' @param y an M vector of outcomes
#' @param ebnm_fn a function (eg from ebnm package) that takes parameters x and s and returns posterior mean and sd under a normal means model (no eb for now!)
vamp = function(A,y,ebnm_fn= function(x,s){ebnm_normal(x=x,s=s,mode=0,scale=1)}, r1.init = rnorm(ncol(A)), gamma1.init = 1, theta2=1, niter = 100){
# initialize
r1 = r1.init
gamma1 = gamma1.init
N = ncol(A)
A.svd = svd(A)
v = A.svd$v
d = A.svd$d
for(k in 1:niter){
fit = do.call(ebnm_fn,list(x = r1,s = sqrt(1/gamma1)))
x1 = fit$posterior$mean
eta1 = 1/(mean(fit$posterior$sd^2))
gamma2 = eta1 - gamma1
r2 = (eta1 * x1 - gamma1 * r1)/gamma2
# this is the brute force approach; superceded by the svd approach
#Q = theta2 * t(A) %*% A + gamma2 * diag(N)
#Qinv = chol2inv(chol(Q))
#diag_Qinv = diag(Qinv)
# The following avoids computing Qinv explicitly
a = theta2*d^2/(theta2*d^2 + gamma2)
#Qinv = (1/gamma2) * (diag(N) - v %*% diag(a) %*% t(v))
diag_Qinv = (1/gamma2) * (1- colSums( a * t(v^2) ))
eta2 = 1/mean(diag_Qinv)
#x2 = Qinv %*% (theta2 * t(A) %*% y + gamma2 * r2)
temp = (theta2 * t(A) %*% y + gamma2 * r2) # temp is a vector
temp2= (v %*% (diag(a) %*% (t(v) %*% temp))) # matrix mult vdiag(a)v'temp efficiently
x2 = (1/gamma2) * (temp - temp2)
gamma1 = eta2 - gamma2
r1 = (eta2 * x2 - gamma2 * r2)/ gamma1
}
return(fit = list(x1=x1,x2=x2, eta1=eta1, eta2=eta2))
}
```
Now I try this out with a normal prior (which should give same answer as ridge regression and does...)
```{r}
fit = vamp(A,y)
plot(fit$x1,fit$x2, main="x1 vs x2")
abline(a=0,b=1)
plot(fit$x1,x.rr, main="comparison with ridge regression")
abline(a=0,b=1)
```
Note that the $\eta$ values converge to the inverse of the mean of the digonal of the posterior variance.
```{r}
fit$eta1 - fit$eta2
1/fit$eta1 - mean(diag(S))
```
## A harder example
Here we try vamp on a problematic case for mean field from [here](mr_ash_vs_lasso.html)
Here the prior is a 50-50 mixture of 0 and $N(0,1)$.
I'm going to give vamp both the true prior and the true residual variance.
```{r}
my_g = normalmix(pi=c(0.5,0.5), mean=c(0,0), sd=c(0,1))
my_ebnm_fn = function(x,s){ebnm(x,s,g_init=my_g,fix_g = TRUE )}
```
```{r}
set.seed(123)
n <- 500
p <- 1000
p_causal <- 500 # number of causal variables (simulated effects N(0,1))
pve <- 0.95
nrep = 10
rmse_vamp = rep(0,nrep)
rmse_glmnet = rep(0,nrep)
for(i in 1:nrep){
sim=list()
sim$X = matrix(rnorm(n*p,sd=1),nrow=n)
B <- rep(0,p)
causal_variables <- sample(x=(1:p), size=p_causal)
B[causal_variables] <- rnorm(n=p_causal, mean=0, sd=1)
sim$B = B
sim$Y = sim$X %*% sim$B
sigma2 = ((1-pve)/(pve))*sd(sim$Y)^2
E = rnorm(n,sd = sqrt(sigma2))
sim$Y = sim$Y + E
fit_glmnet <- cv.glmnet(x=sim$X, y=sim$Y, family="gaussian", alpha=1, standardize=FALSE)
fit_vamp <- vamp(A=sim$X, y = sim$Y, ebnm_fn = my_ebnm_fn, niter=10)
rmse_glmnet[i] = sqrt(mean((sim$B-coef(fit_glmnet)[-1])^2))
rmse_vamp[i] = sqrt(mean((sim$B-fit_vamp$x1)^2))
}
plot(rmse_vamp,rmse_glmnet,main="vamp (true prior) vs glmnet")
abline(a=0,b=1)
```
```{r}
```