Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP

Loading…

Linear regression #52

Merged
merged 7 commits into from

2 participants

@njern
Collaborator

An implementation of linear regression, accompanied by unit tests & benchmarks.

Includes some test data (exams.csv / exam.csv) from here

Tested & compared results using the online calculator available here and results seem to match. Meaning, either my implementation is correct or I have re-implemented the same bugs as xuru.org... :smile:

It's also reasonably fast. Single-threaded on my machine:

BenchmarkLinearRegressionOneRow  5000000           326 ns/op
@njern
Collaborator

One thing to note here is that I slightly diverged from previous the Fit() / Predict() interface by returning errors from both. I think this would be a great refactoring in other parts of the library as well... but please let me know what you think.

@sjwhitworth sjwhitworth merged commit 222b0ab into sjwhitworth:master

1 check passed

Details continuous-integration/travis-ci The Travis CI build passed
@njern njern referenced this pull request
Merged

Add benchmark #53

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
This page is out of date. Refresh to see the latest.
View
2  base/csv.go
@@ -98,7 +98,7 @@ func ParseCSVSniffAttributeTypes(filepath string, hasHeaders bool) []Attribute {
for _, entry := range columns {
entry = strings.Trim(entry, " ")
matched, err := regexp.MatchString("^[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?$", entry)
- fmt.Println(entry, matched)
+ //fmt.Println(entry, matched)
if err != nil {
panic(err)
}
View
2  examples/datasets/exam.csv
@@ -0,0 +1,2 @@
+EXAM1,EXAM2,EXAM3,FINAL
+73,80,75,152
View
26 examples/datasets/exams.csv
@@ -0,0 +1,26 @@
+EXAM1,EXAM2,EXAM3,FINAL
+73,80,75,152
+93,88,93,185
+89,91,90,180
+96,98,100,196
+73,66,70,142
+53,46,55,101
+69,74,77,149
+47,56,60,115
+87,79,90,175
+79,70,88,164
+69,70,73,141
+70,65,74,141
+93,95,91,184
+79,80,73,152
+70,73,78,148
+93,89,96,192
+78,75,68,147
+81,90,93,183
+88,92,86,177
+78,83,77,159
+82,86,90,177
+86,82,89,175
+78,83,85,175
+76,83,71,149
+96,93,95,192
View
5 linear_models/doc.go
@@ -0,0 +1,5 @@
+/*
+Package linear_models implements linear
+and logistic regression models.
+*/
+package linear_models
View
98 linear_models/linear_regression.go
@@ -0,0 +1,98 @@
+package linear_models
+
+import (
+ "errors"
+
+ "github.com/sjwhitworth/golearn/base"
+
+ _ "github.com/gonum/blas"
+ "github.com/gonum/blas/cblas"
+ "github.com/gonum/matrix/mat64"
+)
+
+var (
+ NotEnoughDataError = errors.New("not enough rows to support this many variables.")
+ NoTrainingDataError = errors.New("you need to Fit() before you can Predict()")
+)
+
+type LinearRegression struct {
+ fitted bool
+ disturbance float64
+ regressionCoefficients []float64
+}
+
+func init() {
+ mat64.Register(cblas.Blas{})
+}
+
+func NewLinearRegression() *LinearRegression {
+ return &LinearRegression{fitted: false}
+}
+
+func (lr *LinearRegression) Fit(inst *base.Instances) error {
+ if inst.Rows < inst.GetAttributeCount() {
+ return NotEnoughDataError
+ }
+
+ // Split into two matrices, observed results (dependent variable y)
+ // and the explanatory variables (X) - see http://en.wikipedia.org/wiki/Linear_regression
+ observed := mat64.NewDense(inst.Rows, 1, nil)
+ explVariables := mat64.NewDense(inst.Rows, inst.GetAttributeCount(), nil)
+
+ for i := 0; i < inst.Rows; i++ {
+ observed.Set(i, 0, inst.Get(i, inst.ClassIndex)) // Set observed data
+
+ for j := 0; j < inst.GetAttributeCount(); j++ {
+ if j == 0 {
+ // Set intercepts to 1.0
+ // Could / should be done better: http://www.theanalysisfactor.com/interpret-the-intercept/
+ explVariables.Set(i, 0, 1.0)
+ } else {
+ explVariables.Set(i, j, inst.Get(i, j-1))
+ }
+ }
+ }
+
+ n := inst.GetAttributeCount()
+ qr := mat64.QR(explVariables)
+ q := qr.Q()
+ reg := qr.R()
+
+ var transposed, qty mat64.Dense
+ transposed.TCopy(q)
+ qty.Mul(&transposed, observed)
+
+ regressionCoefficients := make([]float64, n)
+ for i := n - 1; i >= 0; i-- {
+ regressionCoefficients[i] = qty.At(i, 0)
+ for j := i + 1; j < n; j++ {
+ regressionCoefficients[i] -= regressionCoefficients[j] * reg.At(i, j)
+ }
+ regressionCoefficients[i] /= reg.At(i, i)
+ }
+
+ lr.disturbance = regressionCoefficients[0]
+ lr.regressionCoefficients = regressionCoefficients[1:]
+ lr.fitted = true
+
+ return nil
+}
+
+func (lr *LinearRegression) Predict(X *base.Instances) (*base.Instances, error) {
+ if !lr.fitted {
+ return nil, NoTrainingDataError
+ }
+
+ ret := X.GeneratePredictionVector()
+ for i := 0; i < X.Rows; i++ {
+ var prediction float64 = lr.disturbance
+ for j := 0; j < X.Cols; j++ {
+ if j != X.ClassIndex {
+ prediction += X.Get(i, j) * lr.regressionCoefficients[j]
+ }
+ }
+ ret.Set(i, 0, prediction)
+ }
+
+ return ret, nil
+}
View
60 linear_models/linear_regression_test.go
@@ -0,0 +1,60 @@
+package linear_models
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/sjwhitworth/golearn/base"
+)
+
+func TestNoTrainingData(t *testing.T) {
+ lr := NewLinearRegression()
+
+ rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ _, err = lr.Predict(rawData)
+ if err != NoTrainingDataError {
+ t.Fatal("failed to error out even if no training data exists")
+ }
+}
+
+func TestNotEnoughTrainingData(t *testing.T) {
+ lr := NewLinearRegression()
+
+ rawData, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = lr.Fit(rawData)
+ if err != NotEnoughDataError {
+ t.Fatal("failed to error out even though there was not enough data")
+ }
+}
+
+func TestLinearRegression(t *testing.T) {
+ lr := NewLinearRegression()
+
+ rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
+ err = lr.Fit(trainData)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ predictions, err := lr.Predict(testData)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for i := 0; i < predictions.Rows; i++ {
+ fmt.Printf("Expected: %f || Predicted: %f\n", testData.Get(i, testData.ClassIndex), predictions.Get(i, predictions.ClassIndex))
+ }
+}
View
1  lm/linear_regression.go
@@ -1 +0,0 @@
-package lm
Something went wrong with that request. Please try again.