diff --git a/Sources/DeepLearning/Initializers.swift b/Sources/DeepLearning/Initializers.swift index c3c9d14e7..4539b18ea 100644 --- a/Sources/DeepLearning/Initializers.swift +++ b/Sources/DeepLearning/Initializers.swift @@ -129,22 +129,27 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint, Scalar.RawSignificand: FixedWidthInteger { /// Performs Glorot uniform initialization for the specified shape, creating a tensor by /// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`, - /// where limit is `sqrt(6 / (fanIn + fanOut))`. + /// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of + /// input and output features multiplied by the receptive field if present. /// /// - Parameters: /// - shape: The dimensions of the tensor. /// - generator: Random number generator to use. /// init(glorotUniform shape: TensorShape, generator: inout G) { - let fanIn = shape[shape.count - 2] - let fanOut = shape[shape.count - 1] + let spatialDimCount = shape.count - 2 + let receptiveField = shape[0..