-
Notifications
You must be signed in to change notification settings - Fork 133
Conversation
@dan-zheng, @Shashi456 I updated PR |
@awav could you add a test for |
@Shashi456, I added tests. |
@dan-zheng, @Shashi456 Hello all, any comments on the code? :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late reply!
I would suggest conforming TensorShape
to Collection
and RangeReplaceableCollection
(#609) instead of manually implementing TensorShape.+(_:_:)
static methods. Previous comment here.
Would you be interested in exploring this in a separate PR? I'd say avoiding manual definitions for functions that have default implementations is sufficiently desirable that #609 is a blocker for this PR.
No problem!
Yes, I can do it. Although, I need a bit more input on how to do it properly or just a simple example of how it can be done. |
Awesome! You can start by adding an empty Do the same for Please self-assign #609 if you get started. |
@dan-zheng, I cannot assign myself to the task in Github's swift-apis. Guess, I'm not the project member. Also, according to your link, |
Ah, that's right. Please comment on #609 so I can assign you to the issue! Since you're not a maintainer of this GitHub repository, you need to comment on the issue before you can be assigned.
I think you may be confusing a few concepts. In Swift, we don't typically say that types like Types are categorized into value types and reference types. Value types (structs and enums) have value semantics, while reference types (classes) have reference semantics. I'd highly recommend checking out the article linked above to learn more, if you're not familiar with these concepts! Value types like Short example of // A value type.
struct TensorShape {
mutating func mutate() {}
}
var mutable = TensorShape()
mutable.mutate() // okay
let immutable = TensorShape()
immutable.mutate() // not okay $ swift test.swift
test.swift:10:11: error: cannot use mutating member on immutable value: 'immutable' is a 'let' constant
immutable.mutate() // not okay
~~~~~~~~~ ^
test.swift:9:1: note: change 'let' to 'var' to make it mutable
let immutable = TensorShape()
^~~
var |
Co-Authored-By: Dan Zheng <danielzheng@google.com>
Co-Authored-By: Dan Zheng <danielzheng@google.com>
@dan-zheng, I addressed your comments. The PR should be ready for review/merge. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doc comments could use improvement, but I don't have time to leave detailed comments now. I'll try to address them this weekend.
Co-Authored-By: Dan Zheng <danielzheng@google.com>
) -> Tensor<T> { | ||
precondition(matrix.rank >= 2, "The matrix tensor must have at least rank 2.") | ||
precondition(rhs.rank >= 2, "The rhs tensor must have at least rank 2.") | ||
if matrix.rank < rhs.rank { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if you used some reference implementation and tests for triangularSolve
and _vjpTriangularSolve
in this PR? I'm particularly curious about the extractLeadingDims
helper function.
Does triangularSolve
match tf.linalg.triangular_solve
in behavior, or does it have different/additional functionality?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trisolve
from TensorFlow doesn't work for batched cases. tensorflow/tensorflow#26204 (comment). I will make same contribution to the TF2.0 soon as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dan-zheng, the trisolve gradient you can find here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L607
@dan-zheng, okay, the good thing is that |
@dan-zheng, @Shashi456, @asuhan Hello everyone, can someone take a look at this PR? I think it is almost there and needs a little push :) Thanks! |
@dan-zheng Do we still need to check the |
- Revamp doc comments for `triangularSolve`. - Add note regarding custom broadcasting support for `triangularSolve`. - It should be possible to remove custom broadcasting support in Swift after tensorflow/tensorflow@b105944. - Simplify implementation of `triangularSolve` and `_vjpTriangularSolve`. - Simplify implementation of `extractLeadingDimensions` utility function. - Remove `ignoreLast` parameter. Use `TensorShape.dropLast` on arguments instead. - Simplify and clean up tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello everyone, can someone take a look at this PR? I think it is almost there and needs a little push :) Thanks!
So sorry for the delay! It took quite some time to clean up documentation comments and implementations: done in 92da713.
The old doc comments for triangularSolve
had room for improvement. The meaning of parameters and the return value are not clear.
/// Solves optionally batched systems of linear equations with upper or lower triangular
/// matrices by backsubstitution.
/// Returns solution to the system `A x = b`. Shape of return matches `b`.
///
/// - Parameters:
/// - matrix: A batched matrix tensor.
/// - rhs: A batched vector tensor.
/// - lower: Boolean option indicating whether the innermost matrices in
/// matrix are lower or upper triangular. Defaults to true.
/// - adjoint: Boolean option indicating whether to solve with matrix or
/// its (block-wise) adjoint. Defaults to False.
/// - Precondition: `matrix` must be a tensor with shape `[..., M, M]`.
/// - Precondition: `rhs` must be a tensor with shape `[..., M, K]`.
@inlinable
@differentiable
public func triangularSolve<T: TensorFlowFloatingPoint>(
matrix: Tensor<T>,
rhs: Tensor<T>,
lower: Bool = true,
adjoint: Bool = false
) -> Tensor<T> { ... }
Questions raised by the doc comments:
- How does the equation
A x = b
relate to the parameters?- It's not clear that the parameter
matrix
refers toA
, and that the parameterrhs
refers tob
.
- It's not clear that the parameter
- What is the meaning of the return value of
triangularSolve
?- The doc comment says
/// Returns solution to the system `A x = b`
, but it's not clear thatx
is the solution.
- The doc comment says
The new doc comments are more clear:
/// Returns the solution `x` to the system of linear equations represented by `Ax = b`.
///
/// - Parameters:
/// - matrix: The input triangular coefficient matrix, representing `A` in `Ax = b`.
/// - rhs: Right-hand side values, representing `b` in `Ax = b`.
/// - lower: Whether `matrix` is lower triangular (`true`) or upper triangular (`false`). The
/// default value is `true`.
/// - adjoint: If `true`, solve with the adjoint of `matrix` instead of `matrix`. The default
/// value is `false`.
/// - Returns: The solution `x` to the system of linear equations represented by `Ax = b`.
/// `x` has the same shape as `b`.
/// - Precondition: `matrix` must be a tensor with shape `[..., M, M]`.
/// - Precondition: `rhs` must be a tensor with shape `[..., M, K]`.
@inlinable
@differentiable
public func triangularSolve<T: TensorFlowFloatingPoint>(
matrix: Tensor<T>,
rhs: Tensor<T>,
lower: Bool = true,
adjoint: Bool = false
) -> Tensor<T> { ... }
The documentation for tf.linalg.triangular_solve
is quite poor.
I adapted documentation from torch.triangular_solve
and scipy.linalg.solve_triangular
instead.
Writing idiomatic Swift documentation takes time, so attention there would really help accelerate code review. Thanks for your patience!
@awav mentioned that batched input support for I verified that the behavior is consistent with the Swift implementation added in this patch: https://gist.github.com/dan-zheng/6c1404149aa16a4c7082047ce3890ee2 Currently, our version of We should be able to remove the custom Swift broadcasting code in
|
This PR adds
triangularSolve
operation with its gradient. Unfortunately, TF triangularSolve function doesn't do automatic broadcasting, so I had to implement it. For broadcasting case, I needed an implementation for extracting leading dimension to input tensors - in the code, it is calledextractLeadingDims
. This function can be useful in general.PR requires some polishing and open for discussion.
Thanks!