Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

[Linear Algebra] Triangular solve #599

Merged
merged 16 commits into from
Feb 6, 2020
Merged

Conversation

awav
Copy link
Contributor

@awav awav commented Dec 27, 2019

This PR adds triangularSolve operation with its gradient. Unfortunately, TF triangularSolve function doesn't do automatic broadcasting, so I had to implement it. For broadcasting case, I needed an implementation for extracting leading dimension to input tensors - in the code, it is called extractLeadingDims. This function can be useful in general.

PR requires some polishing and open for discussion.

Thanks!

@awav
Copy link
Contributor Author

awav commented Jan 1, 2020

@dan-zheng, @Shashi456 I updated PR

@Shashi456
Copy link
Contributor

@awav could you add a test for extractLeadingDims ?

@awav
Copy link
Contributor Author

awav commented Jan 4, 2020

@Shashi456, I added tests.

@awav
Copy link
Contributor Author

awav commented Jan 6, 2020

@dan-zheng, @Shashi456 Hello all, any comments on the code? :)

@ematejska ematejska requested a review from asuhan January 8, 2020 18:09
@asuhan asuhan self-assigned this Jan 8, 2020
Copy link
Member

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late reply!

I would suggest conforming TensorShape to Collection and RangeReplaceableCollection (#609) instead of manually implementing TensorShape.+(_:_:) static methods. Previous comment here.

Would you be interested in exploring this in a separate PR? I'd say avoiding manual definitions for functions that have default implementations is sufficiently desirable that #609 is a blocker for this PR.

@awav
Copy link
Contributor Author

awav commented Jan 10, 2020

@dan-zheng,

Sorry for the late reply!

No problem!

Would you be interested in exploring this in a separate PR? I'd say avoiding manual definitions for functions that have default implementations is sufficiently desirable that #609 is a blocker for this PR.

Yes, I can do it. Although, I need a bit more input on how to do it properly or just a simple example of how it can be done.

@dan-zheng
Copy link
Member

dan-zheng commented Jan 10, 2020

Yes, I can do it. Although, I need a bit more input on how to do it properly or just a simple example of how it can be done.

Awesome!

You can start by adding an empty extension TensorShape: Collection {} conformance to Sources/TensorFlow/Core/TensorShape.swift. If you try to compile it, the compiler will produce error messages prompting you to implement the missing Collection protocol requirements. Xcode can even fill in protocol requirement skeletons for you.

Do the same for RangeReplaceableCollection: extension TensorShape: RangeReplaceableCollection { ... }. When that conformance is done, you can use the RangeReplaceableCollection.+(_:_:) default implementation!

Please self-assign #609 if you get started.

@awav
Copy link
Contributor Author

awav commented Jan 10, 2020

@dan-zheng, I cannot assign myself to the task in Github's swift-apis. Guess, I'm not the project member.

Also, according to your link, + uses the append method, that means that TensorShape should be modifiable. Right now, TensorShape is immutable, which is not a bad thing.

@dan-zheng
Copy link
Member

dan-zheng commented Jan 11, 2020

@dan-zheng, I cannot assign myself to the task in Github's swift-apis. Guess, I'm not the project member.

Ah, that's right. Please comment on #609 so I can assign you to the issue! Since you're not a maintainer of this GitHub repository, you need to comment on the issue before you can be assigned.

Also, according to your link, + uses the append method, that means that TensorShape should be modifiable. Right now, TensorShape is immutable, which is not a bad thing.

I think you may be confusing a few concepts. In Swift, we don't typically say that types like TensorShape are mutable or immutable - instead, values are mutable (e.g. var declarations) or immutable (e.g. let declarations).

Types are categorized into value types and reference types. Value types (structs and enums) have value semantics, while reference types (classes) have reference semantics. I'd highly recommend checking out the article linked above to learn more, if you're not familiar with these concepts!

Value types like TensorShape can define mutating methods like RangeReplaceableCollection.append(_:). These methods can only be applied to mutable values (e.g. values declared with var).


Short example of mutating methods:

// A value type.
struct TensorShape {
  mutating func mutate() {}
}

var mutable = TensorShape()
mutable.mutate() // okay

let immutable = TensorShape()
immutable.mutate() // not okay
$ swift test.swift
test.swift:10:11: error: cannot use mutating member on immutable value: 'immutable' is a 'let' constant
immutable.mutate() // not okay
~~~~~~~~~ ^
test.swift:9:1: note: change 'let' to 'var' to make it mutable
let immutable = TensorShape()
^~~
var

Co-Authored-By: Dan Zheng <danielzheng@google.com>
Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
awav and others added 2 commits January 15, 2020 11:33
@awav awav requested a review from dan-zheng January 16, 2020 16:09
@awav
Copy link
Contributor Author

awav commented Jan 17, 2020

@dan-zheng, I addressed your comments. The PR should be ready for review/merge. Thanks!

Copy link
Member

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comments could use improvement, but I don't have time to leave detailed comments now. I'll try to address them this weekend.

Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
Sources/TensorFlow/Operators/LinearAlgebra.swift Outdated Show resolved Hide resolved
@awav awav requested a review from dan-zheng January 21, 2020 21:56
) -> Tensor<T> {
precondition(matrix.rank >= 2, "The matrix tensor must have at least rank 2.")
precondition(rhs.rank >= 2, "The rhs tensor must have at least rank 2.")
if matrix.rank < rhs.rank {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you used some reference implementation and tests for triangularSolve and _vjpTriangularSolve in this PR? I'm particularly curious about the extractLeadingDims helper function.

Does triangularSolve match tf.linalg.triangular_solve in behavior, or does it have different/additional functionality?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trisolve from TensorFlow doesn't work for batched cases. tensorflow/tensorflow#26204 (comment). I will make same contribution to the TF2.0 soon as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awav awav requested a review from dan-zheng January 24, 2020 21:23
@awav
Copy link
Contributor Author

awav commented Jan 28, 2020

@dan-zheng, okay, the good thing is that trisolve supports broadcasting in TensorFlow master now (since yesterday).

@awav
Copy link
Contributor Author

awav commented Jan 31, 2020

@dan-zheng, @Shashi456, @asuhan Hello everyone, can someone take a look at this PR? I think it is almost there and needs a little push :) Thanks!

@asuhan
Copy link
Contributor

asuhan commented Jan 31, 2020

@dan-zheng Do we still need to check the trisolve broadcasting semantics?

- Revamp doc comments for `triangularSolve`.
- Add note regarding custom broadcasting support for `triangularSolve`.
  - It should be possible to remove custom broadcasting support in Swift after
    tensorflow/tensorflow@b105944.
- Simplify implementation of `triangularSolve` and `_vjpTriangularSolve`.
- Simplify implementation of `extractLeadingDimensions` utility function.
  - Remove `ignoreLast` parameter. Use `TensorShape.dropLast` on arguments
    instead.
- Simplify and clean up tests.
Copy link
Member

@dan-zheng dan-zheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello everyone, can someone take a look at this PR? I think it is almost there and needs a little push :) Thanks!

