-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8868f89
commit 22df4c9
Showing
9 changed files
with
309 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(predict,kcmeans) | ||
export(kcmeans) | ||
export(toy_fun) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#' Ordinary least squares with generalized inverse from the ddml package | ||
#' See ?ddml::ols | ||
ols <- function(y, X, | ||
const = FALSE, | ||
w = NULL) { | ||
# Add constant (optional) | ||
if (const) X <- cbind(1, X) | ||
|
||
# Data parameters | ||
calc_wls <- !is.null(w) | ||
|
||
# Compute OLS coefficient | ||
if (!calc_wls) { | ||
XX_inv <- csolve(as.matrix(Matrix::crossprod(X))) | ||
coef <- XX_inv %*% Matrix::crossprod(X, y) | ||
} else { # Or calculate WLS coefficient whenever weights are specified | ||
Xw <- X * w # weight rows | ||
XX_inv <- csolve(as.matrix(Matrix::crossprod(Xw, X))) | ||
coef <- XX_inv %*% Matrix::crossprod(Xw, y) | ||
}#IFELSE | ||
# Return estimate | ||
coef <- as.matrix(coef) | ||
try(rownames(coef) <- colnames(X)) # assign coefficient names | ||
output <- list(coef = coef, y = y, X = X, | ||
const = const, w = w) | ||
class(output) <- "ols" # define S3 class | ||
return(output) | ||
}#OLS | ||
|
||
# Complementary methods ======================================================== | ||
|
||
# Constructed fitted values | ||
predict.ols <- function(object, newdata = NULL, ...){ | ||
# Obtain datamatrix | ||
if (is.null(newdata)) { | ||
newdata <- object$X | ||
} else if (object$const) { | ||
newdata <- cbind(1, newdata) | ||
}#IFELSE | ||
# Calculate and return fitted values with the OLS coefficient | ||
fitted <- newdata%*%object$coef | ||
return(fitted) | ||
}#PREDICT.OLS | ||
|
||
# help function for generalized inverse ======================================== | ||
|
||
# Simple generalized inverse wrapper. | ||
csolve <- function(X) { | ||
# Attempt inversion | ||
X_inv <- tryCatch(solve(X), error = function(e) NA) | ||
# If inversion failed, calculate generalized inverse | ||
if (any(is.na(X_inv))) { | ||
X_inv <- MASS::ginv(X) | ||
}#IF | ||
# Return (generalized) inverse | ||
return(X_inv) | ||
}#CSOLVE |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,42 @@ | ||
test_that("Fhat computes", { | ||
# Generate data | ||
test_that("kcmenas computes", { | ||
# Get data from the included SimDat data | ||
y <- SimDat$D | ||
X <- SimDat$Z | ||
|
||
# Compute kcmeans | ||
kcmeans_fit <- kcmeans(y, X, K = 3) | ||
|
||
# Check output with expectations | ||
expect_equal(dim(kcmeans_fit), c(60, 5)) | ||
})#TEST_THAT | ||
|
||
test_that("kcmenas computes with additional controls", { | ||
# Get data from the included SimDat data | ||
y <- SimDat$D | ||
X <- cbind(SimDat$Z, SimDat$X) | ||
|
||
|
||
# Compute kcmeans | ||
kcmeans_fit <- kcmeans(y, X, K = 3) | ||
|
||
# Check output with expectations | ||
expect_equal(dim(kcmeans_fit), c(60, 5)) | ||
})#TEST_THAT | ||
|
||
test_that("predict.kcmenas computes w/ unseen categories", { | ||
# Get data from the included SimDat data | ||
y <- SimDat$D | ||
X <- cbind(SimDat$Z, SimDat$X) | ||
|
||
# Compute kcmeans | ||
kcmeans_fit <- kcmeans(y, X, K = 3) | ||
|
||
# Compute predictions w/ unseen categories | ||
newdata <- X | ||
newdata[1:20, 1] <- -22 | ||
|
||
fitted_values <- predict(kcmeans_fit, newdata) | ||
|
||
# Check output with expectations | ||
expect_equal(length(myfun), 1) | ||
expect_equal(dim(kcmeans_fit), c(60, 5)) | ||
})#TEST_THAT |