Skip to content

Commit

Permalink
add naive ordinary least squares regression
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-zobel committed Jun 2, 2024
1 parent 09428dc commit 63f09f2
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 0 deletions.
216 changes: 216 additions & 0 deletions src/main/java/math/linalg/LSSummary.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright 2023 Stefan Zobel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package math.linalg;

import java.util.List;

import math.list.DoubleArrayList;
import math.list.DoubleList;

/**
* Least squares regression summary.
*/
public class LSSummary {

// significance level used
private double alpha;

// parameters (beta)
private DoubleList coefficients;

// designMatrix X
private DMatrix designMatrix;

// regressand y
private DMatrix regressand;

// average value of y
private double yBar;

// predicted values of y
private DoubleList yHat;

// y - yHat = epsilon
private DoubleList residuals;

// coefficient of determination
private double rSquared;

// population variance estimator
private double sigmaHatSquared;

// var-cov matrix of the coefficients
private DMatrix varCovMatrix;

// standard error estimates of the coefficient estimators
private DoubleList coefficientStandardErrors;

// t values of the coefficients
private DoubleList tValues;

// p values of the coefficients
private DoubleList pValues;

// confidence intervals of the coefficients
private List<DoubleList> confidenceIntervals;

// df for t-distribution
private int degreesOfFreedom;

public LSSummary(double alpha, DMatrix designMatrix, DMatrix regressand) {
this.alpha = alpha;
this.designMatrix = designMatrix;
this.regressand = regressand;
}

public void clearTemporaries() {
designMatrix = null;
regressand = null;
yHat = null;
residuals = null;
tValues = null;
}

public DoubleList getBeta() {
return coefficients;
}

void setBeta(DMatrix beta) {
coefficients = new DoubleArrayList(beta.numRows());
for (int i = 0; i < beta.numRows(); ++i) {
coefficients.add(beta.get(i, 0));
}
}

public double getYBar() {
return yBar;
}

void setYBar(double yBar) {
this.yBar = yBar;
}

public DoubleList getYHat() {
return yHat;
}

void setYHat(DMatrix yEst) {
yHat = new DoubleArrayList(yEst.numRows());
for (int i = 0; i < yEst.numRows(); ++i) {
yHat.add(yEst.get(i, 0));
}
}

public DoubleList getResiduals() {
return residuals;
}

void setResiduals(DMatrix epsilonHat) {
residuals = new DoubleArrayList(epsilonHat.numRows());
for (int i = 0; i < epsilonHat.numRows(); ++i) {
residuals.add(epsilonHat.get(i, 0));
}
}

public double getRSquared() {
return rSquared;
}

void setRSquared(double rSquared) {
this.rSquared = rSquared;
}

public double getSigmaHatSquared() {
return sigmaHatSquared;
}

void setSigmaHatSquared(double sigmaHatSquared) {
this.sigmaHatSquared = sigmaHatSquared;
}

public DMatrix getVarianceCovarianceMatrix() {
return varCovMatrix;
}

void setVarianceCovarianceMatrix(DMatrix varianceCovarianceMatrix) {
this.varCovMatrix = varianceCovarianceMatrix;
}

public DoubleList getCoefficientStandardErrors() {
return coefficientStandardErrors;
}

void setCoefficientStandardErrors(DoubleList coefficientStandardErrors) {
this.coefficientStandardErrors = coefficientStandardErrors;
}

public DoubleList getTValues() {
return tValues;
}

void setTValues(DoubleList tValues) {
this.tValues = tValues;
}

public DoubleList getPValues() {
return pValues;
}

void setPValues(DoubleList pValues) {
this.pValues = pValues;
}

public List<DoubleList> getConfidenceIntervals() {
return confidenceIntervals;
}

void setConfidenceIntervals(List<DoubleList> confidenceIntervals) {
this.confidenceIntervals = confidenceIntervals;
}

public int getDegreesOfFreedom() {
return degreesOfFreedom;
}

void setDegreesOfFreedom(int degreesOfFreedom) {
this.degreesOfFreedom = degreesOfFreedom;
}

public double getAlpha() {
return alpha;
}

public DMatrix getXMatrix() {
return designMatrix;
}

public DMatrix getYVector() {
return regressand;
}

public int getCoefficientsCount() {
return coefficients.size();
}

@Override
public String toString() {
return "Summary [alpha=" + alpha + ", numCoefficients=" + getCoefficientsCount() + ",\n coefficients="
+ coefficients + ",\n yBar=" + yBar + ", rSquared=" + rSquared + ", sigmaHatSquared=" + sigmaHatSquared
+ ",\n varCovMatrix=" + varCovMatrix + ", coefficientStandardErrors=" + coefficientStandardErrors
+ ",\n pValues=" + pValues + ",\n confidenceIntervals=" + confidenceIntervals + ",\n degreesOfFreedom="
+ degreesOfFreedom + "]";
}
}
105 changes: 105 additions & 0 deletions src/main/java/math/linalg/OLS.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2023 Stefan Zobel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package math.linalg;