So sorry for the delay! It took quite some time to clean up documentation comments and implementations: done in 92da713.

The old doc comments for triangularSolve had room for improvement. The meaning of parameters and the return value are not clear.

/// Solves optionally batched systems of linear equations with upper or lower triangular
/// matrices by backsubstitution.
/// Returns solution to the system `A x = b`. Shape of return matches `b`.
///
/// - Parameters:
///     - matrix: A batched matrix tensor.
///     - rhs: A batched vector tensor.
///     - lower: Boolean option indicating whether the innermost matrices in
///         matrix are lower or upper triangular. Defaults to true.
///     - adjoint: Boolean option indicating whether to solve with matrix or
///         its (block-wise) adjoint. Defaults to False.
/// - Precondition: `matrix` must be a tensor with shape `[..., M, M]`.
/// - Precondition: `rhs` must be a tensor with shape `[..., M, K]`.
@inlinable
@differentiable
public func triangularSolve<T: TensorFlowFloatingPoint>(
    matrix: Tensor<T>,
    rhs: Tensor<T>,
    lower: Bool = true,
    adjoint: Bool = false
) -> Tensor<T> { ... }

Questions raised by the doc comments:

  • How does the equation A x = b relate to the parameters?
    • It's not clear that the parameter matrix refers to A, and that the parameter rhs refers to b.
  • What is the meaning of the return value of triangularSolve?
    • The doc comment says /// Returns solution to the system `A x = b`, but it's not clear that x is the solution.

The new doc comments are more clear:

/// Returns the solution `x` to the system of linear equations represented by `Ax = b`.
///
/// - Parameters:
///   - matrix: The input triangular coefficient matrix, representing `A` in `Ax = b`.
///   - rhs: Right-hand side values, representing `b` in `Ax = b`.
///   - lower: Whether `matrix` is lower triangular (`true`) or upper triangular (`false`). The
///     default value is `true`.
///   - adjoint: If `true`, solve with the adjoint of `matrix` instead of `matrix`. The default
///     value is `false`.
/// - Returns: The solution `x` to the system of linear equations represented by `Ax = b`.
///   `x` has the same shape as `b`.
/// - Precondition: `matrix` must be a tensor with shape `[..., M, M]`.
/// - Precondition: `rhs` must be a tensor with shape `[..., M, K]`.
@inlinable
@differentiable
public func triangularSolve<T: TensorFlowFloatingPoint>(
    matrix: Tensor<T>,
    rhs: Tensor<T>,
    lower: Bool = true,
    adjoint: Bool = false
) -> Tensor<T> { ... }

The documentation for tf.linalg.triangular_solve is quite poor.

I adapted documentation from torch.triangular_solve and scipy.linalg.solve_triangular instead.

Writing idiomatic Swift documentation takes time, so attention there would really help accelerate code review. Thanks for your patience!

@dan-zheng
Copy link
Member

@dan-zheng Do we still need to check the trisolve broadcasting semantics?

@awav mentioned that batched input support for _Raw.matrixTriangularSolve was added in tensorflow/tensorflow@b105944.

I verified that the behavior is consistent with the Swift implementation added in this patch: https://gist.github.com/dan-zheng/6c1404149aa16a4c7082047ce3890ee2


Currently, our version of tensorflow is pinned at v2.1.0-rc1.

We should be able to remove the custom Swift broadcasting code in func triangularSolve after we:

func triangularSolve should then directly forward to _Raw.matrixTriangularSolve.

@dan-zheng dan-zheng merged commit 94e0765 into tensorflow:master Feb 6, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants