Skip to content

[SR-16094] [AutoDiff] Incorrect gradients produced from certain generic functions  #58353

@BradLarson

Description

@BradLarson
Previous ID SR-16094
Radar None
Original Reporter @BradLarson
Type Bug
Additional Detail from JIRA
Votes 0
Component/s Compiler
Labels Bug, AutoDiff
Assignee None
Priority Medium

md5: 12d02582042f49af0b93f5ec66d6fb2f

Issue Description:

We've noticed that specifically constructed generic functions seem to produce different gradients than the same function implemented with concrete types. For example, the following file:

import _Differentiation
import Foundation

extension Dictionary: Differentiable where Value: Differentiable {
    public typealias TangentVector = [Key: Value.TangentVector]
    public mutating func move(by direction: TangentVector) {
        for (componentKey, componentDirection) in direction {
            func fatalMissingComponent() -> Value {
                fatalError("missing component \(componentKey) in moved Dictionary")
            }
            self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
        }
    }
    
    public var zeroTangentVectorInitializer: () -> TangentVector {
        let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
        func initializer() -> Self.TangentVector {
            return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
        }
        return initializer
    }
}

extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
    public static func + (_ lhs: Self, _ rhs: Self) -> Self {
        lhs.merging(rhs, uniquingKeysWith: +)
    }

    public static func - (_ lhs: Self, _ rhs: Self) -> Self {
        lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
    }

    public static var zero: Self { [:] }
}

extension Dictionary where Value: Differentiable {
    // get
    @usableFromInline
    @derivative(of: subscript(_:))
    func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
        // When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
        // every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
        return (self[key], { v in
            if let value = v.value {
                return [key: value]
            }
            else {
                return .zero
            }
        })
    }
 }

public extension Dictionary where Value: Differentiable {
    @differentiable(reverse)
    mutating func set(_ key: Key, to newValue: Value) {
        self[key] = newValue
    }

    @derivative(of: set)
    mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
        self.set(key, to: newValue)
        
        let forwardCount = self.count
        let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
        
        return ((), { v in
            // manual zero tangent initialization
            if v.count < forwardCount {
                v = Self.TangentVector()
                forwardKeys.forEach { v[$0] = .zero }
            }
            
            if let dElement = v[key] {
                v[key] = .zero
                return dElement
            }
            else { // should this fail?
                v[key] = .zero
                return .zero
            }
        })
    }
}


func dictionaryOperationD(of newValues: [String: Double?], on another: inout [String: Double?]) {
    for key in withoutDerivative(at: another.keys) {
        if newValues.keys.contains(key) {
            let value = newValues[key]!
            another.set(key, to: value)
        }
    }
}
@differentiable(reverse)
func testFunctionD(newValues: [String: Double?], dict: [String: Double?]) -> Double {
    var newDict = dict
    dictionaryOperationD(of: newValues, on: &newDict)
    return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}


func dictionaryOperation<DataType>(of newValues: [String: DataType?], on another: inout [String: DataType?])
where DataType: Differentiable {
    for key in withoutDerivative(at: another.keys) {
        if newValues.keys.contains(key) {
            let value = newValues[key]!
            another.set(key, to: value)
        }
    }
}

@differentiable(reverse)
func testFunction(newValues: [String: Double?], dict: [String: Double?]) -> Double {
    var newDict = dict
    dictionaryOperation(of: newValues, on: &newDict)
    return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}

func dictionaryOperationNoInout<DataType>(of newValues: [String: DataType?], on another: [String: DataType?])-> [String: DataType?]
where DataType: Differentiable {
    var another = another
    for key in withoutDerivative(at: another.keys) {
        if newValues.keys.contains(key) {
            let value = newValues[key]!
            another.set(key, to: value)
        }
    }
    return another
}

@differentiable(reverse)
func testFunctionNoInout(newValues: [String: Double?], dict: [String: Double?]) -> Double {
    let newDict = dictionaryOperationNoInout(of: newValues, on: dict)
    return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}


let answerConcreteType = gradient(
    at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
    ["s1": 0.0, "s2": nil, "s3": nil],
    of: testFunctionD).0
let answerGenericType = gradient(
    at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
    ["s1": 0.0, "s2": nil, "s3": nil],
    of: testFunction).0
let answerGenericTypeNoInout = gradient(
    at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
    ["s1": 0.0, "s2": nil, "s3": nil],
    of: testFunctionNoInout).0

