In [1]:
import torch
import os
import inspect
from time import time

# A.T @ A case

## Simple case

In [25]:
b, m, n = 1, 10, 5 # batch size, model dim, sequence length
A = torch.rand(b, m, n)
A.shape, A.transpose(1, 2).shape

(torch.Size([1, 10, 5]), torch.Size([1, 5, 10]))

In [26]:
def ATA(A):
    return torch.bmm(A.transpose(1, 2), A)

In [27]:
ATA(A).shape

torch.Size([1, 5, 5])

In [28]:
jac = torch.autograd.functional.jacobian(ATA, A)
jac.shape

torch.Size([1, 5, 5, 1, 10, 5])

In [29]:
B = torch.randn(b, n, n)

### Verifying that the three ways yield the same result

In [30]:
torch.matmul(B.view(b*n*n), jac.view(b*n*n, b*m*n)).view(b, m, n) # Solving for the vjp by vectorizing the matrix and flattening the Jacobian, then reshaping it

tensor([[[-1.9905,  0.6171, -3.3195,  0.4265, -2.5245],
         [-4.0342,  0.1170, -3.2399, -0.5155, -2.7586],
         [-4.3139, -0.3710, -0.0815, -2.5825,  0.4590],
         [-7.0037, -1.0059, -4.6378, -2.9415, -2.6776],
         [-5.5451, -1.4243, -1.9190, -2.6157, -1.0399],
         [-4.8951, -0.5864, -1.9706, -3.1078, -0.7024],
         [-5.9275, -1.6060, -2.1951, -2.0454, -1.1969],
         [-5.0533, -0.2408, -0.3195, -2.9782, -0.0384],
         [-2.8512, -0.1231, -1.2354, -0.5520, -1.0963],
         [-3.1871, -0.8985, -1.5581, -1.2299, -0.2537]]])

In [31]:
torch.autograd.functional.vjp(ATA, A, v=B)[1] # Solving for the vjp using torch's autograd

tensor([[[-1.9905,  0.6171, -3.3195,  0.4265, -2.5245],
         [-4.0342,  0.1170, -3.2399, -0.5155, -2.7586],
         [-4.3139, -0.3710, -0.0815, -2.5825,  0.4590],
         [-7.0037, -1.0059, -4.6378, -2.9415, -2.6776],
         [-5.5451, -1.4243, -1.9190, -2.6157, -1.0399],
         [-4.8951, -0.5864, -1.9706, -3.1078, -0.7024],
         [-5.9275, -1.6060, -2.1951, -2.0454, -1.1969],
         [-5.0533, -0.2408, -0.3195, -2.9782, -0.0384],
         [-2.8512, -0.1231, -1.2354, -0.5520, -1.0963],
         [-3.1871, -0.8985, -1.5581, -1.2299, -0.2537]]])

In [32]:
torch.bmm(A, B + B.transpose(1,2)) # Solving for the vjp using simplified method

tensor([[[-1.9905,  0.6171, -3.3195,  0.4265, -2.5245],
         [-4.0342,  0.1170, -3.2399, -0.5155, -2.7586],
         [-4.3139, -0.3710, -0.0815, -2.5825,  0.4590],
         [-7.0037, -1.0059, -4.6378, -2.9415, -2.6776],
         [-5.5451, -1.4243, -1.9190, -2.6157, -1.0399],
         [-4.8951, -0.5864, -1.9706, -3.1078, -0.7024],
         [-5.9275, -1.6060, -2.1951, -2.0454, -1.1969],
         [-5.0533, -0.2408, -0.3195, -2.9782, -0.0384],
         [-2.8512, -0.1231, -1.2354, -0.5520, -1.0963],
         [-3.1871, -0.8985, -1.5581, -1.2299, -0.2537]]])

## Larger case

In [110]:
b, m, n = 256, 1024, 100 # batch size, model dim, sequence length
A = torch.rand(b, m, n)
A.shape, A.transpose(1, 2).shape

(torch.Size([256, 1024, 100]), torch.Size([256, 100, 1024]))

In [111]:
ATA(A).shape

torch.Size([256, 100, 100])

In [112]:
B = torch.randn(b, n, n)

### Time comparison

In [121]:
t1 = time()
torch.autograd.functional.vjp(ATA, A, v=B)
t2 = time()
print(t2 - t1)

0.14750933647155762


In [122]:
t1 = time()
torch.bmm(A, B + B.transpose(1,2))
t2 = time()
print(t2 - t1)

0.05492901802062988


Note: I couldn't solve for the Jacobian because in the larger case because my system ran out of ram and I have 32gb of ram...

## Even larger

In [106]:
b, m, n = 512, 2048, 200 # batch size, model dim, sequence length
A = torch.rand(b, m, n)
A.shape, A.transpose(1, 2).shape

(torch.Size([512, 2048, 200]), torch.Size([512, 200, 2048]))

In [107]:
B = torch.randn(b, n, n)

### Time comparison

In [108]:
t1 = time()
torch.autograd.functional.vjp(ATA, A, v=B)
t2 = time()
print(t2 - t1)

1.5000905990600586


In [109]:
t1 = time()
torch.bmm(A, B + B.transpose(1,2))
t2 = time()
print(t2 - t1)

0.43678879737854004


# General A@B

## Simple case

In [72]:
b, m, n, p = 1, 8, 6, 4# batch size, model dim, sequence length
A = torch.rand(b, m, n)
B = torch.rand(b, n, p)

