Skip to content

Commit

Permalink
kcmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaswiemann committed Nov 7, 2023
1 parent 8868f89 commit 22df4c9
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 26 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ Depends:
R (>= 3.6)
Imports:
AER,
Ckmeans.1d.dp
Ckmeans.1d.dp,
MASS,
Matrix
Suggests:
testthat (>= 3.0.0),
covr,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Generated by roxygen2: do not edit by hand

S3method(predict,kcmeans)
export(kcmeans)
export(toy_fun)
121 changes: 116 additions & 5 deletions R/Fhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,119 @@
#' @examples
#' res <- toy_fun(rnorm(100))
#' res$fun
toy_fun <- function(y) {
output <- list(fun = TRUE, y = y)
class(output) <- "toy_fun" # define S3 class
return(output)
}#TOY_FUN
kcmeans <- function(y, X, K) {

# Data parameters
nobs <- length(y)

# Check whether additional features are included, residualize accordingly
if (length(X) > nobs) {
Z <- X[, 1] # categorical variable
X <- X[, -1, drop = FALSE] # additional features
# Compute \pi and residualize y
nX <- ncol(X)
Z_mat <- model.matrix(~ 0 + as.factor(Z))
ols_fit <- ols(y, cbind(X, Z_mat)) # ols w/ generalized inverse
pi <- ols_fit$coef[1:nX]
y <- y - X %*% pi
} else {
Z <- X # categorical variable
pi <- NULL
}#IFELSE

# Prepare data and prepare the cluster map
unique_Z <- unique(Z)
cluster_map <- t(simplify2array(lapply(unique_X, function (x) {
c(x, mean(y[Z == x]), mean(Z == x))
})))#LAPPLY

# Estimate kmeans on means of D given Z = z
kmeans_fit <- Ckmeans.1d.dp::Ckmeans.1d.dp(x = cluster_map[, 2], k = K,
y = cluster_map[, 3])

# Amend the cluster map
cluster_map <- cbind(cluster_map, kmeans_fit$cluster,
kmeans_fit$centers[kmeans_fit$cluster])
colnames(cluster_map) <- c("x", "EYx", "Px", "gx", "mx")

# Compute the unconditional mean
mean_y <- mean(y)

# Prepare and return the model fit object
mdl_fit <- list(cluster_map = cluster_map,
mean_y = mean_y, pi = pi)
class(mdl_fit) <- "kcmeans" # define S3 class
return(mdl_fit)
}#kcmeans


#' Inference Methods for Partially Linear Estimators.
#'
#' @seealso [sandwich::vcovHC()]
#'
#' @description Inference methods for partially linear estimators. Simple
#' wrapper for [sandwich::vcovHC()].
#'
#' @param object An object of class \code{ddml_plm}, \code{ddml_pliv}, or
#' \code{ddml_fpliv} as fitted by [ddml::ddml_plm()], [ddml::ddml_pliv()],
#' and [ddml::ddml_fpliv()], respectively.
#' @param ... Additional arguments passed to \code{vcovHC}. See
#' [sandwich::vcovHC()] for a complete list of arguments.
#'
#' @return An array with inference results for each \code{ensemble_type}.
#'
#' @references
#' Zeileis A (2004). "Econometric Computing with HC and HAC Covariance Matrix
#' Estimators.” Journal of Statistical Software, 11(10), 1-17.
#'
#' Zeileis A (2006). “Object-Oriented Computation of Sandwich Estimators.”
#' Journal of Statistical Software, 16(9), 1-16.
#'
#' Zeileis A, Köll S, Graham N (2020). “Various Versatile Variances: An
#' Object-Oriented Implementation of Clustered Covariances in R.” Journal of
#' Statistical Software, 95(1), 1-36.
#'
#' @export
#'
#' @examples
#' # Construct variables from the included Angrist & Evans (1998) data
#' y = AE98[, "worked"]
#' D = AE98[, "morekids"]
#' X = AE98[, c("age","agefst","black","hisp","othrace","educ")]
#'
#' # Estimate the partially linear model using a single base learner, ridge.
#' plm_fit <- ddml_plm(y, D, X,
#' learners = list(what = mdl_glmnet,
#' args = list(alpha = 0)),
#' sample_folds = 2,
#' silent = TRUE)
#' summary(plm_fit)
predict.kcmeans <- function(object, newdata, clusters = FALSE, ...) {

# Check whether additional features are included, compute X\pi if needed
if (!is.null(object$pi)) {
Z <- newdata[, 1]
X <- newdata[, -1, drop = FALSE]
if(!clusters) Xpi <- X %*% object$pi
} else {
Z <- newdata
Xpi <- 0
}#IFELSE

# Construct fitted values from cluster map
fitted_mat <- merge(Z, object$cluster_map,
by.x = 1, by.y = 1, all.x = TRUE)

# Construct predictions
if (clusters) {
# Return estimated cluster assignment
return(fitted_mat[, "gx"])
} else {
# Replace unseen categories with unconditional mean of y - X\pi
fitted_mat[is.na(fitted_mat[, "mx"]), 5] <- object$mean_y
# Construct and return fitted values
fitted <- fitted_mat[, "mx"] + Xpi
return(fitted)
}#IFELSE

}#PREDICT.KCMEANS
57 changes: 57 additions & 0 deletions R/ols.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#' Ordinary least squares with generalized inverse from the ddml package
#' See ?ddml::ols
ols <- function(y, X,
const = FALSE,
w = NULL) {
# Add constant (optional)
if (const) X <- cbind(1, X)

# Data parameters
calc_wls <- !is.null(w)

# Compute OLS coefficient
if (!calc_wls) {
XX_inv <- csolve(as.matrix(Matrix::crossprod(X)))
coef <- XX_inv %*% Matrix::crossprod(X, y)
} else { # Or calculate WLS coefficient whenever weights are specified
Xw <- X * w # weight rows
XX_inv <- csolve(as.matrix(Matrix::crossprod(Xw, X)))
coef <- XX_inv %*% Matrix::crossprod(Xw, y)
}#IFELSE
# Return estimate
coef <- as.matrix(coef)
try(rownames(coef) <- colnames(X)) # assign coefficient names
output <- list(coef = coef, y = y, X = X,
const = const, w = w)
class(output) <- "ols" # define S3 class
return(output)
}#OLS

