In [1]:
from jax import jvp, vjp, random
import jax.numpy as jnp

In [2]:
key = random.PRNGKey(0)
A_key, A_tangent_key, U_cotangent_key, S_cotangent_key, Vh_cotangent_key, Vh_partial_cotangent_key = random.split(key, 6)

A = random.normal(A_key, (3, 4))
A_tangent = random.normal(A_tangent_key, A.shape)
U_cotangent = random.normal(U_cotangent_key, (3, 3))
S_cotangent = random.normal(S_cotangent_key, (3,))
Vh_cotangent = random.normal(Vh_cotangent_key, (4, 4))
Vh_partial_cotangent = random.normal(Vh_partial_cotangent_key, (3, 4))



In [3]:
print(A)
print(A_tangent)
print(U_cotangent)
print(S_cotangent)
print(Vh_cotangent)
print(Vh_partial_cotangent)

[[-0.5948281  -0.40346822  0.55371755  0.49942812]
 [-0.80238426 -1.3047416  -1.61652     0.27323487]
 [-0.68734825  0.41053954 -2.517128    1.7581573 ]]
[[-1.4027991   0.14532632  1.9614642  -1.2714893 ]
 [-1.1409343  -1.7149451   0.25800338 -1.1909169 ]
 [ 0.48682007  0.08552241 -2.8392208  -1.9545186 ]]
[[ 1.5043879   0.84688544 -0.9443146 ]
 [-1.2586974  -0.24351463 -0.48190218]
 [ 0.03214358 -1.2323356  -0.09203631]]
[-0.54219013  0.2726896   1.6822213 ]
[[-1.3283378   0.7635682   0.47792315 -0.71792597]
 [ 0.7886467   0.4386686  -2.4607036   0.4564042 ]
 [-0.463435   -0.8441072   0.3504501   0.61656743]
 [-0.18929946 -1.3674479  -0.04416634  1.3953555 ]]
[[-1.7422994   1.2946855   0.49448735 -0.08164447]
 [-0.01237171 -0.34202918  0.964581   -0.6673679 ]
 [ 0.63641965 -1.9036711  -0.30302268  0.8402989 ]]


In [4]:
# svd full
svd_full = lambda A: jnp.linalg.svd(A, full_matrices = True)

try:
    jvp(svd_full, (A,), (A_tangent,))
except NotImplementedError as e:
    print(e)

try:
    out, vjp_fun = vjp(svd_full, A)
    vjp_fun((U_cotangent, S_cotangent, Vh_cotangent))
except NotImplementedError as e:
    print(e)

Singular value decomposition JVP not implemented for full matrices
Singular value decomposition JVP not implemented for full matrices


In [5]:
# svd partial
svd_partial = lambda A: jnp.linalg.svd(A, full_matrices = False)

(U, S, Vh), (U_tangent, S_tangent, Vh_tangent) = jvp(svd_partial, (A,), (A_tangent,))

print('svd partial jvp')
print('----------')
print(f'U {U}')
print(f'S {S}')
print(f'Vh {Vh}')
print(f'U tangent {U_tangent}')
print(f'S tangent {S_tangent}')
print(f'Vh tangent {Vh_tangent}')

(U, S, Vh), vjp_fun = vjp(svd_partial, A)
A_cotangent = vjp_fun((U_cotangent, S_cotangent, Vh_partial_cotangent))

print()
print('svd partial vjp')
print('----------')
print(f'U {U}')
print(f'S {S}')
print(f'Vh {Vh}')
print(f'A cotangent {A_cotangent}')


svd partial jvp
----------
U [[-0.00941738  0.24955875  0.96831393]
 [ 0.5082209   0.83515656 -0.21029788]
 [ 0.8611754  -0.49013692  0.1346958 ]]
S [3.5728412  1.5581862  0.99071205]
Vh [[-0.278242   -0.0855764  -0.83811533  0.4613254 ]
 [-0.30912    -0.8930728   0.0140388  -0.32660246]
 [-0.5045093  -0.06157321  0.54211146  0.66917413]]
U tangent [[-0.7349052   1.4758277  -0.38750508]
 [-0.06396279 -0.19327536 -0.922132  ]
 [ 0.02971101  0.42210707  1.3460248 ]]
S tangent [1.0140587 1.9056993 0.460028 ]
Vh tangent [[ 0.16871683 -0.08966165 -0.52082837 -0.861091  ]
 [-1.2613225   0.06035677  1.3714367   1.0877167 ]
 [-0.78280795  2.4763026  -0.90722877  0.37263733]]

svd partial vjp
----------
U [[-0.00941738  0.24955875  0.96831393]
 [ 0.5082209   0.83515656 -0.21029788]
 [ 0.8611754  -0.49013692  0.1346958 ]]
S [3.5728412  1.5581862  0.99071205]
Vh [[-0.278242   -0.0855764  -0.83811533  0.4613254 ]
 [-0.30912    -0.8930728   0.0140388  -0.32660246]
 [-0.5045093  -0.06157321  0.54211

In [6]:
# svd singular values
svd_singular_values = lambda A: jnp.linalg.svd(A, compute_uv = False)

S, S_tangent = jvp(svd_singular_values, (A,), (A_tangent,))

print('svd singular values jvp')
print(f'S {S}')
print(f'S_tangent {S_tangent}')

S, vjp_fun = vjp(svd_singular_values, A)
A_cotangent = vjp_fun(S_cotangent)

print()
print('svd singular values vjp')
print(f'S {S}')
print(f'A cotangent {A_cotangent}')

svd singular values jvp
S [3.5728412  1.5581862  0.99071205]
S_tangent [1.0140587 1.9056993 0.460028 ]

svd singular values vjp
S [3.5728412  1.5581862  0.99071205]
A cotangent (DeviceArray([[-0.84426147, -0.16151014,  0.87973124,  1.0701596 ],
             [ 0.18475075, -0.15802369,  0.04236039, -0.43823138],
             [ 0.05691664,  0.1453695 ,  0.51229316, -0.02012339]],            dtype=float32),)
