In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cpu")

In [3]:
## Full rank matrix
A = torch.randn(8, 8)
A

tensor([[-0.4328,  1.4276, -0.1760, -0.4712, -2.0191,  0.9317, -0.6690,  1.0245],
        [-0.0123, -0.2105,  1.0330,  1.9057,  0.5513, -0.4306,  1.1316, -1.4846],
        [-1.2255, -0.0883, -0.8999,  0.4882,  0.6080,  1.3124, -1.5459,  0.9929],
        [ 0.2848, -0.4424, -0.1292, -0.6310, -0.5312, -0.8484,  0.3398, -1.3651],
        [-0.7149, -0.7609, -0.4445,  1.7718,  0.2339,  0.8045,  0.0708, -2.8123],
        [-1.8124, -1.3133, -1.2716, -1.4437,  1.3278,  1.5659, -0.2026, -0.0375],
        [ 1.0776,  1.8514,  0.3449,  0.6349, -0.4402, -1.4011,  0.9380,  0.5628],
        [-0.2396,  1.1808,  0.1210,  0.6580, -0.1796,  0.4519, -0.8249, -0.2166]])

In [126]:
### pair up (1,2), (3,4), ...
diag0 = torch.randn(4, 2, 2)
diag0

tensor([[[ 0.1695,  1.1962],
         [-0.0480,  1.7277]],

        [[ 0.2971, -0.0095],
         [ 0.7436,  0.1813]],

        [[ 1.8542,  0.6069],
         [-1.5079,  0.0020]],

        [[-0.9866,  1.1867],
         [ 0.0990, -1.5312]]])

In [127]:
torch.block_diag(*diag0)

tensor([[ 0.1695,  1.1962,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0480,  1.7277,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.2971, -0.0095,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.7436,  0.1813,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.8542,  0.6069,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.5079,  0.0020,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.9866,  1.1867],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0990, -1.5312]])

In [290]:
### Pair up (1, 3), (2,4)
## that means rearrange as.. 1,3,2,4 and apply block diagonal

def get_pair(dim, step=1):
    assert 2**int(np.log2(dim)) == dim , "The dim must be power of 2"
    assert isinstance(step, int), "Step must be integer"
    
    blocks = (2**step)
    range_ = dim//blocks
#     print(range_, blocks)
    adder_ = torch.arange(0, range_)*blocks
#     print(adder_)
    
    pairs_ = torch.Tensor([0, blocks//2])
    repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
#     print(pairs_)
#     print(repeat_)
    
    block_map = (pairs_+repeat_).reshape(-1)
#     print(block_map)
    
#     print(block_map+adder_.reshape(-1,1))
    reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
    indx = reorder_for_pair.type(torch.long)
    rev_indx = torch.argsort(indx)
    
    return indx, rev_indx

In [291]:
get_pair(8, 1)

(tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([0, 1, 2, 3, 4, 5, 6, 7]))

In [292]:
get_pair(8, 2)

(tensor([0, 2, 1, 3, 4, 6, 5, 7]), tensor([0, 2, 1, 3, 4, 6, 5, 7]))

In [293]:
get_pair(8, 3)

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [207]:
mat = torch.arange(8, dtype=torch.float).reshape(-1, 1) + torch.arange(8, dtype=torch.float).reshape(1, -1)*0.1
mat

tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
        [1.0000, 1.1000, 1.2000, 1.3000, 1.4000, 1.5000, 1.6000, 1.7000],
        [2.0000, 2.1000, 2.2000, 2.3000, 2.4000, 2.5000, 2.6000, 2.7000],
        [3.0000, 3.1000, 3.2000, 3.3000, 3.4000, 3.5000, 3.6000, 3.7000],
        [4.0000, 4.1000, 4.2000, 4.3000, 4.4000, 4.5000, 4.6000, 4.7000],
        [5.0000, 5.1000, 5.2000, 5.3000, 5.4000, 5.5000, 5.6000, 5.7000],
        [6.0000, 6.1000, 6.2000, 6.3000, 6.4000, 6.5000, 6.6000, 6.7000],
        [7.0000, 7.1000, 7.2000, 7.3000, 7.4000, 7.5000, 7.6000, 7.7000]])

