diff --git a/stdlib/public/TensorFlow/Random.swift b/stdlib/public/TensorFlow/Random.swift index 74108ad409be2..224f39c3f190c 100644 --- a/stdlib/public/TensorFlow/Random.swift +++ b/stdlib/public/TensorFlow/Random.swift @@ -171,3 +171,131 @@ public final class NormalDistribution return mean + standardDeviation * normal01 } } + +@_fixed_layout +public final class BetaDistribution { + public let alpha: Float + public let beta: Float + private let uniformDistribution = UniformFloatingPointDistribution() + + public init(alpha: Float = 0, beta: Float = 1) { + self.alpha = alpha + self.beta = beta + } + + public func next(using rng: inout G) -> Float { + // Generate a sample using Cheng's sampling algorithm from: + // R. C. H. Cheng, "Generating beta variates with nonintegral shape + // parameters.". Communications of the ACM, 21, 317-322, 1978. + let a = min(alpha, beta) + let b = max(alpha, beta) + if a > 1 { + return BetaDistribution.chengsAlgorithmBB(alpha, a, b, using: &rng) + } else { + return BetaDistribution.chengsAlgorithmBC(alpha, b, a, using: &rng) + } + } + + /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BB + /// algorithm, when both alpha and beta are greater than 1. + /// + /// - Parameters: + /// - alpha: First Beta distribution shape parameter. + /// - a: `min(alpha, beta)`. + /// - b: `max(alpha, beta)`. + /// - rng: Random number generator. + /// + /// - Returns: Sample obtained using Cheng's BB algorithm. + private static func chengsAlgorithmBB( + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G + ) -> Float { + let alpha = a + b + let beta = sqrt((alpha - 2) / (2 * a * b - alpha)) + let gamma = a + 1 / beta + + var r: Float = 0.0 + var w: Float = 0.0 + var t: Float = 0.0 + + repeat { + let u1 = Float.random(in: 0.0...1.0, using: &rng) + let u2 = Float.random(in: 0.0...1.0, using: &rng) + let v = beta * (log(u1) - log1p(-u1)) + r = gamma * v - 1.3862944 + let z = u1 * u1 * u2 + w = a * exp(v) + + let s = a + r - w + if s + 2.609438 >= 5 * z { + break + } + + t = log(z) + if s >= t { + break + } + } while r + alpha * (log(alpha) - log(b + w)) < t + + w = min(w, Float.greatestFiniteMagnitude) + return a == alpha0 ? w / (b + w) : b / (b + w) + } + + /// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC + /// algorithm, when at least one of alpha and beta is less than 1. + /// + /// - Parameters: + /// - alpha: First Beta distribution shape parameter. + /// - a: `max(alpha, beta)`. + /// - b: `min(alpha, beta)`. + /// - rng: Random number generator. + /// + /// - Returns: Sample obtained using Cheng's BB algorithm. + private static func chengsAlgorithmBC( + _ alpha0: Float, + _ a: Float, + _ b: Float, + using rng: inout G + ) -> Float { + let alpha = a + b + let beta = 1 / b + let delta = 1 + a - b + let k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778) + let k2 = 0.25 + (0.5 + 0.25 / delta) * b + + var w: Float = 0.0 + + while true { + let u1 = Float.random(in: 0.0...1.0, using: &rng) + let u2 = Float.random(in: 0.0...1.0, using: &rng) + let y = u1 * u2 + let z = u1 * y + + if u1 < 0.5 { + if 0.25 * u2 + z - y >= k1 { + continue + } + } else { + if z <= 0.25 { + let v = beta * (log(u1) - log1p(-u1)) + w = a * exp(v) + break + } + if z >= k2 { + continue + } + } + + let v = beta * (log(u1) - log1p(-u1)) + w = a * exp(v) + if alpha * (log(alpha) - log(b + 1) + v) - 1.3862944 >= log(z) { + break + } + } + + w = min(w, Float.greatestFiniteMagnitude) + return a == alpha0 ? w / (b + w) : b / (b + w) + } +}