Skip to content

Commit

Permalink
Switch to interface implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Steve Bentley committed Apr 7, 2017
1 parent 0225906 commit be6ed10
Show file tree
Hide file tree
Showing 9 changed files with 864 additions and 372 deletions.
102 changes: 102 additions & 0 deletions doc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package lingo

// import "fmt"

// // Add matrix m1 and matrix m2, creating matrix m3.
// func ExampleMatrix_Add() {
// m1 := Matrix{
// {1, 2, 3},
// {2, 4, 5},
// }

// m2 := Matrix{
// {3, 5, 1},
// {4, 4, 3},
// }

// m3, _ := m1.Add(m2)

// fmt.Println(m3)
// // Output: [4 7 4]
// // [6 8 8]
// }

// // Subtract matrix m1 from matrix m2, creating matrix m3.
// func ExampleMatrix_Subtract() {
// m1 := Matrix{
// {1, 2, 3},
// {2, 4, 5},
// }

// m2 := Matrix{
// {3, 5, 1},
// {4, 4, 3},
// }

// m3, _ := m2.Subtract(m1)

// fmt.Println(m3)
// // Output: [2 3 -2]
// // [2 0 -2]
// }

// // Element wise multiplcation of matrix m1 by matrix m2, creating matrix m3.
// func ExampleMatrix_Multiply() {
// m1 := Matrix{
// {1, 2, 3},
// {2, 4, 5},
// }

// m2 := Matrix{
// {3, 5, 1},
// {4, 4, 3},
// }

// m3, _ := m1.Multiply(m2)

// fmt.Println(m3)
// // Output: [3 10 3]
// // [8 16 15]
// }

// // Dot product of matrix m1 and matrix m2, creating matrix m3.
// func ExampleMatrix_Dot() {
// m1 := Matrix{
// {1, 2, 3},
// {2, 4, 5},
// }

// m2 := Matrix{
// {1, 2},
// {3, 2},
// {4, 5},
// }

// m3, _ := m1.Dot(m2)

// fmt.Println(m3)
// // Output: [19 21]
// // [34 37]
// }

// // Get the value at position 1,2 in matrix m.
// func ExampleMatrix_Value() {
// m := Matrix{
// {1, 2, 3},
// {2, 4, 5},
// }

// v, _ := m.Value(1, 2)

// fmt.Println(v)
// // Output: 5
// }

// // Rows
// // Columns
// // Dimensions
// // Equal
// // Scale
// // Transpose
// // Reshape
// // ToVector
94 changes: 93 additions & 1 deletion lingo.go
Original file line number Diff line number Diff line change
@@ -1,2 +1,94 @@
// Package lingo provides matrix operations for use in linear algebra
// Package lingo provides matrix and vector operations
package lingo

import (
"errors"
)

type Tensor interface {
Order() int
Rows() int
Columns() int
Value(position ...int) (float64, error)
SetValue(value float64, position ...int) (Tensor, error)
Reshape(dims ...int) (Tensor, error)
String() string
}

// Add adds two compatible tensors together.
// An error is returned if the tensors are not compatible sizes.
func Add(m, o Tensor) (Tensor, error) {
if !matchSize(m, o) {
return nil, errors.New("incompatible matrices")
}

// Scalar
if m.Order() == 0 {
mValue, err := m.Value()
if err != nil {
return nil, err
}
oValue, err := o.Value()
if err != nil {
return nil, err
}
return Scalar(float64(mValue) + float64(oValue)), nil
}

var r Tensor
var err error

if m.Rows() == 1 {
r, err = newZeroTensor(m.Columns())
if err != nil {
return nil, err
}
}

if m.Rows() > 2 {
r, err = newZeroTensor(m.Rows(), m.Columns())
if err != nil {
return nil, err
}
}

for mRows := 0; mRows < m.Rows(); mRows++ {
for mCols := 0; mCols < m.Columns(); mCols++ {
mValue, _ := m.Value(mRows, mCols)
oValue, _ := o.Value(mRows, mCols)
r.SetValue(mValue+oValue, mRows, mCols)
}
}

return nil, nil
}

func matchSize(m1, m2 Tensor) bool {
if m1.Columns() != m2.Columns() || m1.Rows() != m2.Rows() {
return false
}
return true
}

func newZeroTensor(dims ...int) (Tensor, error) {
if len(dims) == 0 {
s := Scalar(0)
return s, nil
}

if len(dims) == 1 {
v := Vector(make([]float64, dims[0]))
return v, nil
}

if len(dims) == 2 {
m := Matrix{}
for x := 0; x < dims[0]; x++ {
r := make([]float64, dims[1], dims[1])
m = append(m, r)
}
return m, nil
}

return nil, errors.New("more than 2 dimensions is not supported")
}
14 changes: 14 additions & 0 deletions lingo_test.go
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
package lingo

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestAddScalarToScalar(t *testing.T) {
s1 := Scalar(1)
s2 := Scalar(2)
s3, err := Add(s1, s2)

assert.Equal(t, nil, err, "err is not nil")
assert.Equal(t, Scalar(3), s3, "value of 1 + 2 is not 3")
}
Loading

0 comments on commit be6ed10

Please sign in to comment.