Skip to content

Commit

Permalink
added multiplication matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
schleichardt committed Apr 12, 2012
1 parent f7f99f7 commit 811053b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/main/scala/info/schleichardt/math/ValueMatrix.scala
@@ -1,5 +1,8 @@
package info.schleichardt.math package info.schleichardt.math


import collection.{Seq, GenTraversableOnce}
import collection.immutable.IndexedSeq

object ValueMatrix { object ValueMatrix {
def apply(input: Seq[Double]*) = new ValueMatrix(input) def apply(input: Seq[Double]*) = new ValueMatrix(input)
} }
Expand Down Expand Up @@ -58,4 +61,22 @@ class ValueMatrix(val content: Seq[Seq[Double]]) {
} }
new ValueMatrix(seq) new ValueMatrix(seq)
} }

lazy val isNullMatrix: Boolean = {
val allInALine: Seq[Double] = content.flatMap(x => x)
allInALine.forall(_ == 0)
}

def *(other: ValueMatrix): ValueMatrix = {
require(columnCount == other.content.length, "length matches")
val seq: Seq[Seq[Double]] =
for (lineResult <- 0 until content.length) yield {
for (columnResult <- 0 until content(0).length) yield {
(for (column <- 0 until content(0).length) yield {
content(lineResult)(column) * other.content(column)(columnResult)
}).sum
}
}
new ValueMatrix(seq)
}
} }
18 changes: 18 additions & 0 deletions src/test/scala/MatrixSpec.scala
Expand Up @@ -16,5 +16,23 @@ class MatrixSpec extends Specification with JUnit with ScalaTest {
val transposedMatrix = matrix.transpose val transposedMatrix = matrix.transpose
transposedMatrix must_== ValueMatrix(Seq(1, 1), Seq(-3, 2), Seq(2, 7)) transposedMatrix must_== ValueMatrix(Seq(1, 1), Seq(-3, 2), Seq(2, 7))
} }
"be null matrices" in {
val nullMatrix = ValueMatrix(Seq(0, 0, 0), Seq(0, 0, 0), Seq(0, 0, 0))
val notNullMatrix = ValueMatrix(Seq(1, 0, 2), Seq(1, 2, 7))
nullMatrix.isNullMatrix must be_==(true)
notNullMatrix.isNullMatrix must be_==(false)
}
"be multiplied" in {
val left = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1))
val right = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1), Seq(4, 0, 2))
left * right must be_==(ValueMatrix(Seq(9, 3, 5), Seq(26, 10, 12), Seq(22, 14, 10)))
}
"only fitting matrices can be multiplied" in {
val matrixWith3Rows = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1))//3 rows, 3 columns
val matrixWith2Rows = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1))//2 rows, 3 columns
matrixWith3Rows * matrixWith2Rows must throwA[IllegalArgumentException]
matrixWith2Rows * matrixWith3Rows must not(throwA[IllegalArgumentException])
}

} }
} }

0 comments on commit 811053b

Please sign in to comment.