-
Notifications
You must be signed in to change notification settings - Fork 2
/
ridge_conjugate_gradient.Rmd
90 lines (75 loc) · 1.94 KB
/
ridge_conjugate_gradient.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
---
title: "ridge_conjugate_gradient"
author: "Matthew Stephens"
date: "2019-10-21"
output: workflowr::wflow_html
editor_options:
chunk_output_type: console
---
## Introduction
The goal here is for me to experiment with using conjugate gradient to solve ridge regression,
mostly as a learning experience for me.
## Conjugate gradient code
This code was modified from the matlab code given on [Wikipedia](https://en.wikipedia.org/wiki/Conjugate_gradient_method).
```{r}
# A an n by n PSD matrix; x and b are column vectors of length n
conjgrad = function(A, b, x){
r = b - A %*% x
p = r
rsold = t(r) %*% r
for(i in 1:length(b)){
Ap = A %*% p
alpha = as.numeric(rsold / (t(p) %*% Ap))
x = x + alpha * p
r = r - alpha * Ap
rsnew = t(r) %*% r
if(sqrt(rsnew) < 1e-10){
break
}
p = r + as.numeric(rsnew / rsold) * p
rsold = rsnew
}
return(list(x=x,niter = i))
}
```
## Apply to changepoint
Here we simulate some data from a changepoint model
```{r}
set.seed(100)
n = 100
p = n
X = matrix(0,nrow=n,ncol=n)
for(i in 1:n){
X[i:n,i] = 1
}
btrue = rep(0,n)
btrue[40] = 8
Y = X %*% btrue + rnorm(n)
plot(Y)
lines(X %*% btrue)
```
Now apply the conjugate gradient method to solve the ridge regression problem,
with prior variance and residual variance both = 1, so $A = X'X + I$ and $b=X'Y$.
```{r}
A = t(X) %*% X + diag(n)
b = t(X) %*% Y
res = conjgrad(A, b, x = cbind(rep(0,100)))
res$niter
plot(Y)
lines(X %*% btrue)
lines(X %*% res$x,col=2)
```
## Preconditioning
Now we try preconditioning by $T=(X'X)^{-1}$, and solving $TAx = Tb$.
Obviously we would not want to compute $T$ explicitly in practice. I'm just
doing it here to see how it affects iterations. It roughly halves the iterations here.
```{r}
T = solve(t(X) %*% X)
A2 = T %*% A
b2 = T %*% b
res2 = conjgrad(A2, b2, x = cbind(rep(0,100)))
res2$niter
plot(Y)
lines(X %*% btrue)
lines(X %*% res$x,col=2)
```