print("Answer expected: [\"s1\": 1.0, \"s2\": 2.0, \"s3\": 3.0]")
print("Answer using concrete type: ", answerConcreteType)
print("Answer using generic type: ", answerGenericType)
print("Answer using generic type without inout: ", answerGenericTypeNoInout)

prints out

Answer expected: ["s1": 1.0, "s2": 2.0, "s3": 3.0]
Answer using concrete type:  ["s1": 1.0, "s3": 3.0, "s2": 2.0]
Answer using generic type:  ["s3": 4.0, "s2": 6.0, "s1": 1.0]
Answer using generic type without inout:  ["s1": 6.0, "s3": 3.0, "s2": 5.0]

Where the values for "s1", "s2", and "s3" do not match between generic functions and ones with discrete types, despite the logic for the differentiable functions being the same.

A similar case fails in a different way:

import _Differentiation
import Foundation

extension Dictionary: Differentiable where Value: Differentiable {
    public typealias TangentVector = [Key: Value.TangentVector]
    public mutating func move(by direction: TangentVector) {
        for (componentKey, componentDirection) in direction {
            func fatalMissingComponent() -> Value {
                fatalError("missing component \(componentKey) in moved Dictionary")
            }
            self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
        }
    }
    
    public var zeroTangentVectorInitializer: () -> TangentVector {
        let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
        func initializer() -> Self.TangentVector {
            return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
        }
        return initializer
    }
}

extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
    public static func + (_ lhs: Self, _ rhs: Self) -> Self {
        lhs.merging(rhs, uniquingKeysWith: +)
    }

    public static func - (_ lhs: Self, _ rhs: Self) -> Self {
        lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
    }

    public static var zero: Self { [:] }
}

extension Dictionary where Value: Differentiable {
    // get
    @usableFromInline
    @derivative(of: subscript(_:))
    func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
        // When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
        // every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
        return (self[key], { v in
            if let value = v.value {
                return [key: value]
            }
            else {
                return .zero
            }
        })
    }
 }

public extension Dictionary where Value: Differentiable {
    @differentiable(reverse)
    mutating func set(_ key: Key, to newValue: Value) {
        self[key] = newValue
    }

    @derivative(of: set)
    mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
        self.set(key, to: newValue)
        
        let forwardCount = self.count
        let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
        
        return ((), { v in
            // manual zero tangent initialization
            if v.count < forwardCount {
                v = Self.TangentVector()
                forwardKeys.forEach { v[$0] = .zero }
            }
            
            if let dElement = v[key] {
                v[key] = .zero
                return dElement
            }
            else { // should this fail?
                v[key] = .zero
                return .zero
            }
        })
    }
}


func getD(from newValues: [String: Double?], at key: String) -> Double? {
    if newValues.keys.contains(key) {
        return newValues[key]!
    }
    return nil
}
@differentiable(reverse)
func testFunctionD(newValues: [String: Double?]) -> Double {
    return 1.0 * getD(from: newValues, at: "s1")! +
    2.0 * getD(from: newValues, at: "s2")! +
    3.0 * getD(from: newValues, at: "s3")!
}

func get<DataType>(from newValues: [String: DataType?], at key: String) -> DataType?
where DataType: Differentiable {
    if newValues.keys.contains(key) {
        return newValues[key]!
    }
    return nil
}
@differentiable(reverse)
func testFunction(newValues: [String: Double?]) -> Double {
    return 1.0 * get(from: newValues, at: "s1")! +
    2.0 * get(from: newValues, at: "s2")! +
    3.0 * get(from: newValues, at: "s3")!
}

let answerExpected = [1.0, 2.0, 3.0]
let answerConcreteType = gradient(
    at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
    of: testFunctionD).sorted(by: { $0.key < $1.key }).compactMap(\.value.value)
let answerGenericType = gradient(
    at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
    of: testFunction).sorted(by: { $0.key < $1.key }).compactMap(\.value.value)

print("Answer expected: [\"s1\": 1.0, \"s2\": 2.0, \"s3\": 3.0]")
print("Answer using concrete type: ", answerConcreteType)
print("Answer using generic type: ", answerGenericType)

with a result of

Answer expected: ["s1": 1.0, "s2": 2.0, "s3": 3.0]
Answer using concrete type:  [0.0, 0.0, 0.0]
Answer using generic type:  [1.0, 2.0, 3.0]

The above can probably be simplified down so that the Differentiable Dictionary type isn't necessary to reproduce, but this is one of the cases where we've observed this behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itself

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions