-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Steve Bentley
committed
Apr 7, 2017
1 parent
0225906
commit be6ed10
Showing
9 changed files
with
864 additions
and
372 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package lingo | ||
|
||
// import "fmt" | ||
|
||
// // Add matrix m1 and matrix m2, creating matrix m3. | ||
// func ExampleMatrix_Add() { | ||
// m1 := Matrix{ | ||
// {1, 2, 3}, | ||
// {2, 4, 5}, | ||
// } | ||
|
||
// m2 := Matrix{ | ||
// {3, 5, 1}, | ||
// {4, 4, 3}, | ||
// } | ||
|
||
// m3, _ := m1.Add(m2) | ||
|
||
// fmt.Println(m3) | ||
// // Output: [4 7 4] | ||
// // [6 8 8] | ||
// } | ||
|
||
// // Subtract matrix m1 from matrix m2, creating matrix m3. | ||
// func ExampleMatrix_Subtract() { | ||
// m1 := Matrix{ | ||
// {1, 2, 3}, | ||
// {2, 4, 5}, | ||
// } | ||
|
||
// m2 := Matrix{ | ||
// {3, 5, 1}, | ||
// {4, 4, 3}, | ||
// } | ||
|
||
// m3, _ := m2.Subtract(m1) | ||
|
||
// fmt.Println(m3) | ||
// // Output: [2 3 -2] | ||
// // [2 0 -2] | ||
// } | ||
|
||
// // Element wise multiplcation of matrix m1 by matrix m2, creating matrix m3. | ||
// func ExampleMatrix_Multiply() { | ||
// m1 := Matrix{ | ||
// {1, 2, 3}, | ||
// {2, 4, 5}, | ||
// } | ||
|
||
// m2 := Matrix{ | ||
// {3, 5, 1}, | ||
// {4, 4, 3}, | ||
// } | ||
|
||
// m3, _ := m1.Multiply(m2) | ||
|
||
// fmt.Println(m3) | ||
// // Output: [3 10 3] | ||
// // [8 16 15] | ||
// } | ||
|
||
// // Dot product of matrix m1 and matrix m2, creating matrix m3. | ||
// func ExampleMatrix_Dot() { | ||
// m1 := Matrix{ | ||
// {1, 2, 3}, | ||
// {2, 4, 5}, | ||
// } | ||
|
||
// m2 := Matrix{ | ||
// {1, 2}, | ||
// {3, 2}, | ||
// {4, 5}, | ||
// } | ||
|
||
// m3, _ := m1.Dot(m2) | ||
|
||
// fmt.Println(m3) | ||
// // Output: [19 21] | ||
// // [34 37] | ||
// } | ||
|
||
// // Get the value at position 1,2 in matrix m. | ||
// func ExampleMatrix_Value() { | ||
// m := Matrix{ | ||
// {1, 2, 3}, | ||
// {2, 4, 5}, | ||
// } | ||
|
||
// v, _ := m.Value(1, 2) | ||
|
||
// fmt.Println(v) | ||
// // Output: 5 | ||
// } | ||
|
||
// // Rows | ||
// // Columns | ||
// // Dimensions | ||
// // Equal | ||
// // Scale | ||
// // Transpose | ||
// // Reshape | ||
// // ToVector |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,94 @@ | ||
// Package lingo provides matrix operations for use in linear algebra | ||
// Package lingo provides matrix and vector operations | ||
package lingo | ||
|
||
import ( | ||
"errors" | ||
) | ||
|
||
type Tensor interface { | ||
Order() int | ||
Rows() int | ||
Columns() int | ||
Value(position ...int) (float64, error) | ||
SetValue(value float64, position ...int) (Tensor, error) | ||
Reshape(dims ...int) (Tensor, error) | ||
String() string | ||
} | ||
|
||
// Add adds two compatible tensors together. | ||
// An error is returned if the tensors are not compatible sizes. | ||
func Add(m, o Tensor) (Tensor, error) { | ||
if !matchSize(m, o) { | ||
return nil, errors.New("incompatible matrices") | ||
} | ||
|
||
// Scalar | ||
if m.Order() == 0 { | ||
mValue, err := m.Value() | ||
if err != nil { | ||
return nil, err | ||
} | ||
oValue, err := o.Value() | ||
if err != nil { | ||
return nil, err | ||
} | ||
return Scalar(float64(mValue) + float64(oValue)), nil | ||
} | ||
|
||
var r Tensor | ||
var err error | ||
|
||
if m.Rows() == 1 { | ||
r, err = newZeroTensor(m.Columns()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
if m.Rows() > 2 { | ||
r, err = newZeroTensor(m.Rows(), m.Columns()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
|
||
for mRows := 0; mRows < m.Rows(); mRows++ { | ||
for mCols := 0; mCols < m.Columns(); mCols++ { | ||
mValue, _ := m.Value(mRows, mCols) | ||
oValue, _ := o.Value(mRows, mCols) | ||
r.SetValue(mValue+oValue, mRows, mCols) | ||
} | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func matchSize(m1, m2 Tensor) bool { | ||
if m1.Columns() != m2.Columns() || m1.Rows() != m2.Rows() { | ||
return false | ||
} | ||
return true | ||
} | ||
|
||
func newZeroTensor(dims ...int) (Tensor, error) { | ||
if len(dims) == 0 { | ||
s := Scalar(0) | ||
return s, nil | ||
} | ||
|
||
if len(dims) == 1 { | ||
v := Vector(make([]float64, dims[0])) | ||
return v, nil | ||
} | ||
|
||
if len(dims) == 2 { | ||
m := Matrix{} | ||
for x := 0; x < dims[0]; x++ { | ||
r := make([]float64, dims[1], dims[1]) | ||
m = append(m, r) | ||
} | ||
return m, nil | ||
} | ||
|
||
return nil, errors.New("more than 2 dimensions is not supported") | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,15 @@ | ||
package lingo | ||
|
||
import ( | ||
"github.com/stretchr/testify/assert" | ||
"testing" | ||
) | ||
|
||
func TestAddScalarToScalar(t *testing.T) { | ||
s1 := Scalar(1) | ||
s2 := Scalar(2) | ||
s3, err := Add(s1, s2) | ||
|
||
assert.Equal(t, nil, err, "err is not nil") | ||
assert.Equal(t, Scalar(3), s3, "value of 1 + 2 is not 3") | ||
} |
Oops, something went wrong.