import java.util.ArrayList;

import math.distribution.StudentT;
import math.list.DoubleArrayList;
import math.list.DoubleList;

/**
* Poor man's naive ordinary least squares regression.
*/
public class OLS {

public static LSSummary estimate(double alpha, DMatrix X, DMatrix y) {
if (X.numRows() != y.numRows()) {
throw new IllegalArgumentException("X.numRows != y.numRows : " + X.numRows() + " != " + y.numRows());
}
if (X.numRows() - X.numColumns() < 1) {
throw new IllegalArgumentException("degrees of freedom < 1 : " + (X.numRows() - X.numColumns()));
}
if (alpha <= 0.0) {
throw new IllegalArgumentException("alpha <= 0 : " + alpha);
}
if (alpha >= 1.0) {
throw new IllegalArgumentException("alpha >= 1 : " + alpha);
}
LSSummary smmry = new LSSummary(alpha, X, y);
DMatrix Xtrans = X.transpose();
// Note: this may be numerically unstable!
DMatrix beta = Xtrans.mul(X).inverse().mul(Xtrans).mul(y);
smmry.setBeta(beta);
DMatrix yHat = X.mul(beta);
smmry.setYHat(yHat);
DMatrix ones = new DMatrix(1, y.numRows());
for (int i = 0; i < y.numRows(); ++i) {
ones.setUnsafe(0, i, 1.0);
}
double ybar = ones.mul(y).scaleInplace(1.0 / y.numRows()).get(0, 0);
smmry.setYBar(ybar);
ones = new DMatrix(y.numRows(), 1);
for (int i = 0; i < y.numRows(); ++i) {
ones.setUnsafe(i, 0, 1.0);
}
DMatrix yBarMat = ones.scaleInplace(ybar);
DMatrix a = yHat.minus(yBarMat);
DMatrix b = y.minus(yBarMat);
double SQE = a.transpose().mul(a).get(0, 0);
double SQT = b.transpose().mul(b).get(0, 0);
double R_squared = SQE / SQT;
smmry.setRSquared(R_squared > 1.0 ? 1.0 : R_squared);
DMatrix epsHat = y.minus(yHat);
smmry.setResiduals(epsHat);
int df = epsHat.numRows() - X.numColumns();
smmry.setDegreesOfFreedom(df);
double sigmaHatSquared = epsHat.transpose().mul(epsHat).scaleInplace(1.0 / (df)).get(0, 0);
smmry.setSigmaHatSquared(sigmaHatSquared);
DMatrix varCov = X.transpose().mul(X).inverse().scaleInplace(sigmaHatSquared);
smmry.setVarianceCovarianceMatrix(varCov);
DoubleList standardErrors = new DoubleArrayList(varCov.numRows());
for (int i = 0; i < varCov.numRows(); ++i) {
double vari = varCov.get(i, i);
if (vari < 0.0) {
vari = Double.MIN_NORMAL;
varCov.set(i, i, vari);
}
standardErrors.add(Math.sqrt(vari));
}
smmry.setCoefficientStandardErrors(standardErrors);
DoubleList tValues = new DoubleArrayList(varCov.numRows());
DoubleList pValues = new DoubleArrayList(varCov.numRows());
ArrayList<DoubleList> confIntervals = new ArrayList<>();
StudentT tDist = new StudentT(df);
double tval = tDist.inverseCdf(1.0 - (alpha / 2.0));
for (int i = 0; i < varCov.numRows(); ++i) {
double coeff = beta.get(i, 0);
double se = standardErrors.get(i);
double t = coeff / se;
double p = 2.0 * (1.0 - tDist.cdf(Math.abs(t)));
double min = coeff - tval * se;
double max = coeff + tval * se;
tValues.add(t);
pValues.add(p);
confIntervals.add(DoubleList.of(min, max));
}
smmry.setTValues(tValues);
smmry.setPValues(pValues);
smmry.setConfidenceIntervals(confIntervals);
return smmry;
}
}

0 comments on commit 63f09f2

Please sign in to comment.