-
Notifications
You must be signed in to change notification settings - Fork 2
/
ridge_admm.Rmd
145 lines (117 loc) · 4.14 KB
/
ridge_admm.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
---
title: "ridge_admm"
author: "Matthew Stephens"
date: "2020-06-29"
output: workflowr::wflow_html
editor_options:
chunk_output_type: console
---
## Introduction
Here I want to use ADMM (see [this exampe](admm_01.html) for my previous code on this)
to fit ridge regression with heterogeneous prior variances,
which can be written as:
$$\min f(x) + g(x)$$
where $f(x) = (1/2\sigma^2) ||y-Ax||_2^2$ and $g(x) = \sum_j x^2_j/2s^2_j$.
Here $s_j^2$ is the prior variance for "coefficient"" $x_j$. to be consistent
with my previous example I write $\lambda_j = 1/2s_j^2$.
```{r}
# note that lambda is as vector here
obj_l2 = function(x,y,A,lambda, residual_variance=1){
(1/(2*residual_variance)) * sum((y- A %*% x)^2) + sum(x^2*lambda)
}
```
My motivation is that this could be a good approach in situations where
the hetergeneous prior variances are changing from iteration to iteration, as
one can do one SVD of $X$ upfront and then use that to solve the ADMM subproblem, which
has homogenous variances.
The ADMM steps, also equivalent to Douglas--Rachford, are given
in [these lecture notes](https://www.stat.cmu.edu/~ryantibs/convexopt/lectures/admm.pdf) as:
$$x \leftarrow \text{prox}_{f,1/\rho} (z-w)$$
$$z \leftarrow \text{prox}_{g,1/\rho}(x+w)$$
$$w \leftarrow w + x - z$$
where the proximal operator is defined as
$$\text{prox}_{h,t}(x) := \arg \min_z [ (1/2t) ||x-z||_2^2 + h(z)]$$
For now I'll use the proximal operator for $f$ is as in my [previous example](admm_01.html),
but later I plan to switch it out for something more efficient.
```{r}
# Note that allows for non-zero prior mean -- ridge regression is usually 0 prior mean
ridge = function(y,A,prior_variance,prior_mean=rep(0,ncol(A)),residual_variance=1){
n = length(y)
p = ncol(A)
L = chol(t(A) %*% A + (residual_variance/prior_variance)*diag(p))
b = backsolve(L, t(A) %*% y + (residual_variance/prior_variance)*prior_mean, transpose=TRUE)
b = backsolve(L, b)
#b = chol2inv(L) %*% (t(A) %*% y + (residual_variance/prior_variance)*prior_mean)
return(b)
}
prox_regression = function(x, t, y, A, residual_variance=1){
ridge(y,A,prior_variance = t,prior_mean = x,residual_variance)
}
```
The proximal operator for $g$ is:
$$\text{prox}_{g,t}(x) := \arg \min_z [ (1/2t) ||x-z||_2^2 + \sum_j \lambda_j z_j^2]$$
which is the posterior mean, for the normal means problem with data $z$, prior variances $1/(2\lambda_j)$ and data variance $t$:
```{r}
# I use sb2 for the prior variance of regression coefficients. Here it is a vector of variances.
prox_l2_het = function(x,t,lambda){
prior_prec = 2*lambda # prior precision
data_prec = 1/t
return(x * data_prec/(data_prec+prior_prec))
}
```
This `admm_fn` function is taken from my [previous example](admm_01.html), with defaults changed:
```{r}
admm_fn = function(y,A,rho,lambda,prox_f=prox_regression, prox_g = prox_l2_het, obj_fn = obj_l2, niter=1000, z_init=NULL){
p = ncol(A)
x = matrix(0,nrow=niter+1,ncol=p)
z = x
w = x
if(!is.null(z_init)){
z[1,] = z_init
}
obj_x = rep(0,niter+1)
obj_z = rep(0,niter+1)
obj_x[1] = obj_fn(x[1,],y,A,lambda)
obj_z[1] = obj_fn(z[1,],y,A,lambda)
for(i in 1:niter){
x[i+1,] = prox_f(z[i,] - w[i,],1/rho,y,A)
z[i+1,] = prox_g(x[i+1,] + w[i,],1/rho,lambda)
w[i+1,] = w[i,] + x[i+1,] - z[i+1,]
obj_x[i+1] = obj_fn(x[i+1,],y,A,lambda)
obj_z[i+1] = obj_fn(z[i+1,],y,A,lambda)
}
return(list(x=x,z=z,w=w,obj_x=obj_x, obj_z=obj_z))
}
```
# Trendfiltering example
Here I simulate some data from my trend-filtering example:
```{r}
set.seed(100)
n = 100
p = n
X = matrix(0,nrow=n,ncol=n)
for(i in 1:n){
X[i:n,i] = 1:(n-i+1)
}
btrue = rep(0,n)
btrue[40] = 8
btrue[41] = -8
Y = X %*% btrue + rnorm(n)
plot(Y)
lines(X %*% btrue)
```
And apply `admm` and compare with direct ridge approach to solve with heterogeneous variances:
```{r}
y = Y
A = X
niter = 100
lambda = rexp(n)
y.admm = admm_fn(y,A,rho=1,lambda=lambda,niter= niter)
plot(y.admm$obj_x[-1])
plot(y.admm$obj_z[-1])
plot(y,main="fitted values, admm (red) and direct ridge (green)")
lines(A %*% y.admm$x[niter+1,],col=2)
y.ridge = ridge(y,A,prior_variance = 0.5/lambda)
lines(A %*% y.ridge,col=3)
plot(y.ridge,y.admm$x[niter+1,])
```