In [1]:
import torch
from torch.linalg import svd, svdvals
from jax import random

In [2]:
key = random.PRNGKey(0)
A_key, A_tangent_key, U_cotangent_key, S_cotangent_key, Vh_cotangent_key, Vh_partial_cotangent_key = random.split(key, 6)

A = torch.tensor(random.normal(A_key, (3, 4)).tolist(), requires_grad=True)
U_cotangent = torch.tensor(random.normal(U_cotangent_key, (3, 3)).tolist())
S_cotangent = torch.tensor(random.normal(S_cotangent_key, (3,)).tolist())
Vh_cotangent = torch.tensor(random.normal(Vh_cotangent_key, (4, 4)).tolist())
Vh_partial_cotangent = torch.tensor(random.normal(Vh_partial_cotangent_key, (3, 4)).tolist())



In [3]:
print(A)
print(U_cotangent)
print(S_cotangent)
print(Vh_cotangent)
print(Vh_partial_cotangent)

tensor([[-0.5948, -0.4035,  0.5537,  0.4994],
        [-0.8024, -1.3047, -1.6165,  0.2732],
        [-0.6873,  0.4105, -2.5171,  1.7582]], requires_grad=True)
tensor([[ 1.5044,  0.8469, -0.9443],
        [-1.2587, -0.2435, -0.4819],
        [ 0.0321, -1.2323, -0.0920]])
tensor([-0.5422,  0.2727,  1.6822])
tensor([[-1.3283,  0.7636,  0.4779, -0.7179],
        [ 0.7886,  0.4387, -2.4607,  0.4564],
        [-0.4634, -0.8441,  0.3505,  0.6166],
        [-0.1893, -1.3674, -0.0442,  1.3954]])
tensor([[-1.7423,  1.2947,  0.4945, -0.0816],
        [-0.0124, -0.3420,  0.9646, -0.6674],
        [ 0.6364, -1.9037, -0.3030,  0.8403]])


In [4]:
# svd full
U, S, Vh = svd(A, full_matrices = True)
print()
U.backward(U_cotangent, retain_graph=True)
print(A.grad)
S.backward(S_cotangent, retain_graph=True)
print(A.grad)
Vh.backward(Vh_cotangent, retain_graph=True)
print(A.grad)


tensor([[-0.7083, -1.2774, -0.4001,  0.0017],
        [-0.3437,  0.1760,  0.4340,  0.6456],
        [ 0.0674, -0.2085, -0.1513, -0.2763]])
tensor([[-1.5526, -1.4389,  0.4797,  1.0719],
        [-0.1589,  0.0180,  0.4764,  0.2073],
        [ 0.1243, -0.0631,  0.3610, -0.2964]])
tensor([[-0.4935, -0.3076,  0.2198,  1.3187],
        [ 0.8525,  0.2234, -0.3413, -1.1965],
        [-0.6520,  0.7924,  0.8375,  0.5926]])


In [5]:
A.grad = None

In [6]:
# svd partial
U, S, Vh = svd(A, full_matrices = False)
U.backward(U_cotangent, retain_graph=True)
print(A.grad)
S.backward(S_cotangent, retain_graph=True)
print(A.grad)
Vh.backward(Vh_partial_cotangent, retain_graph=True)
print(A.grad)

tensor([[-0.7083, -1.2774, -0.4001,  0.0017],
        [-0.3437,  0.1760,  0.4340,  0.6456],
        [ 0.0674, -0.2085, -0.1513, -0.2763]])
tensor([[-1.5526, -1.4389,  0.4797,  1.0719],
        [-0.1589,  0.0180,  0.4764,  0.2073],
        [ 0.1243, -0.0631,  0.3610, -0.2964]])
tensor([[ 0.0635, -1.4800,  0.3604,  1.9311],
        [-0.3005,  0.1542, -0.1228, -0.8199],
        [-0.3708, -0.0620,  0.7534,  0.0474]])


In [7]:
A.grad = None

In [8]:
# svd singular values
S = svdvals(A)
S.backward(S_cotangent, retain_graph=True)
print(A.grad)

tensor([[-0.8443, -0.1615,  0.8797,  1.0702],
        [ 0.1848, -0.1580,  0.0424, -0.4382],
        [ 0.0569,  0.1454,  0.5123, -0.0201]])
