-
Notifications
You must be signed in to change notification settings - Fork 10.5k
Closed
Closed
Copy link
Labels
AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.A deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itselfThe Swift compiler itself
Description
Previous ID | SR-16094 |
Radar | None |
Original Reporter | @BradLarson |
Type | Bug |
Additional Detail from JIRA
Votes | 0 |
Component/s | Compiler |
Labels | Bug, AutoDiff |
Assignee | None |
Priority | Medium |
md5: 12d02582042f49af0b93f5ec66d6fb2f
Issue Description:
We've noticed that specifically constructed generic functions seem to produce different gradients than the same function implemented with concrete types. For example, the following file:
import _Differentiation
import Foundation
extension Dictionary: Differentiable where Value: Differentiable {
public typealias TangentVector = [Key: Value.TangentVector]
public mutating func move(by direction: TangentVector) {
for (componentKey, componentDirection) in direction {
func fatalMissingComponent() -> Value {
fatalError("missing component \(componentKey) in moved Dictionary")
}
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
}
}
public var zeroTangentVectorInitializer: () -> TangentVector {
let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
func initializer() -> Self.TangentVector {
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
}
return initializer
}
}
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs, uniquingKeysWith: +)
}
public static func - (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
}
public static var zero: Self { [:] }
}
extension Dictionary where Value: Differentiable {
// get
@usableFromInline
@derivative(of: subscript(_:))
func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
return (self[key], { v in
if let value = v.value {
return [key: value]
}
else {
return .zero
}
})
}
}
public extension Dictionary where Value: Differentiable {
@differentiable(reverse)
mutating func set(_ key: Key, to newValue: Value) {
self[key] = newValue
}
@derivative(of: set)
mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
self.set(key, to: newValue)
let forwardCount = self.count
let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
return ((), { v in
// manual zero tangent initialization
if v.count < forwardCount {
v = Self.TangentVector()
forwardKeys.forEach { v[$0] = .zero }
}
if let dElement = v[key] {
v[key] = .zero
return dElement
}
else { // should this fail?
v[key] = .zero
return .zero
}
})
}
}
func dictionaryOperationD(of newValues: [String: Double?], on another: inout [String: Double?]) {
for key in withoutDerivative(at: another.keys) {
if newValues.keys.contains(key) {
let value = newValues[key]!
another.set(key, to: value)
}
}
}
@differentiable(reverse)
func testFunctionD(newValues: [String: Double?], dict: [String: Double?]) -> Double {
var newDict = dict
dictionaryOperationD(of: newValues, on: &newDict)
return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}
func dictionaryOperation<DataType>(of newValues: [String: DataType?], on another: inout [String: DataType?])
where DataType: Differentiable {
for key in withoutDerivative(at: another.keys) {
if newValues.keys.contains(key) {
let value = newValues[key]!
another.set(key, to: value)
}
}
}
@differentiable(reverse)
func testFunction(newValues: [String: Double?], dict: [String: Double?]) -> Double {
var newDict = dict
dictionaryOperation(of: newValues, on: &newDict)
return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}
func dictionaryOperationNoInout<DataType>(of newValues: [String: DataType?], on another: [String: DataType?])-> [String: DataType?]
where DataType: Differentiable {
var another = another
for key in withoutDerivative(at: another.keys) {
if newValues.keys.contains(key) {
let value = newValues[key]!
another.set(key, to: value)
}
}
return another
}
@differentiable(reverse)
func testFunctionNoInout(newValues: [String: Double?], dict: [String: Double?]) -> Double {
let newDict = dictionaryOperationNoInout(of: newValues, on: dict)
return 1.0 * newDict["s1"]!! + 2.0 * newDict["s2"]!! + 3.0 * newDict["s3"]!!
}
let answerConcreteType = gradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
["s1": 0.0, "s2": nil, "s3": nil],
of: testFunctionD).0
let answerGenericType = gradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
["s1": 0.0, "s2": nil, "s3": nil],
of: testFunction).0
let answerGenericTypeNoInout = gradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
["s1": 0.0, "s2": nil, "s3": nil],
of: testFunctionNoInout).0
print("Answer expected: [\"s1\": 1.0, \"s2\": 2.0, \"s3\": 3.0]")
print("Answer using concrete type: ", answerConcreteType)
print("Answer using generic type: ", answerGenericType)
print("Answer using generic type without inout: ", answerGenericTypeNoInout)
prints out
Answer expected: ["s1": 1.0, "s2": 2.0, "s3": 3.0]
Answer using concrete type: ["s1": 1.0, "s3": 3.0, "s2": 2.0]
Answer using generic type: ["s3": 4.0, "s2": 6.0, "s1": 1.0]
Answer using generic type without inout: ["s1": 6.0, "s3": 3.0, "s2": 5.0]
Where the values for "s1", "s2", and "s3" do not match between generic functions and ones with discrete types, despite the logic for the differentiable functions being the same.
A similar case fails in a different way:
import _Differentiation
import Foundation
extension Dictionary: Differentiable where Value: Differentiable {
public typealias TangentVector = [Key: Value.TangentVector]
public mutating func move(by direction: TangentVector) {
for (componentKey, componentDirection) in direction {
func fatalMissingComponent() -> Value {
fatalError("missing component \(componentKey) in moved Dictionary")
}
self[componentKey, default: fatalMissingComponent()].move(by: componentDirection)
}
}
public var zeroTangentVectorInitializer: () -> TangentVector {
let listOfKeys = self.keys // capturing only what's needed, not the entire self, in order to not waste memory
func initializer() -> Self.TangentVector {
return listOfKeys.reduce(into: [Key: Value.TangentVector]()) { $0[$1] = Value.TangentVector.zero }
}
return initializer
}
}
extension Dictionary: AdditiveArithmetic where Value: AdditiveArithmetic {
public static func + (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs, uniquingKeysWith: +)
}
public static func - (_ lhs: Self, _ rhs: Self) -> Self {
lhs.merging(rhs.mapValues { .zero - $0 }, uniquingKeysWith: +)
}
public static var zero: Self { [:] }
}
extension Dictionary where Value: Differentiable {
// get
@usableFromInline
@derivative(of: subscript(_:))
func vjpSubscriptGet(key: Key) -> (value: Value?, pullback: (Optional<Value>.TangentVector) -> Dictionary<Key, Value>.TangentVector) {
// When adding two dictionaries, nil values are equivalent to zeroes, so there is no need to manually zero-out
// every key's value. Instead, it is faster to create a dictionary with the single non-zero entry.
return (self[key], { v in
if let value = v.value {
return [key: value]
}
else {
return .zero
}
})
}
}
public extension Dictionary where Value: Differentiable {
@differentiable(reverse)
mutating func set(_ key: Key, to newValue: Value) {
self[key] = newValue
}
@derivative(of: set)
mutating func vjpUpdated(_ key: Key, to newValue: Value) -> (value: Void, pullback: (inout TangentVector) -> (Value.TangentVector)) {
self.set(key, to: newValue)
let forwardCount = self.count
let forwardKeys = self.keys // may be heavy to capture all of these, not sure how to do without them though
return ((), { v in
// manual zero tangent initialization
if v.count < forwardCount {
v = Self.TangentVector()
forwardKeys.forEach { v[$0] = .zero }
}
if let dElement = v[key] {
v[key] = .zero
return dElement
}
else { // should this fail?
v[key] = .zero
return .zero
}
})
}
}
func getD(from newValues: [String: Double?], at key: String) -> Double? {
if newValues.keys.contains(key) {
return newValues[key]!
}
return nil
}
@differentiable(reverse)
func testFunctionD(newValues: [String: Double?]) -> Double {
return 1.0 * getD(from: newValues, at: "s1")! +
2.0 * getD(from: newValues, at: "s2")! +
3.0 * getD(from: newValues, at: "s3")!
}
func get<DataType>(from newValues: [String: DataType?], at key: String) -> DataType?
where DataType: Differentiable {
if newValues.keys.contains(key) {
return newValues[key]!
}
return nil
}
@differentiable(reverse)
func testFunction(newValues: [String: Double?]) -> Double {
return 1.0 * get(from: newValues, at: "s1")! +
2.0 * get(from: newValues, at: "s2")! +
3.0 * get(from: newValues, at: "s3")!
}
let answerExpected = [1.0, 2.0, 3.0]
let answerConcreteType = gradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
of: testFunctionD).sorted(by: { $0.key < $1.key }).compactMap(\.value.value)
let answerGenericType = gradient(
at: ["s1": 10.0, "s2": 20.0, "s3": 30.0],
of: testFunction).sorted(by: { $0.key < $1.key }).compactMap(\.value.value)
print("Answer expected: [\"s1\": 1.0, \"s2\": 2.0, \"s3\": 3.0]")
print("Answer using concrete type: ", answerConcreteType)
print("Answer using generic type: ", answerGenericType)
with a result of
Answer expected: ["s1": 1.0, "s2": 2.0, "s3": 3.0]
Answer using concrete type: [0.0, 0.0, 0.0]
Answer using generic type: [1.0, 2.0, 3.0]
The above can probably be simplified down so that the Differentiable Dictionary type isn't necessary to reproduce, but this is one of the cases where we've observed this behavior.
Metadata
Metadata
Assignees
Labels
AutoDiffbugA deviation from expected or documented behavior. Also: expected but undesirable behavior.A deviation from expected or documented behavior. Also: expected but undesirable behavior.compilerThe Swift compiler itselfThe Swift compiler itself