In [265]:
diag0 = torch.randn(4, 2, 2)
diag0.requires_grad=True
bd0 = torch.block_diag(*diag0)
bd0

tensor([[-0.0757,  0.4338,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.3521, -0.5412,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.3217,  0.8534,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.2215, -1.2530,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.5971,  2.2851,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.4778, -0.3428,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.6734,  0.1554],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.3725, -2.3790]],
       grad_fn=<CopySlices>)

In [230]:
diag1 = torch.randn(4, 2, 2)
bd1 = torch.block_diag(*diag1)
# bd1 = mat
ind1, rind1 = get_pair(8, 2)
bd1 = bd1[ind1]
bd1 = bd1[:,ind1]
bd1

tensor([[-0.8887,  0.0000,  0.6940,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.8879,  0.0000,  0.5084,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0682,  0.0000,  0.4320,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0855,  0.0000, -0.2696,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.2088,  0.0000, -0.1409,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.3235,  0.0000, -0.0925],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.8058,  0.0000,  0.4583,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1498,  0.0000, -1.2522]])

In [232]:
(bd0@bd1)

tensor([[ 1.3188, -0.5817, -1.0299, -0.3331,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.9296, -3.0125, -1.5068, -1.7249,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0414,  0.0884,  0.2626, -0.2788,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.1457, -0.2144,  0.9233,  0.6763,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.8262, -0.0757,  0.0963,  0.0216],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.2177,  0.0906, -0.1419, -0.0259],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.3811,  0.7722,  0.2167,  0.8410],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.8800, -0.2050,  0.5005, -0.2232]])

In [257]:
diag2 = torch.randn(4, 2, 2)
bd2 = torch.block_diag(*diag2)
# bd2 = mat
ind2, rind2 = get_pair(8, 3)
# bd2 = bd2[ind2]
# bd2 = bd2[:, ind2]
bd2

tensor([[ 0.4477, -1.3671,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.4485, -0.0302,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  1.5917, -0.0900,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.9070, -1.1521,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.0856, -1.8143,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.1753, -0.1693,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  1.8989,  0.9534],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1629,  0.9740]])

In [262]:
get_pair(8, 3)

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [263]:
torch.scatter

In [264]:
bd2

tensor([[ 0.4477, -1.3671, -1.3671,  0.0000, -1.3671,  0.0000,  0.0000,  0.0000],
        [ 1.4485, -0.0302, -0.0302,  0.0000, -0.0302,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0900, -0.0900,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1521, -1.1521,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0900, -0.0900,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1521, -1.1521,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.1521, -1.1521,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9740]])

In [260]:
bd2[:,rind2] = bd2

In [199]:
ind2, rind2

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [185]:
(bd0@bd1)@bd2

tensor([[ 0.1105, -0.1467,  0.0000,  0.0000,  0.0478,  0.2200,  0.0000,  0.0000],
        [-0.0663,  0.2073,  0.0000,  0.0000, -0.4844, -0.9049,  0.0000,  0.0000],
        [ 0.1666,  0.1975,  0.0000,  0.0000, -0.2654, -0.2873,  0.0000,  0.0000],
        [ 0.3107, -0.3445,  0.0000,  0.0000, -0.5522, -0.3524,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.6435, -2.7394,  0.0000,  0.0000, -0.3413,  2.5342],
        [ 0.0000,  0.0000,  0.2377,  3.1177,  0.0000,  0.0000,  0.2279, -1.4367],
        [ 0.0000,  0.0000, -0.1139,  0.8770,  0.0000,  0.0000, -0.0407,  2.9807],
        [ 0.0000,  0.0000,  0.2386,  0.4714,  0.0000,  0.0000, -0.0431,  1.5147]])

In [184]:
indices

tensor([0, 4, 1, 5, 2, 6, 3, 7])

In [162]:
torch.argsort(indices) ## this is reverse index

tensor([0, 2, 4, 6, 1, 3, 5, 7])

In [350]:
def get_pair_indices(dim, step=1):
    assert 2**int(np.log2(dim)) == dim , "The dim must be power of 2"
    assert isinstance(step, int), "Step must be integer"
    
    blocks = (2**step)
    range_ = dim//blocks
#     print(range_, blocks)
    adder_ = torch.arange(0, range_)*blocks
