Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions stdlib/public/TensorFlow/Random.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,131 @@ public final class NormalDistribution<T : BinaryFloatingPoint>
return mean + standardDeviation * normal01
}
}

@_fixed_layout
public final class BetaDistribution {
public let alpha: Float
public let beta: Float
private let uniformDistribution = UniformFloatingPointDistribution<Float>()

public init(alpha: Float = 0, beta: Float = 1) {
self.alpha = alpha
self.beta = beta
}

public func next<G: RandomNumberGenerator>(using rng: inout G) -> Float {
// Generate a sample using Cheng's sampling algorithm from:
// R. C. H. Cheng, "Generating beta variates with nonintegral shape
// parameters.". Communications of the ACM, 21, 317-322, 1978.
let a = min(alpha, beta)
let b = max(alpha, beta)
if a > 1 {
return BetaDistribution.chengsAlgorithmBB(alpha, a, b, using: &rng)
} else {
return BetaDistribution.chengsAlgorithmBC(alpha, b, a, using: &rng)
}
}

/// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BB
/// algorithm, when both alpha and beta are greater than 1.
///
/// - Parameters:
/// - alpha: First Beta distribution shape parameter.
/// - a: `min(alpha, beta)`.
/// - b: `max(alpha, beta)`.
/// - rng: Random number generator.
///
/// - Returns: Sample obtained using Cheng's BB algorithm.
private static func chengsAlgorithmBB<G: RandomNumberGenerator>(
_ alpha0: Float,
_ a: Float,
_ b: Float,
using rng: inout G
) -> Float {
let alpha = a + b
let beta = sqrt((alpha - 2) / (2 * a * b - alpha))
let gamma = a + 1 / beta

var r: Float = 0.0
var w: Float = 0.0
var t: Float = 0.0

repeat {
let u1 = Float.random(in: 0.0...1.0, using: &rng)
let u2 = Float.random(in: 0.0...1.0, using: &rng)
let v = beta * (log(u1) - log1p(-u1))
r = gamma * v - 1.3862944
let z = u1 * u1 * u2
w = a * exp(v)

let s = a + r - w
if s + 2.609438 >= 5 * z {
break
}

t = log(z)
if s >= t {
break
}
} while r + alpha * (log(alpha) - log(b + w)) < t

w = min(w, Float.greatestFiniteMagnitude)
return a == alpha0 ? w / (b + w) : b / (b + w)
}

/// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC
/// algorithm, when at least one of alpha and beta is less than 1.
///
/// - Parameters:
/// - alpha: First Beta distribution shape parameter.
/// - a: `max(alpha, beta)`.
/// - b: `min(alpha, beta)`.
/// - rng: Random number generator.
///
/// - Returns: Sample obtained using Cheng's BB algorithm.
private static func chengsAlgorithmBC<G: RandomNumberGenerator>(
_ alpha0: Float,
_ a: Float,
_ b: Float,
using rng: inout G
) -> Float {
let alpha = a + b
let beta = 1 / b
let delta = 1 + a - b
let k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778)
let k2 = 0.25 + (0.5 + 0.25 / delta) * b

var w: Float = 0.0

while true {
let u1 = Float.random(in: 0.0...1.0, using: &rng)
let u2 = Float.random(in: 0.0...1.0, using: &rng)
let y = u1 * u2
let z = u1 * y

if u1 < 0.5 {
if 0.25 * u2 + z - y >= k1 {
continue
}
} else {
if z <= 0.25 {
let v = beta * (log(u1) - log1p(-u1))
w = a * exp(v)
break
}
if z >= k2 {
continue
}
}

let v = beta * (log(u1) - log1p(-u1))
w = a * exp(v)
if alpha * (log(alpha) - log(b + 1) + v) - 1.3862944 >= log(z) {
break
}
}

w = min(w, Float.greatestFiniteMagnitude)
return a == alpha0 ? w / (b + w) : b / (b + w)
}
}