-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add naive ordinary least squares regression
- Loading branch information
1 parent
09428dc
commit 63f09f2
Showing
2 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
/* | ||
* Copyright 2023 Stefan Zobel | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package math.linalg; | ||
|
||
import java.util.List; | ||
|
||
import math.list.DoubleArrayList; | ||
import math.list.DoubleList; | ||
|
||
/** | ||
* Least squares regression summary. | ||
*/ | ||
public class LSSummary { | ||
|
||
// significance level used | ||
private double alpha; | ||
|
||
// parameters (beta) | ||
private DoubleList coefficients; | ||
|
||
// designMatrix X | ||
private DMatrix designMatrix; | ||
|
||
// regressand y | ||
private DMatrix regressand; | ||
|
||
// average value of y | ||
private double yBar; | ||
|
||
// predicted values of y | ||
private DoubleList yHat; | ||
|
||
// y - yHat = epsilon | ||
private DoubleList residuals; | ||
|
||
// coefficient of determination | ||
private double rSquared; | ||
|
||
// population variance estimator | ||
private double sigmaHatSquared; | ||
|
||
// var-cov matrix of the coefficients | ||
private DMatrix varCovMatrix; | ||
|
||
// standard error estimates of the coefficient estimators | ||
private DoubleList coefficientStandardErrors; | ||
|
||
// t values of the coefficients | ||
private DoubleList tValues; | ||
|
||
// p values of the coefficients | ||
private DoubleList pValues; | ||
|
||
// confidence intervals of the coefficients | ||
private List<DoubleList> confidenceIntervals; | ||
|
||
// df for t-distribution | ||
private int degreesOfFreedom; | ||
|
||
public LSSummary(double alpha, DMatrix designMatrix, DMatrix regressand) { | ||
this.alpha = alpha; | ||
this.designMatrix = designMatrix; | ||
this.regressand = regressand; | ||
} | ||
|
||
public void clearTemporaries() { | ||
designMatrix = null; | ||
regressand = null; | ||
yHat = null; | ||
residuals = null; | ||
tValues = null; | ||
} | ||
|
||
public DoubleList getBeta() { | ||
return coefficients; | ||
} | ||
|
||
void setBeta(DMatrix beta) { | ||
coefficients = new DoubleArrayList(beta.numRows()); | ||
for (int i = 0; i < beta.numRows(); ++i) { | ||
coefficients.add(beta.get(i, 0)); | ||
} | ||
} | ||
|
||
public double getYBar() { | ||
return yBar; | ||
} | ||
|
||
void setYBar(double yBar) { | ||
this.yBar = yBar; | ||
} | ||
|
||
public DoubleList getYHat() { | ||
return yHat; | ||
} | ||
|
||
void setYHat(DMatrix yEst) { | ||
yHat = new DoubleArrayList(yEst.numRows()); | ||
for (int i = 0; i < yEst.numRows(); ++i) { | ||
yHat.add(yEst.get(i, 0)); | ||
} | ||
} | ||
|
||
public DoubleList getResiduals() { | ||
return residuals; | ||
} | ||
|
||
void setResiduals(DMatrix epsilonHat) { | ||
residuals = new DoubleArrayList(epsilonHat.numRows()); | ||
for (int i = 0; i < epsilonHat.numRows(); ++i) { | ||
residuals.add(epsilonHat.get(i, 0)); | ||
} | ||
} | ||
|
||
public double getRSquared() { | ||
return rSquared; | ||
} | ||
|
||
void setRSquared(double rSquared) { | ||
this.rSquared = rSquared; | ||
} | ||
|
||
public double getSigmaHatSquared() { | ||
return sigmaHatSquared; | ||
} | ||
|
||
void setSigmaHatSquared(double sigmaHatSquared) { | ||
this.sigmaHatSquared = sigmaHatSquared; | ||
} | ||
|
||
public DMatrix getVarianceCovarianceMatrix() { | ||
return varCovMatrix; | ||
} | ||
|
||
void setVarianceCovarianceMatrix(DMatrix varianceCovarianceMatrix) { | ||
this.varCovMatrix = varianceCovarianceMatrix; | ||
} | ||
|
||
public DoubleList getCoefficientStandardErrors() { | ||
return coefficientStandardErrors; | ||
} | ||
|
||
void setCoefficientStandardErrors(DoubleList coefficientStandardErrors) { | ||
this.coefficientStandardErrors = coefficientStandardErrors; | ||
} | ||
|
||
public DoubleList getTValues() { | ||
return tValues; | ||
} | ||
|
||
void setTValues(DoubleList tValues) { | ||
this.tValues = tValues; | ||
} | ||
|
||
public DoubleList getPValues() { | ||
return pValues; | ||
} | ||
|
||
void setPValues(DoubleList pValues) { | ||
this.pValues = pValues; | ||
} | ||
|
||
public List<DoubleList> getConfidenceIntervals() { | ||
return confidenceIntervals; | ||
} | ||
|
||
void setConfidenceIntervals(List<DoubleList> confidenceIntervals) { | ||
this.confidenceIntervals = confidenceIntervals; | ||
} | ||
|
||
public int getDegreesOfFreedom() { | ||
return degreesOfFreedom; | ||
} | ||
|
||
void setDegreesOfFreedom(int degreesOfFreedom) { | ||
this.degreesOfFreedom = degreesOfFreedom; | ||
} | ||
|
||
public double getAlpha() { | ||
return alpha; | ||
} | ||
|
||
public DMatrix getXMatrix() { | ||
return designMatrix; | ||
} | ||
|
||
public DMatrix getYVector() { | ||
return regressand; | ||
} | ||
|
||
public int getCoefficientsCount() { | ||
return coefficients.size(); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "Summary [alpha=" + alpha + ", numCoefficients=" + getCoefficientsCount() + ",\n coefficients=" | ||
+ coefficients + ",\n yBar=" + yBar + ", rSquared=" + rSquared + ", sigmaHatSquared=" + sigmaHatSquared | ||
+ ",\n varCovMatrix=" + varCovMatrix + ", coefficientStandardErrors=" + coefficientStandardErrors | ||
+ ",\n pValues=" + pValues + ",\n confidenceIntervals=" + confidenceIntervals + ",\n degreesOfFreedom=" | ||
+ degreesOfFreedom + "]"; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* Copyright 2023 Stefan Zobel | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package math.linalg; | ||
|
||
import java.util.ArrayList; | ||
|
||
import math.distribution.StudentT; | ||
import math.list.DoubleArrayList; | ||
import math.list.DoubleList; | ||
|
||
/** | ||
* Poor man's naive ordinary least squares regression. | ||
*/ | ||
public class OLS { | ||
|
||
public static LSSummary estimate(double alpha, DMatrix X, DMatrix y) { | ||
if (X.numRows() != y.numRows()) { | ||
throw new IllegalArgumentException("X.numRows != y.numRows : " + X.numRows() + " != " + y.numRows()); | ||
} | ||
if (X.numRows() - X.numColumns() < 1) { | ||
throw new IllegalArgumentException("degrees of freedom < 1 : " + (X.numRows() - X.numColumns())); | ||
} | ||
if (alpha <= 0.0) { | ||
throw new IllegalArgumentException("alpha <= 0 : " + alpha); | ||
} | ||
if (alpha >= 1.0) { | ||
throw new IllegalArgumentException("alpha >= 1 : " + alpha); | ||
} | ||
LSSummary smmry = new LSSummary(alpha, X, y); | ||
DMatrix Xtrans = X.transpose(); | ||
// Note: this may be numerically unstable! | ||
DMatrix beta = Xtrans.mul(X).inverse().mul(Xtrans).mul(y); | ||
smmry.setBeta(beta); | ||
DMatrix yHat = X.mul(beta); | ||
smmry.setYHat(yHat); | ||
DMatrix ones = new DMatrix(1, y.numRows()); | ||
for (int i = 0; i < y.numRows(); ++i) { | ||
ones.setUnsafe(0, i, 1.0); | ||
} | ||
double ybar = ones.mul(y).scaleInplace(1.0 / y.numRows()).get(0, 0); | ||
smmry.setYBar(ybar); | ||
ones = new DMatrix(y.numRows(), 1); | ||
for (int i = 0; i < y.numRows(); ++i) { | ||
ones.setUnsafe(i, 0, 1.0); | ||
} | ||
DMatrix yBarMat = ones.scaleInplace(ybar); | ||
DMatrix a = yHat.minus(yBarMat); | ||
DMatrix b = y.minus(yBarMat); | ||
double SQE = a.transpose().mul(a).get(0, 0); | ||
double SQT = b.transpose().mul(b).get(0, 0); | ||
double R_squared = SQE / SQT; | ||
smmry.setRSquared(R_squared > 1.0 ? 1.0 : R_squared); | ||
DMatrix epsHat = y.minus(yHat); | ||
smmry.setResiduals(epsHat); | ||
int df = epsHat.numRows() - X.numColumns(); | ||
smmry.setDegreesOfFreedom(df); | ||
double sigmaHatSquared = epsHat.transpose().mul(epsHat).scaleInplace(1.0 / (df)).get(0, 0); | ||
smmry.setSigmaHatSquared(sigmaHatSquared); | ||
DMatrix varCov = X.transpose().mul(X).inverse().scaleInplace(sigmaHatSquared); | ||
smmry.setVarianceCovarianceMatrix(varCov); | ||
DoubleList standardErrors = new DoubleArrayList(varCov.numRows()); | ||
for (int i = 0; i < varCov.numRows(); ++i) { | ||
double vari = varCov.get(i, i); | ||
if (vari < 0.0) { | ||
vari = Double.MIN_NORMAL; | ||
varCov.set(i, i, vari); | ||
} | ||
standardErrors.add(Math.sqrt(vari)); | ||
} | ||
smmry.setCoefficientStandardErrors(standardErrors); | ||
DoubleList tValues = new DoubleArrayList(varCov.numRows()); | ||
DoubleList pValues = new DoubleArrayList(varCov.numRows()); | ||
ArrayList<DoubleList> confIntervals = new ArrayList<>(); | ||
StudentT tDist = new StudentT(df); | ||
double tval = tDist.inverseCdf(1.0 - (alpha / 2.0)); | ||
for (int i = 0; i < varCov.numRows(); ++i) { | ||
double coeff = beta.get(i, 0); | ||
double se = standardErrors.get(i); | ||
double t = coeff / se; | ||
double p = 2.0 * (1.0 - tDist.cdf(Math.abs(t))); | ||
double min = coeff - tval * se; | ||
double max = coeff + tval * se; | ||
tValues.add(t); | ||
pValues.add(p); | ||
confIntervals.add(DoubleList.of(min, max)); | ||
} | ||
smmry.setTValues(tValues); | ||
smmry.setPValues(pValues); | ||
smmry.setConfidenceIntervals(confIntervals); | ||
return smmry; | ||
} | ||
} |