Permalink
Browse files

added multiplication matrices

  • Loading branch information...
1 parent f7f99f7 commit 811053b9497fd4e852ee11b6c98a4101a3b972ff @schleichardt committed Apr 12, 2012
Showing with 39 additions and 0 deletions.
  1. +21 −0 src/main/scala/info/schleichardt/math/ValueMatrix.scala
  2. +18 −0 src/test/scala/MatrixSpec.scala
@@ -1,5 +1,8 @@
package info.schleichardt.math
+import collection.{Seq, GenTraversableOnce}
+import collection.immutable.IndexedSeq
+
object ValueMatrix {
def apply(input: Seq[Double]*) = new ValueMatrix(input)
}
@@ -58,4 +61,22 @@ class ValueMatrix(val content: Seq[Seq[Double]]) {
}
new ValueMatrix(seq)
}
+
+ lazy val isNullMatrix: Boolean = {
+ val allInALine: Seq[Double] = content.flatMap(x => x)
+ allInALine.forall(_ == 0)
+ }
+
+ def *(other: ValueMatrix): ValueMatrix = {
+ require(columnCount == other.content.length, "length matches")
+ val seq: Seq[Seq[Double]] =
+ for (lineResult <- 0 until content.length) yield {
+ for (columnResult <- 0 until content(0).length) yield {
+ (for (column <- 0 until content(0).length) yield {
+ content(lineResult)(column) * other.content(column)(columnResult)
+ }).sum
+ }
+ }
+ new ValueMatrix(seq)
+ }
}
@@ -16,5 +16,23 @@ class MatrixSpec extends Specification with JUnit with ScalaTest {
val transposedMatrix = matrix.transpose
transposedMatrix must_== ValueMatrix(Seq(1, 1), Seq(-3, 2), Seq(2, 7))
}
+ "be null matrices" in {
+ val nullMatrix = ValueMatrix(Seq(0, 0, 0), Seq(0, 0, 0), Seq(0, 0, 0))
+ val notNullMatrix = ValueMatrix(Seq(1, 0, 2), Seq(1, 2, 7))
+ nullMatrix.isNullMatrix must be_==(true)
+ notNullMatrix.isNullMatrix must be_==(false)
+ }
+ "be multiplied" in {
+ val left = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1))
+ val right = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1), Seq(4, 0, 2))
+ left * right must be_==(ValueMatrix(Seq(9, 3, 5), Seq(26, 10, 12), Seq(22, 14, 10)))
+ }
+ "only fitting matrices can be multiplied" in {
+ val matrixWith3Rows = ValueMatrix(Seq(0, 1, 2), Seq(4, 2, 3), Seq(5, 3, 1))//3 rows, 3 columns
+ val matrixWith2Rows = ValueMatrix(Seq(3, 1, 1), Seq(1, 3, 1))//2 rows, 3 columns
+ matrixWith3Rows * matrixWith2Rows must throwA[IllegalArgumentException]
+ matrixWith2Rows * matrixWith3Rows must not(throwA[IllegalArgumentException])
+ }
+
}
}

0 comments on commit 811053b

Please sign in to comment.