Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added fused Batch Norm #99

Merged
merged 2 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions src/main/scala/scanet/core/Shape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {

def rank: Int = dims.size

def axis: List[Int] = dims.indices.toList
def axes: List[Int] = dims.indices.toList

def axisExcept(other: Int*): List[Int] = {
val indexedAxis = indexAxis(other)
(dims.indices.toSet -- indexedAxis.toSet).toList.sorted
def axesExcept(other: Int*): List[Int] = {
val indexedAxes = indexAxes(other)
(dims.indices.toSet -- indexedAxes.toSet).toList.sorted
}

def isScalar: Boolean = rank == 0
Expand Down Expand Up @@ -158,8 +158,8 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
def broadcastableAny(other: Shape): Boolean =
broadcastableBy(other) || other.broadcastableBy(this)

def broadcastableAxis(other: Shape): Seq[Int] = {
require(broadcastableAny(other), s"cannot find broadcastable axis for $this and $other")
def broadcastableAxes(other: Shape): Seq[Int] = {
require(broadcastableAny(other), s"cannot find broadcastable axes for $this and $other")
if (rank < other.rank) {
Seq()
} else {
Expand All @@ -179,62 +179,65 @@ case class Shape(dims: List[Int]) extends Ordered[Shape] {
Shape(dimsResult)
}

def permute(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def permute(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
rank == indexedAxis.size,
rank == indexedAxes.size,
"the number of permutation indexes " +
s"should be equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
s"should be equal to rank $rank, but was (${axes.mkString(", ")})")
Shape(indexedAxes.foldLeft(List[Int]())((permDims, index) => dims(index) :: permDims).reverse)
}

def select(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def select(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of selected axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
Shape(indexedAxis.map(get).toList)
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of selected axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
Shape(indexedAxes.map(get).toList)
}

def remove(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def remove(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of removed axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of removed axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
val filteredDims = dims.zipWithIndex
.filter {
case (_, i) =>
!indexedAxis.contains(i)
!indexedAxes.contains(i)
}
.map { case (dim, _) => dim }
Shape(filteredDims)
}

def updated(axis: Int, value: Int): Shape = updateAll(value)(axis)

def updateAll(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
def updateAll(value: Int)(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
require(
indexedAxis.forall(i => i < rank && i >= 0),
s"the number of updated axis " +
s"should be less or equal to rank $rank, but was (${axis.mkString(", ")})")
indexedAxes.forall(i => i < rank && i >= 0),
s"the number of updated axes " +
s"should be less or equal to rank $rank, but was (${axes.mkString(", ")})")
val updatedDims = dims.zipWithIndex.map {
case (dim, i) =>
if (indexedAxis.contains(i)) value else dim
if (indexedAxes.contains(i)) value else dim
}
Shape(updatedDims)
}

def updateAllExcept(value: Int)(axis: Int*): Shape = {
val indexedAxis = indexAxis(axis)
val axisToUpdate = dims.indices.toSet -- indexedAxis.toSet
updateAll(value)(axisToUpdate.toList: _*)
def updateAllExcept(value: Int)(axes: Int*): Shape = {
val indexedAxes = indexAxes(axes)
val axesToUpdate = dims.indices.toSet -- indexedAxes.toSet
updateAll(value)(axesToUpdate.toList: _*)
}

private def indexAxis(axis: Seq[Int]): Seq[Int] =
axis.map(a => if (a == -1) dims.size - 1 else a)
def indexAxes(axes: Seq[Int]): Seq[Int] =
axes.map(indexAxis)

def indexAxis(axis: Int): Int =
if (axis == -1) dims.size - 1 else axis

def minus(other: Shape): Shape = {
require(broadcastableAny(other), s"cannot $this - $other")
Expand Down
95 changes: 84 additions & 11 deletions src/main/scala/scanet/math/alg/AllKernels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ case class Plus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
parentGrad.sum(shrinkLeftAxis).reshape(left.shape),
parentGrad.sum(shrinkRightAxis).reshape(right.shape))
Expand Down Expand Up @@ -74,8 +74,8 @@ case class Minus[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
List(
parentGrad.sum(shrinkLeftAxis).reshape(left.shape),
-parentGrad.sum(shrinkRightAxis).reshape(right.shape))
Expand Down Expand Up @@ -111,8 +111,8 @@ case class Multiply[A: Numeric] private (left: Expr[A], right: Expr[A]) extends
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
(right.cast[R] * parentGrad).sum(shrinkLeftAxis).reshape(left.shape),
(left.cast[R] * parentGrad).sum(shrinkRightAxis).reshape(right.shape))
Expand All @@ -137,7 +137,7 @@ case class Pow[A: Numeric](expr: Expr[A], exponent: Expr[Float]) extends Expr[A]
}
}

case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] {
case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self =>
override def name: String = "Sqrt"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = expr.shape
Expand All @@ -147,12 +147,46 @@ case class Sqrt[A: Numeric](expr: Expr[A]) extends Expr[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R]
List(local * parentGrad)
// val local = (expr.cast[R] ^ -0.5f) * 0.5f.const.cast[R]
// List(local * parentGrad)
List(SqrtGrad(self.cast[R], parentGrad))
}
}
}

case class SqrtGrad[A: Numeric](sqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] {
override def name: String = "SqrtGrad"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = sqrt.shape
override def inputs: Seq[Expr[_]] = Seq(sqrt, parentGrad)
override def compiler: Compiler[A] = DefaultCompiler[A]()
}

case class Rsqrt[A: Numeric](expr: Expr[A]) extends Expr[A] { self =>
override def name: String = "Rsqrt"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = expr.shape
override def inputs: Seq[Expr[_]] = Seq(expr)
override def compiler: Compiler[A] = DefaultCompiler[A]()
override def localGrad: Grad[A] = new Grad[A] {
override def calc[R: Floating](
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
// val local = (expr.cast[R] ^ -1.5f) * -0.5f.const.cast[R]
// List(local * parentGrad)
List(RsqrtGrad(self.cast[R], parentGrad))
}
}
}

case class RsqrtGrad[A: Numeric](rsqrt: Expr[A], parentGrad: Expr[A]) extends Expr[A] {
override def name: String = "RsqrtGrad"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
override def shape: Shape = rsqrt.shape
override def inputs: Seq[Expr[_]] = Seq(rsqrt, parentGrad)
override def compiler: Compiler[A] = DefaultCompiler[A]()
}

case class Exp[A: Numeric](expr: Expr[A]) extends Expr[A] {
override def name: String = "Exp"
override def tpe: Option[TensorType[A]] = Some(TensorType[A])
Expand Down Expand Up @@ -182,8 +216,8 @@ case class Div[A: Numeric](left: Expr[A], right: Expr[A]) extends Expr[A] {
current: Expr[A],
parentGrad: Expr[R]): Seq[Expr[R]] = {
val parentShape = parentGrad.shape
val shrinkRightAxis = parentShape.broadcastableAxis(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxis(left.shape).toList
val shrinkRightAxis = parentShape.broadcastableAxes(right.shape).toList
val shrinkLeftAxis = parentShape.broadcastableAxes(left.shape).toList
List(
(parentGrad / right.cast[R]).sum(shrinkLeftAxis).reshape(left.shape),
(-left.cast[R] * parentGrad / right.sqr.cast[R])
Expand Down Expand Up @@ -452,6 +486,8 @@ trait AllKernels {

def sqrt[A: Numeric](expr: Expr[A]): Expr[A] = Sqrt(expr)

def rsqrt[A: Numeric](expr: Expr[A]): Expr[A] = Rsqrt(expr)

def sqrtZeroSafe[A: Numeric](out: Expr[A], epsilon: Expr[A]): Expr[A] =
sqrt(plus(out, epsilon))

Expand All @@ -466,6 +502,20 @@ trait AllKernels {
keepDims: Boolean = false): Expr[A] = Mean(expr, axis, keepDims)
def mean[A: Numeric](expr: Expr[A]): Expr[A] = mean(expr, 0 until expr.rank)

def moments[A: Numeric](
expr: Expr[A],
axis: Seq[Int],
keepDims: Boolean = false): (Expr[A], Expr[A]) = {
val m = mean(expr, axis, keepDims)
// try squared_difference, it has optimized kernel op
val v = mean((expr - m).sqr, axis, keepDims)
(m, v)
}

def moments[A: Numeric](
expr: Expr[A]): (Expr[A], Expr[A]) =
moments(expr, 0 until expr.rank)

def max[A: TensorType, C](left: Expr[A], right: C)(implicit c: Convertible[C, Expr[A]]): Expr[A] =
Max(left, c.convert(right))

Expand Down Expand Up @@ -597,6 +647,12 @@ object kernels extends AllKernels {
*/
def sqr: Expr[A] = pow(2.0f)

/** Computes reciprocal (inversed) of square root of x element-wise: `1 / sqrt(x))`
*
* @return tensor `^` -0.5
*/
def rsqrt: Expr[A] = f.rsqrt(expr)

/** Returns square root of the given tensor
*
* {{{Tensor.vector(1.0f, 4.0f, 9.0f).const.sqrt.eval should be(Tensor.vector(1.0f, 2.0f, 3.0f))}}}
Expand Down Expand Up @@ -676,6 +732,23 @@ object kernels extends AllKernels {
*/
def mean: Expr[A] = f.mean(expr)

/** Computes the frequency-weighted mean and variance across dimensions of a tensor.
*
* Reduces `(mean, variance)` along the dimensions given in `axis`.
* The rank of the tensor is reduced by 1 for each entry in `axis`.
*
* @param axis to sum
* @return tensors `(mean, variance)`
*/
def moments(axis: Seq[Int], keepDims: Boolean = false): (Expr[A], Expr[A]) =
f.moments(expr, axis, keepDims)

/** Computes the frequency-weighted mean and variance across all dimensions of a tensor.
* *
* @return tensors `(mean, variance)`
*/
def moments: (Expr[A], Expr[A]) = f.moments(expr)

/** Shuffle dimensions of `out` according to a permutation.
*
* {{{
Expand Down
Loading