diff --git a/src/main/scala/info/schleichardt/math/ValueMatrix.scala b/src/main/scala/info/schleichardt/math/ValueMatrix.scala index a89e56c..c73610d 100644 --- a/src/main/scala/info/schleichardt/math/ValueMatrix.scala +++ b/src/main/scala/info/schleichardt/math/ValueMatrix.scala @@ -1,5 +1,8 @@ package info.schleichardt.math +import collection.{Seq, GenTraversableOnce} +import collection.immutable.IndexedSeq + object ValueMatrix { def apply(input: Seq[Double]*) = new ValueMatrix(input) } @@ -58,4 +61,22 @@ class ValueMatrix(val content: Seq[Seq[Double]]) { } new ValueMatrix(seq) } + + lazy val isNullMatrix: Boolean = { + val allInALine: Seq[Double] = content.flatMap(x => x) + allInALine.forall(_ == 0) + } + + def *(other: ValueMatrix): ValueMatrix = { + require(columnCount == other.content.length, "length matches") + val seq: Seq[Seq[Double]] = + for (lineResult <- 0 until content.length) yield { + for (columnResult <- 0 until content(0).length) yield { + (for (column <- 0 until content(0).length) yield { + content(lineResult)(column) * other.content(column)(columnResult) + }).sum + } + } + new ValueMatrix(seq) + } } \ No newline at end of file diff --git a/src/test/scala/MatrixSpec.scala b/src/test/scala/MatrixSpec.scala index d73dd40..3a3e15b 100644 --- a/src/test/scala/MatrixSpec.scala +++ b/src/test/scala/MatrixSpec.scala @@ -16,5 +16,23 @@ class MatrixSpec extends Specification with JUnit with ScalaTest { val transposedMatrix = matrix.transpose transposedMatrix must_== ValueMatrix(Seq(1, 1), Seq(-3, 2), Seq(2, 7)) } + "be null matrices" in { + val nullMatrix = ValueMatrix(Seq(0, 0, 0), Seq(0, 0, 0), Seq(0, 0, 0)) + val notNullMatrix = ValueMatrix(Seq(1, 0, 2), Seq(1, 2, 7)) + nullMatrix.isNullMatrix must be_==(true) + notNullMatrix.isNullMatrix must be_==(false) + } + "be multiplied" in { + val left = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1)) + val right = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1), Seq(4, 0, 2)) + left * right must be_==(ValueMatrix(Seq(9, 3, 5), Seq(26, 10, 12), Seq(22, 14, 10))) + } + "only fitting matrices can be multiplied" in { + val matrixWith3Rows = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1))//3 rows, 3 columns + val matrixWith2Rows = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1))//2 rows, 3 columns + matrixWith3Rows * matrixWith2Rows must throwA[IllegalArgumentException] + matrixWith2Rows * matrixWith3Rows must not(throwA[IllegalArgumentException]) + } + } }