#     print(adder_)
    
    pairs_ = torch.Tensor([0, blocks//2])
    repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
#     print(pairs_)
#     print(repeat_)
    
    block_map = (pairs_+repeat_).reshape(-1)
#     print(pairs_+repeat_)
    
#     print(block_map+adder_.reshape(-1,1))
    reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
    indx = reorder_for_pair.type(torch.long)
#     rev_indx = torch.argsort(indx)
    
#     print(indx.reshape(-1, 2))
    
#     return indx, rev_indx
    indx = indx.reshape(-1, 2)
    
    map_idx = []
    for idx in indx:
        map_idx.append((idx[0], idx[0]))
        map_idx.append((idx[0], idx[1]))        
        map_idx.append((idx[1], idx[0]))        
        map_idx.append((idx[1], idx[1]))        
    map_idx = torch.LongTensor(map_idx)
#     print(map_idx)
    
    return map_idx

In [341]:
map_idx = get_pair_indices(8, 2)

tensor([0., 2.])
tensor([[0],
        [1]])
tensor([[0., 2.],
        [1., 3.]])
tensor([[0., 2., 1., 3.],
        [4., 6., 5., 7.]])
tensor([[0, 2],
        [1, 3],
        [4, 6],
        [5, 7]])
tensor([[0, 0],
        [0, 2],
        [2, 0],
        [2, 2],
        [1, 1],
        [1, 3],
        [3, 1],
        [3, 3],
        [4, 4],
        [4, 6],
        [6, 4],
        [6, 6],
        [5, 5],
        [5, 7],
        [7, 5],
        [7, 7]])


In [342]:
map_idx.shape

torch.Size([16, 2])

In [343]:
w = torch.randn(map_idx.shape[0],1)

In [344]:
z = torch.zeros(8,8)
z[map_idx.split(1, dim=1)] = w
z

tensor([[-1.7107,  0.0000,  0.3462,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.3085,  0.0000,  0.4741,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.2916,  0.0000,  0.3679,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.5811,  0.0000,  0.5204,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.2372,  0.0000, -0.6719,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.0452,  0.0000, -1.2915],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.4518,  0.0000,  0.9587,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1578,  0.0000,  0.1742]])

In [336]:
torch.nonzero(mat, as_tuple=True)

(tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3,
         3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6,
         6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7]),
 tensor([1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0,
         1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0,
         1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]))

In [345]:
### create 3 matrices with the factorization and multiply to approximate given matrix

In [431]:
A = torch.randn(8, 8)*5
# A.abs_()

In [432]:
w0, w1, w2 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

In [433]:
i0, i1, i2 = [get_pair_indices(8, i+1) for i in range(3)]

In [434]:
# optimizer = torch.optim.Adam([w0, w1, w2], lr=0.001)
optimizer = torch.optim.SGD([w0, w1, w2], lr=0.1)

mse = nn.MSELoss()
def mae(A, B):
    return torch.abs(A-B).mean()

In [435]:
### forward propagation
for i in range(10000):
    W0, W1, W2 = [torch.zeros(8, 8) for _ in range(3)]
    W0[i0.split(1, dim=1)] = w0
    W1[i1.split(1, dim=1)] = w1
    W2[i2.split(1, dim=1)] = w2

    W = (W0@W1)@W2
    loss = mse(W,A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(W,A))}")
#         print(f"The MAE loss is : {float(mae(W,A))}")

The MSE loss is : 25.884742736816406
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607
The MSE loss is : 5.331180095672607


In [436]:
W

