-
Notifications
You must be signed in to change notification settings - Fork 693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] [BREEZE-590] DenseTensor Implementation #695
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package breeze.linalg | ||
|
||
import scala.collection.mutable | ||
import scala.reflect.ClassTag | ||
import scala.{specialized => spec} | ||
|
||
/** | ||
* A DenseTensorN is an N-dimensional tensor with all elements flattened into an array. It is | ||
* column major unless isTranspose is true. | ||
* | ||
* @author sujithjay | ||
* @param data The underlying data. | ||
* Column-major unless isTranspose is true. | ||
* Mutate at your own risk. | ||
* Note that this tensor may be a view of the data. | ||
* @param shape | ||
* @param offset | ||
* @param stride | ||
* @param isTranspose | ||
*/ | ||
@SerialVersionUID(1L) | ||
final class DenseTensorN[@spec(Double, Int, Float, Long) V]( | ||
val data: Array[V], | ||
val shape: IndexedSeq[Int], | ||
val offset: Int, | ||
val stride: Int, | ||
val isTranspose: Boolean = false) | ||
extends TensorN[V] | ||
with TensorNLike[V, DenseTensorN[V]] | ||
with Serializable { | ||
|
||
def this(data: Array[V]) = this(data, IndexedSeq[Int](data.length), 0, 1) | ||
def this(shape: IndexedSeq[Int])(implicit man: ClassTag[V]) = this(new Array[V](shape.product), shape, 0, 1) | ||
|
||
def apply(index: IndexedSeq[Int]): V = { | ||
val isValidIndex = index | ||
.zipWithIndex | ||
.forall{ idx => | ||
idx._2 > -shape(idx._1) && idx._2 < shape(idx._1) | ||
} | ||
if(!isValidIndex) throw new IndexOutOfBoundsException("") | ||
val trueIndex: IndexedSeq[Int] = index | ||
.zipWithIndex | ||
.map{ idx => | ||
if(idx._2 < 0) idx._2 + shape(idx._1) else idx._2 | ||
} | ||
data(linearIndex(trueIndex)) | ||
} | ||
|
||
def update(index: IndexedSeq[Int], v: V): Unit = { | ||
val isValidIndex = index | ||
.zipWithIndex | ||
.forall{ idx => | ||
idx._2 > -shape(idx._1) && idx._2 < shape(idx._1) | ||
} | ||
if(!isValidIndex) throw new IndexOutOfBoundsException("") | ||
val trueIndex: IndexedSeq[Int] = index | ||
.zipWithIndex | ||
.map{ idx => | ||
if(idx._2 < 0) idx._2 + shape(idx._1) else idx._2 | ||
} | ||
data(linearIndex(trueIndex)) = v | ||
} | ||
|
||
def size: Int = shape.product | ||
|
||
def activeSize: Int = data.length | ||
|
||
def iterator: Iterator[(IndexedSeq[Int], V)] = for(i <- Iterator.range(0, size)) yield ndIndex(i) -> data(i) | ||
|
||
def activeIterator: Iterator[(IndexedSeq[Int], V)] = iterator | ||
|
||
def valuesIterator: Iterator[V] = for(i <- Iterator.range(0, size)) yield data(i) | ||
|
||
def activeValuesIterator: Iterator[V] = valuesIterator | ||
|
||
def keysIterator: Iterator[IndexedSeq[Int]] = for(i <- Iterator.range(0, size)) yield ndIndex(i) | ||
|
||
def activeKeysIterator: Iterator[IndexedSeq[Int]] = keysIterator | ||
|
||
def repr: DenseTensorN[V] = this | ||
|
||
/** | ||
* Calculates the linear index from its equivalent n-dimensional index | ||
* @param ndIndex | ||
* @return The linear index | ||
*/ | ||
def linearIndex(ndIndex: IndexedSeq[Int]): Int = { | ||
val logicalIndex = if(isTranspose){ | ||
ndIndex.zipWithIndex.foldLeft(0){(prev, idx) => | ||
idx._2 + shape(idx._1) * prev | ||
} | ||
} | ||
else { | ||
ndIndex.zipWithIndex.foldRight(0){(idx, prev) => | ||
idx._2 + shape(idx._1) * prev | ||
} | ||
} | ||
|
||
logicalIndex * (stride + 1) + offset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this In any event, I don't think that an element-by-element stride isn't really that useful for higher-order tensors. In practice, I think you'd you want a per-dimension stride. (Or, alternatively, the stride only to apply to the "major" axis, as in DenseMatrix) |
||
} | ||
|
||
/** | ||
* Calculates the n-dimensional index from its equivalent linear index | ||
* @param linearIndex | ||
* @return The n-dimension index | ||
*/ | ||
def ndIndex(linearIndex: Int): IndexedSeq[Int] = { | ||
val logicalIndex = (linearIndex - offset) / (stride + 1) | ||
val ret = mutable.ArrayBuffer[Int]() | ||
if(isTranspose){ | ||
shape.foldRight((this.size, logicalIndex)) { (dim, tup) => | ||
val sz = tup._1 / dim | ||
val rem = tup._2 % sz | ||
val idx = tup._2 / sz | ||
ret += idx | ||
(sz, rem) | ||
} | ||
} | ||
else { | ||
shape.foldLeft((this.size, logicalIndex)) { (tup, dim) => | ||
val sz = tup._1 / dim | ||
val rem = tup._2 % sz | ||
val idx = tup._2 / sz | ||
ret += idx | ||
(sz, rem) | ||
} | ||
} | ||
ret | ||
} | ||
|
||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package breeze.linalg | ||
|
||
import breeze.linalg.support.CanMapValues | ||
|
||
import scala.annotation.unchecked.uncheckedVariance | ||
import scala.{specialized => spec} | ||
|
||
|
||
|
||
trait TensorNLike[@spec(Double, Int, Float, Long) V, +Self <: TensorN[V]] | ||
extends Tensor[IndexedSeq[Int], V] | ||
with TensorLike[IndexedSeq[Int], V, Self]{ | ||
def map[V2, That](fn: V => V2)(implicit canMapValues: CanMapValues[Self @uncheckedVariance, V, V2, That]): That = | ||
values.map(fn) | ||
|
||
} | ||
|
||
/** | ||
* @author sujithjay | ||
* @tparam V | ||
*/ | ||
trait TensorN[@spec(Int, Long, Double, Float) V] | ||
extends TensorNLike[V, TensorN[V]]{ | ||
|
||
def foreach[U](fn: V => U): Unit = { values.foreach(fn) } | ||
|
||
def shape: IndexedSeq[Int] | ||
|
||
def keySet: Set[IndexedSeq[Int]] = new Set[IndexedSeq[Int]]{ | ||
def contains(elem: IndexedSeq[Int]): Boolean = elem.zipWithIndex.forall{ idx => { | ||
idx._2 >= 0 && idx._2 < shape(idx._1) | ||
} | ||
} | ||
|
||
def +(elem: IndexedSeq[Int]): Set[IndexedSeq[Int]] = throw new UnsupportedOperationException(" member '+' is not supported in TensorN.keySet() ") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can just coerce to a set with |
||
def -(elem: IndexedSeq[Int]): Set[IndexedSeq[Int]] = throw new UnsupportedOperationException(" member '-' is not supported in TensorN.keySet() ") | ||
def iterator: Iterator[IndexedSeq[Int]] = throw new UnsupportedOperationException(" member 'iterator' is not supported in TensorN.keySet() ") | ||
} | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package breeze.linalg | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
class DenseTensorNTest extends FunSuite { | ||
|
||
test("testNdIndex") { | ||
val linearIndex = 10 | ||
val shape = IndexedSeq(3, 3) | ||
val stride = 1 | ||
val offset = 4 | ||
val denseTensorN = new DenseTensorN[Long](new Array[Long](shape.product),shape, offset, stride) | ||
|
||
/* Wrote an imperative-style method to check correctness of the original ndIndex method. Need to write some proper test cases. */ | ||
def imperativeNDIndex(linearIndex: Int, size: Int): IndexedSeq[Int] = { | ||
var denom = size | ||
var index = (linearIndex - offset) / (stride + 1) | ||
val ret = new ArrayBuffer[Int](shape.length) | ||
var i = 0 | ||
for(i <- shape.indices){ | ||
denom /= shape(i) | ||
ret.prepend( index / denom ) | ||
index %= denom | ||
|
||
} | ||
|
||
ret.reverse | ||
} | ||
|
||
print(imperativeNDIndex(linearIndex, shape.product)) | ||
assert(denseTensorN.ndIndex(linearIndex) == imperativeNDIndex(linearIndex, shape.product)) | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isValidIndex should be extracted out as a method