# Complementary methods ========================================================

# Constructed fitted values
predict.ols <- function(object, newdata = NULL, ...){
# Obtain datamatrix
if (is.null(newdata)) {
newdata <- object$X
} else if (object$const) {
newdata <- cbind(1, newdata)
}#IFELSE
# Calculate and return fitted values with the OLS coefficient
fitted <- newdata%*%object$coef
return(fitted)
}#PREDICT.OLS

# help function for generalized inverse ========================================

# Simple generalized inverse wrapper.
csolve <- function(X) {
# Attempt inversion
X_inv <- tryCatch(solve(X), error = function(e) NA)
# If inversion failed, calculate generalized inverse
if (any(is.na(X_inv))) {
X_inv <- MASS::ginv(X)
}#IF
# Return (generalized) inverse
return(X_inv)
}#CSOLVE
28 changes: 28 additions & 0 deletions man/kcmeans.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions man/ols.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions man/predict.kcmeans.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 2 additions & 17 deletions man/toy_fun.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 37 additions & 3 deletions tests/testthat/test-Fhat.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,42 @@
test_that("Fhat computes", {
# Generate data
test_that("kcmenas computes", {
# Get data from the included SimDat data
y <- SimDat$D
X <- SimDat$Z

# Compute kcmeans
kcmeans_fit <- kcmeans(y, X, K = 3)

# Check output with expectations
expect_equal(dim(kcmeans_fit), c(60, 5))
})#TEST_THAT

test_that("kcmenas computes with additional controls", {
# Get data from the included SimDat data
y <- SimDat$D
X <- cbind(SimDat$Z, SimDat$X)


# Compute kcmeans
kcmeans_fit <- kcmeans(y, X, K = 3)

# Check output with expectations
expect_equal(dim(kcmeans_fit), c(60, 5))
})#TEST_THAT

test_that("predict.kcmenas computes w/ unseen categories", {
# Get data from the included SimDat data
y <- SimDat$D
X <- cbind(SimDat$Z, SimDat$X)

# Compute kcmeans
kcmeans_fit <- kcmeans(y, X, K = 3)

# Compute predictions w/ unseen categories
newdata <- X
newdata[1:20, 1] <- -22

fitted_values <- predict(kcmeans_fit, newdata)

# Check output with expectations
expect_equal(length(myfun), 1)
expect_equal(dim(kcmeans_fit), c(60, 5))
})#TEST_THAT

0 comments on commit 22df4c9

Please sign in to comment.