tensor([[ 7.7474e+00, -7.2777e+00,  5.0163e+00,  3.2163e+00, -2.4806e+00,
          2.2372e+00,  2.9761e+00, -2.4492e+00],
        [ 6.4035e+00,  2.8651e+00,  4.1461e+00, -1.2662e+00, -2.0503e+00,
         -8.8076e-01,  2.4599e+00,  9.6423e-01],
        [ 1.4621e+00, -2.8083e+00, -3.6816e+00,  6.1329e+00, -4.6812e-01,
          8.6329e-01, -2.1843e+00, -4.6703e+00],
        [ 2.1042e+00,  1.0639e+00, -5.2987e+00, -2.3234e+00, -6.7373e-01,
         -3.2705e-01, -3.1437e+00,  1.7693e+00],
        [-5.0817e+00, -3.0245e+00,  7.7529e+00,  1.1138e+00, -1.0387e+00,
         -2.2539e+00,  7.8916e+00,  2.6609e-02],
        [-2.1521e+00, -9.1208e-01,  3.2833e+00,  3.3587e-01, -4.3990e-01,
         -6.7969e-01,  3.3420e+00,  8.0245e-03],
        [ 1.3431e+01,  1.1496e+00, -3.4477e-01, -2.1678e+00,  2.7454e+00,
          8.5669e-01, -3.5093e-01, -5.1792e-02],
        [-6.2830e+00, -7.5975e+00,  1.6128e-01,  1.4327e+01, -1.2843e+00,
         -5.6617e+00,  1.6416e-01,  3.4229e-01]], grad_fn=<MmBack

In [437]:
A

tensor([[ 8.1538, -7.7868,  1.9195,  3.4968, -4.0525,  0.8201,  5.8274, -1.8625],
        [ 6.0330,  2.5539,  8.0274, -2.3764,  0.2304, -1.2859, -1.2168,  0.0611],
        [ 1.7211, -0.1381, -3.5464,  6.2581,  3.4480,  4.4723, -3.0781, -5.4443],
        [ 1.5552,  5.4676, -5.2873, -0.9995, -4.5474,  0.5965, -2.7003,  1.0306],
        [-6.2331, -3.2357,  7.8637,  1.3368,  7.3891, -1.8760,  8.1506, -1.3125],
        [-3.5689,  0.1516,  2.9712, -0.2827, -0.1083, -2.4201,  2.7800, -0.6187],
        [13.0813,  3.5712,  0.5989, -2.6353,  5.0878, -3.4195,  3.6591,  2.5340],
        [-5.6142, -7.2746,  3.2072, 14.2534, -3.2069, -6.2503,  7.7258,  0.8523]])

In [438]:
W-A

tensor([[-0.4065,  0.5091,  3.0967, -0.2805,  1.5719,  1.4171, -2.8513, -0.5868],
        [ 0.3705,  0.3112, -3.8813,  1.1102, -2.2806,  0.4051,  3.6766,  0.9031],
        [-0.2591, -2.6702, -0.1352, -0.1252, -3.9162, -3.6090,  0.8938,  0.7740],
        [ 0.5491, -4.4037, -0.0114, -1.3239,  3.8737, -0.9236, -0.4434,  0.7387],
        [ 1.1514,  0.2113, -0.1108, -0.2230, -8.4278, -0.3779, -0.2591,  1.3391],
        [ 1.4168, -1.0637,  0.3121,  0.6185, -0.3316,  1.7404,  0.5621,  0.6267],
        [ 0.3498, -2.4216, -0.9437,  0.4675, -2.3425,  4.2762, -4.0101, -2.5858],
        [-0.6688, -0.3228, -3.0460,  0.0736,  1.9226,  0.5886, -7.5617, -0.5101]],
       grad_fn=<SubBackward0>)

### using double matrix

In [455]:
A = torch.randn(8, 8)
# A.abs_()

In [456]:
w0, w1, w2 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

w3, w4, w5 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

In [457]:
i0, i1, i2 = [get_pair_indices(8, i+1) for i in range(3)]

In [458]:
optimizer = torch.optim.Adam([w0, w1, w2], lr=0.001)
# optimizer = torch.optim.SGD([w0, w1, w2, w3, w4, w5], lr=0.1)

mse = nn.MSELoss()
def mae(A, B):
    return torch.abs(A-B).mean()

In [460]:
### forward propagation
for i in range(10000):
    W0, W1, W2, W3, W4, W5 = [torch.zeros(8, 8) for _ in range(6)]
    W0[i0.split(1, dim=1)] = w0
    W1[i1.split(1, dim=1)] = w1
    W2[i2.split(1, dim=1)] = w2
    
    W3[i0.split(1, dim=1)] = w3
    W4[i1.split(1, dim=1)] = w4
    W5[i2.split(1, dim=1)] = w5

    W = ((((W0@W1)@W2)@W3)@W4)@W5
    loss = mse(W,A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(W,A))}")
#         print(f"The MAE loss is : {float(mae(W,A))}")

