-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #136 from schalkdaniel/general_updates
update some prototype stuff for splines
- Loading branch information
Showing
8 changed files
with
528 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# ============================================================================ # | ||
# # | ||
# Prototype for creating B(asic)-Spline Basis # | ||
# Code is based on the Nurbs Book # | ||
# # | ||
# ============================================================================ # | ||
|
||
# For an application and comparison with splines::splineDesign see | ||
# `compare_bspline_prototype_mboost_bbs.R` | ||
|
||
## Binary search (Site 68): | ||
|
||
# function to find first position where min {i : u <= U[i]} | ||
findSpan = function (u, U) | ||
{ | ||
m = length(U) | ||
|
||
# Special cases: | ||
if (u < U[2]) { return (1) } | ||
|
||
low = 1 | ||
high = m | ||
mid = round((low + high) / 2) | ||
|
||
while (u < U[mid] || u >= U[mid + 1]) | ||
{ | ||
if (u < U[mid]) { | ||
high = mid | ||
} else { | ||
low = mid | ||
} | ||
mid = round((low + high) / 2) | ||
} | ||
# smallest possible number in R is 1, in C++ we have to switch everything | ||
# by -1! | ||
return (mid) | ||
} | ||
|
||
# # Example: | ||
# p = 2 | ||
# U = c(0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 5) | ||
# u = 5/2 | ||
# | ||
# findSpan(p, u, U) | ||
|
||
## Create Base (Site 70): | ||
|
||
# Important: | ||
# - Check a priori if u is within the range of the originally used data | ||
# - One preprocessing step is used here (expanding the knots), see | ||
# `compare_bspline_prototype_mboost_bbs.R` | ||
|
||
basisFuns = function (i, u, p, U) | ||
{ | ||
# full base is length of knots minus the number of coefficients | ||
full.base = rep(0, length(U) - (p + 1)) | ||
if (i > length(full.base)) { i = length(full.base) } | ||
# if (i <= p) { | ||
# i = p + 1 | ||
# } | ||
|
||
# if (u <= U[1]) { | ||
# full.base[1] = 1 | ||
# return (full.base) | ||
# } | ||
# if (u >= U[length(U)]) { | ||
# full.base[length(full.base)] = 1 | ||
# return (full.base) | ||
# } | ||
|
||
# Output for basis functions: | ||
N = numeric(length = p + 1) | ||
right = left = numeric(length = p) | ||
|
||
# In C++ initialization with N[0] | ||
N[1] = 1.0 | ||
|
||
for (j in 1:p) { | ||
|
||
left[j] = u - U[i + 1 - j] | ||
right[j] = U[i + j] - u | ||
|
||
saved = 0 | ||
|
||
for (r in 0:(j - 1)) { | ||
temp = N[r + (1)] / (right[r + 1] + left[j - r]) | ||
N[r + (1)] = saved + right[r + 1] * temp | ||
saved = left[j - r] * temp | ||
} | ||
N[j + (1)] = saved | ||
} | ||
full.base[((i - p):i) ] = N | ||
|
||
return (full.base) | ||
} | ||
|
||
# - Efficient | ||
# - Guarantees no division by 0 | ||
|
||
# # Example: | ||
# i = findSpan(p, u, U) | ||
# basisFuns(i, u, p, U) | ||
# | ||
# | ||
# | ||
# p = 3 | ||
# U = seq(0, 10, length.out = 11) | ||
# u = 4.2 | ||
# | ||
# basisFuns(findSpan(p, u, U), u, p, U) | ||
# | ||
# | ||
# x = runif(100, -5, 7) | ||
# y = 2 * x + 1/5 * x^2 - 1/10 * x^3 + rnorm(100, 0, 2) | ||
# | ||
# plot(x, y) | ||
# | ||
# U = seq(min(x), max(x), length.out = 10) | ||
# p = 2 | ||
# | ||
# abline(v = U) | ||
# | ||
# u = x[10] | ||
# basisFuns(findSpan(p, u, U), u, p, U) | ||
# | ||
# spline.base = as.matrix(lapply(x, function (u) { | ||
# basisFuns(findSpan(p, u, U), u, p, U) | ||
# })) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# ============================================================================ # | ||
# # | ||
# Prototype for creating B(asic)-Spline Basis # | ||
# Functions are from `bspline_basis_prototype` # | ||
# # | ||
# ============================================================================ # | ||
|
||
|
||
x <- c(1, seq(0, 11, length.out = 5), 10.9) | ||
|
||
n = 100 | ||
x1 = x | ||
|
||
n.knots = 30 | ||
degree = 3 # mboost default | ||
|
||
spline1 = mboost::bbs(x1, knots = n.knots, df = 4, boundary.knots = range(x), | ||
degree = degree) | ||
bb.mboost = mboost::extract(spline1, "design") | ||
|
||
knots = attr(bb.mboost, "knots") | ||
|
||
# This is what mboost does to get every | ||
myknots = seq(min(x), max(x), length.out = n.knots + 2) | ||
|
||
knot.range = diff(myknots)[1] | ||
myknots = c(min(x) - 3:1 * knot.range, myknots, max(x) + 1:3 * knot.range) | ||
|
||
# Knot the same as mboost (degree has to be degree + 1 since splineDesign uses | ||
# the number of coefficent as ord): | ||
bb = splines::splineDesign(myknots, x = x1, outer.ok = TRUE, ord = degree + 1) | ||
bb.m = mboost:::bsplines(x1, knots, boundary.knots = range(x1), degree = degree) | ||
|
||
bb | ||
bb.m | ||
bb.mboost | ||
|
||
attributes(bb.m) = NULL | ||
attr(bb.m, "dim") = dim(bb) | ||
|
||
all.equal(bb, bb.m) | ||
|
||
|
||
|
||
|
||
# My stupid but easy algorithm (prototype): | ||
|
||
idx.test = 6 | ||
|
||
u = x[idx.test] | ||
|
||
idx = findSpan(u = u, U = myknots) | ||
basisFuns(idx, u = u, p = degree, U = myknots) | ||
|
||
bb[idx.test, ] | ||
|
||
|
||
mybb = matrix(0, nrow = length(x), ncol = length(myknots) - (degree + 1)) | ||
|
||
for (i in seq_len(nrow(mybb))) { | ||
idx.mybb = findSpan(u = x[i], U = myknots) | ||
mybb[i, ] = basisFuns(idx.mybb, u = x[i], p = degree, U = myknots) | ||
} | ||
|
||
all.equal(mybb, bb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# ============================================================================ # | ||
# # | ||
# Prototype to create Penalty Matrix # | ||
# # | ||
# ============================================================================ # | ||
|
||
cpp.fun1 = " | ||
arma::vec test1 (const arma::vec& myvec, const arma::mat& mymat) | ||
{ | ||
return arma::solve(mymat, myvec); | ||
} | ||
" | ||
|
||
cpp.fun2 = " | ||
arma::vec test2 (const arma::vec& myvec, const arma::mat& mymat) | ||
{ | ||
return arma::inv(mymat.t() * mymat) * mymat.t() * myvec; | ||
} | ||
" | ||
|
||
Rcpp::cppFunction(cpp.fun1, depends = "RcppArmadillo", plugins = "cpp11") | ||
Rcpp::cppFunction(cpp.fun2, depends = "RcppArmadillo", plugins = "cpp11") | ||
|
||
test3 = function (y, X) | ||
{ | ||
return (solve(t(X) %*% X) %*% t(X) %*% y) | ||
} | ||
|
||
mydata = na.omit(hflights::hflights) | ||
|
||
y = mydata$ArrDelay | ||
|
||
X = cbind(mydata$DepDelay, mydata$AirTime, mydata$Distance, mydata$TaxiIn) | ||
|
||
|
||
|
||
test1(y, X) | ||
test2(y, X) | ||
test3(y, X) | ||
|
||
microbenchmark::microbenchmark( | ||
"c++1" = test1(y,X), | ||
"c++2" = test2(y,X), | ||
"r" = test3(y,X) | ||
) | ||
|
||
|
||
|
||
|
||
|
||
|
||
cpp.fun.sparse = " | ||
arma::sp_mat testSparse (const arma::mat& mymat, const arma::vec& myvec, | ||
const unsigned int& differences) | ||
{ | ||
arma::sp_mat X(mymat); | ||
arma::sp_mat newX = arma::join_cols(X, X); | ||
newX = X.t() * X; | ||
unsigned int d = myvec.size(); | ||
arma::sp_mat diffs(0, d); | ||
for (unsigned int i = 0; i < d-1; i++) { | ||
arma::sp_mat insert(1, d); | ||
insert[i] = -1; | ||
insert[i + 1] = 1; | ||
diffs = join_cols(diffs, insert); | ||
} | ||
for (unsigned int k = 0; k < differences - 1; k++) { | ||
arma::sp_mat sparse_temp = diffs(arma::span(1, diffs.n_rows - 1), arma::span(1, diffs.n_cols - 1)); | ||
diffs = sparse_temp * diffs; | ||
} | ||
arma::mat out = arma::spsolve(X, myvec, \"lapack\"); | ||
return diffs; | ||
} | ||
" | ||
|
||
src.get.K = " | ||
arma::sp_mat penaltyMat (const unsigned int& n, const unsigned int& differences) | ||
{ | ||
// Create frame for sparse difference matrix: | ||
arma::sp_mat diffs(0, n); | ||
for (unsigned int i = 0; i < n-1; i++) { | ||
arma::sp_mat insert(1, n); | ||
insert[i] = -1; | ||
insert[i + 1] = 1; | ||
diffs = join_cols(diffs, insert); | ||
} | ||
// Calculate the difference matrix for higher orders: | ||
if (differences > 1) { | ||
arma::sp_mat diffs_reduced = diffs; | ||
for (unsigned int k = 0; k < differences - 1; k++) { | ||
diffs_reduced = diffs_reduced(arma::span(1, diffs_reduced.n_rows - 1), arma::span(1, diffs_reduced.n_cols - 1)); | ||
diffs = diffs_reduced * diffs; | ||
} | ||
} | ||
arma::sp_mat K = diffs.t() * diffs; | ||
return K; | ||
} | ||
" | ||
|
||
getDiffK = function (n, d) | ||
{ | ||
D = diff(diag(n), differences = d) | ||
|
||
return (t(D) %*% D) | ||
} | ||
|
||
Rcpp::cppFunction(cpp.fun.sparse, depends = "RcppArmadillo", plugins = "cpp11") | ||
Rcpp::cppFunction(src.get.K, depends = "RcppArmadillo") | ||
|
||
# a = testSparse(diag(10), rnorm(10)) | ||
# t(a) %*% a | ||
# | ||
|
||
|
||
getD1 = function (n) | ||
{ | ||
D = diag(x = - 1, ncol = n, nrow = n - 1) | ||
D[, -1] = D[, -1] + diag(x = 1, ncol = n - 1, nrow = n - 1) | ||
return(D) | ||
} | ||
getD = function (n, d) | ||
{ | ||
D = getD1(n = n) | ||
for (i in seq_len(d - 1)) { | ||
D = getD1(n = n - i) %*% D | ||
} | ||
return(D) | ||
} | ||
getK = function (n, d) | ||
{ | ||
D = getD(n = n, d = d) | ||
return(t(D) %*% D) | ||
} | ||
|
||
getDiffK = function (n, d) | ||
{ | ||
D = diff(diag(n), differences = d) | ||
return (t(D) %*% D) | ||
} | ||
|
||
getK(10, 3) | ||
penaltyMat(10, 3) | ||
|
||
microbenchmark::microbenchmark( | ||
"C++" = penaltyMat(100, 4), | ||
# "R" = getK(1000, 4), | ||
"R fast" = getDiffK(100, 4), | ||
times = 10L | ||
) | ||
|
||
pryr::mem_change(penaltyMat(1000, 4)) | ||
pryr::mem_change(getK(1000, 4)) | ||
|
||
a = penaltyMat(10000, 4) | ||
b = getK(1000, 4) | ||
c = getDiffK(10000, 3) | ||
|
||
pryr::object_size(a) | ||
pryr::object_size(b) |
Oops, something went wrong.