In [73]:
torch.bmm(A, B).shape

torch.Size([1, 8, 4])

In [74]:
jac = torch.autograd.functional.jacobian(torch.bmm, (A, B))
jac[0].shape, jac[1].shape

(torch.Size([1, 8, 4, 1, 8, 6]), torch.Size([1, 8, 4, 1, 6, 4]))

With two inputs, there are two jacobian outputs: J wrt. A and J wrt. B

In [75]:
C = torch.randn(b, m, p)

In [76]:
torch.matmul(C.view(b*m*p), jac[0].view(b*m*p, b*m*n)).view(b, m, n) # J wrt A

tensor([[[-0.6731, -0.0340, -0.1018,  0.4445, -0.6468,  0.2126],
         [-0.7640, -0.6238, -0.8895, -0.8088, -0.6120, -1.2891],
         [-1.3828, -2.5941, -3.3796, -1.5188, -1.0796, -2.4011],
         [-0.7566, -1.3104, -1.9524, -0.6647, -0.5805, -0.3342],
         [-0.0798,  0.2833,  0.4701, -0.3305, -0.0651, -0.9224],
         [-0.0098, -0.1235,  0.0656,  0.1881, -0.0454, -0.4510],
         [ 0.6521, -0.6137, -0.4850,  0.1626,  0.5680, -0.0510],
         [-1.2129, -0.7720, -0.7436, -0.5206, -1.0802, -2.3379]]])

### Verifying the results 

In [77]:
torch.matmul(C.view(b*m*p), jac[1].view(b*m*p, b*n*p)).view(b, n, p) # J wrt B

tensor([[[-3.9037, -0.6041, -0.4369, -1.6985],
         [-3.2482, -0.1006,  0.3106, -3.2911],
         [-3.1003, -0.9423, -0.6526, -1.4464],
         [-3.7342, -1.8974,  0.5171, -2.9065],
         [-3.2203, -1.8391,  0.1787, -1.4652],
         [-4.3489,  0.0344, -0.4280, -3.4612]]])

In [78]:
torch.autograd.functional.vjp(torch.bmm, (A, B), v=C)[1]

(tensor([[[-0.6731, -0.0340, -0.1018,  0.4445, -0.6468,  0.2126],
          [-0.7640, -0.6238, -0.8895, -0.8088, -0.6120, -1.2891],
          [-1.3828, -2.5941, -3.3796, -1.5188, -1.0796, -2.4011],
          [-0.7566, -1.3104, -1.9524, -0.6647, -0.5805, -0.3342],
          [-0.0798,  0.2833,  0.4701, -0.3305, -0.0651, -0.9224],
          [-0.0098, -0.1235,  0.0656,  0.1881, -0.0454, -0.4510],
          [ 0.6521, -0.6137, -0.4850,  0.1626,  0.5680, -0.0510],
          [-1.2129, -0.7720, -0.7436, -0.5206, -1.0802, -2.3379]]]),
 tensor([[[-3.9037, -0.6041, -0.4369, -1.6985],
          [-3.2482, -0.1006,  0.3106, -3.2911],
          [-3.1003, -0.9423, -0.6526, -1.4464],
          [-3.7342, -1.8974,  0.5171, -2.9065],
          [-3.2203, -1.8391,  0.1787, -1.4652],
          [-4.3489,  0.0344, -0.4280, -3.4612]]]))

In [82]:
torch.bmm(C, B.transpose(1, 2)) # J wrt A

tensor([[[-0.6731, -0.0340, -0.1018,  0.4445, -0.6468,  0.2126],
         [-0.7640, -0.6238, -0.8895, -0.8088, -0.6120, -1.2891],
         [-1.3828, -2.5941, -3.3796, -1.5188, -1.0796, -2.4011],
         [-0.7566, -1.3104, -1.9524, -0.6647, -0.5805, -0.3342],
         [-0.0798,  0.2833,  0.4701, -0.3305, -0.0651, -0.9224],
         [-0.0098, -0.1235,  0.0656,  0.1881, -0.0454, -0.4510],
         [ 0.6521, -0.6137, -0.4850,  0.1626,  0.5680, -0.0510],
         [-1.2129, -0.7720, -0.7436, -0.5206, -1.0802, -2.3379]]])

In [80]:
torch.bmm(A.transpose(1, 2), C) # J wrt B

tensor([[[-3.9037, -0.6041, -0.4369, -1.6985],
         [-3.2482, -0.1006,  0.3106, -3.2911],
         [-3.1003, -0.9423, -0.6526, -1.4464],
         [-3.7342, -1.8974,  0.5171, -2.9065],
         [-3.2203, -1.8391,  0.1787, -1.4652],
         [-4.3489,  0.0344, -0.4280, -3.4612]]])

## Large case

In [100]:
b, m, n, p = 512, 200, 2048, 200# batch size, model dim, sequence length
A = torch.rand(b, m, n)
B = torch.rand(b, n, p)

In [101]:
torch.bmm(A, B).shape

torch.Size([512, 200, 200])

In [102]:
C = torch.randn(b, m, p)

In [103]:
t1 = time()
torch.autograd.functional.vjp(torch.bmm, (A, B), v=C)
t2 = time()
print(t2 - t1)

1.3459370136260986


In [105]:
t1 = time()
torch.bmm(C, B.transpose(1, 2))
torch.bmm(A.transpose(1, 2), C)
t2 = time()
print(t2 - t1)

0.9066958427429199