The MSE loss is : 0.354512095451355
The MSE loss is : 0.35442817211151123
The MSE loss is : 0.3366084098815918
The MSE loss is : 0.3312680125236511
The MSE loss is : 0.3312670886516571
The MSE loss is : 0.3312670886516571
The MSE loss is : 0.3312671482563019
The MSE loss is : 0.33126726746559143
The MSE loss is : 0.33126717805862427
The MSE loss is : 0.3312671482563019


In [461]:
W

tensor([[-0.4585,  0.1179, -0.0854, -0.1305,  0.0665, -0.0224,  0.1827, -0.2260],
        [-0.0794, -0.2044,  0.6706, -0.9132,  0.0160, -0.1102, -1.8848, -0.5498],
        [-0.4361, -0.2286, -0.9940, -0.6881,  0.0653,  0.5506, -0.5307, -1.1167],
        [ 1.1640, -1.0013, -1.4148, -0.1753, -0.1677, -0.0133, -0.9505, -0.2233],
        [-0.3384,  1.5510,  1.4303,  0.3768,  0.0448, -0.0644,  0.1318,  0.4218],
        [ 0.1250, -0.2910,  1.0406,  0.8284, -0.0269, -0.3640, -0.4468, -1.2545],
        [ 2.0066, -0.1582, -0.6990, -0.7972, -0.2819,  0.9129, -1.5735,  1.5679],
        [-0.1782, -0.5537, -0.1705,  0.1694,  0.0241, -0.2428,  0.1207, -0.6610]],
       grad_fn=<MmBackward>)

In [462]:
A

tensor([[-0.6138,  1.2326, -0.2185,  1.1179, -0.8497, -0.3937, -0.4247, -0.7240],
        [-0.1371, -0.4254,  0.7151, -1.3689, -0.0725,  0.0180, -1.7280, -0.2137],
        [-0.3015, -0.2813, -0.9942, -0.9061, -0.9991,  0.3932, -0.0107, -1.4110],
        [ 1.3580, -1.2162, -1.2476, -1.1149,  1.3802, -0.0763, -0.9408,  0.2300],
        [ 0.2501,  1.6509,  0.8738,  1.8718,  0.4026, -0.1844, -0.2245,  1.1366],
        [-0.2321, -0.4036,  1.1845,  0.5598, -0.6597, -0.4030, -0.1734, -1.3939],
        [ 2.0134, -0.2117, -0.6994, -0.4672,  1.1502,  1.1341, -1.7988,  1.6240],
        [-0.0178, -1.5268,  1.2774, -0.1184, -0.9108,  0.2248,  0.5637, -0.4600]])

In [463]:
W-A

tensor([[ 1.5521e-01, -1.1148e+00,  1.3314e-01, -1.2484e+00,  9.1620e-01,
          3.7130e-01,  6.0740e-01,  4.9797e-01],
        [ 5.7690e-02,  2.2091e-01, -4.4549e-02,  4.5573e-01,  8.8576e-02,
         -1.2820e-01, -1.5676e-01, -3.3607e-01],
        [-1.3456e-01,  5.2741e-02,  2.0671e-04,  2.1797e-01,  1.0645e+00,
          1.5746e-01, -5.1999e-01,  2.9428e-01],
        [-1.9399e-01,  2.1491e-01, -1.6725e-01,  9.3962e-01, -1.5479e+00,
          6.3003e-02, -9.6475e-03, -4.5329e-01],
        [-5.8846e-01, -9.9926e-02,  5.5659e-01, -1.4950e+00, -3.5784e-01,
          1.2007e-01,  3.5631e-01, -7.1472e-01],
        [ 3.5716e-01,  1.1257e-01, -1.4399e-01,  2.6857e-01,  6.3276e-01,
          3.8957e-02, -2.7345e-01,  1.3936e-01],
        [-6.8257e-03,  5.3530e-02,  4.2868e-04, -3.2995e-01, -1.4320e+00,
         -2.2117e-01,  2.2531e-01, -5.6055e-02],
        [-1.6039e-01,  9.7311e-01, -1.4479e+00,  2.8773e-01,  9.3484e-01,
         -4.6751e-01, -4.4301e-01, -2.0103e-01]], grad_fn=<SubBac