In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
device = torch.device("cpu")

In [3]:
## Full rank matrix
A = torch.randn(8, 8)
A

tensor([[ 0.0346,  0.7886,  0.6592, -0.8430,  0.0568,  0.0689, -0.4022, -0.2433],
        [ 0.4507,  0.2289, -2.3694, -0.9718,  0.4013, -0.1704, -0.7195, -1.1159],
        [-0.8502, -1.8070,  0.8830,  0.4443,  0.1140,  0.7160, -1.1811, -1.0146],
        [-1.5754, -0.2129, -0.5018,  0.6399, -1.6658, -1.0708,  1.6115,  0.6246],
        [-1.5554,  0.9266, -0.1000,  2.5593, -0.4279,  0.2168, -0.6194,  0.9913],
        [-0.0793,  0.2866,  1.1565, -1.8631,  1.0846,  0.1787, -0.1888, -0.7852],
        [-0.7552, -1.6622, -0.1509, -0.8613, -0.7550, -0.2900,  1.7001,  1.1235],
        [ 0.2355,  1.0721, -0.6708, -1.2783,  0.7991,  0.9393, -0.4271,  2.8781]])

In [4]:
### pair up (1,2), (3,4), ...
diag0 = torch.randn(4, 2, 2)
diag0

tensor([[[ 0.6218,  1.0594],
         [-0.6093,  0.4330]],

        [[-0.9304,  0.3074],
         [-0.6322,  1.8046]],

        [[-1.3292,  1.3891],
         [-0.6256,  2.2485]],

        [[ 0.1929,  0.1220],
         [ 0.6723, -0.4469]]])

In [5]:
torch.block_diag(*diag0)

tensor([[ 0.6218,  1.0594,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.6093,  0.4330,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.9304,  0.3074,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.6322,  1.8046,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -1.3292,  1.3891,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.6256,  2.2485,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1929,  0.1220],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.6723, -0.4469]])

In [6]:
### Pair up (1, 3), (2,4)
## that means rearrange as.. 1,3,2,4 and apply block diagonal

def get_pair(dim, step=1):
    assert 2**int(np.log2(dim)) == dim , "The dim must be power of 2"
    assert isinstance(step, int), "Step must be integer"
    
    blocks = (2**step)
    range_ = dim//blocks
#     print(range_, blocks)
    adder_ = torch.arange(0, range_)*blocks
#     print(adder_)
    
    pairs_ = torch.Tensor([0, blocks//2])
    repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
#     print(pairs_)
#     print(repeat_)
    
    block_map = (pairs_+repeat_).reshape(-1)
#     print(block_map)
    
#     print(block_map+adder_.reshape(-1,1))
    reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
    indx = reorder_for_pair.type(torch.long)
    rev_indx = torch.argsort(indx)
    
    return indx, rev_indx

In [7]:
get_pair(8, 1)

(tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([0, 1, 2, 3, 4, 5, 6, 7]))

In [8]:
get_pair(8, 2)

(tensor([0, 2, 1, 3, 4, 6, 5, 7]), tensor([0, 2, 1, 3, 4, 6, 5, 7]))

In [9]:
get_pair(8, 3)

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [10]:
mat = torch.arange(8, dtype=torch.float).reshape(-1, 1) + torch.arange(8, dtype=torch.float).reshape(1, -1)*0.1
mat

tensor([[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000],
        [1.0000, 1.1000, 1.2000, 1.3000, 1.4000, 1.5000, 1.6000, 1.7000],
        [2.0000, 2.1000, 2.2000, 2.3000, 2.4000, 2.5000, 2.6000, 2.7000],
        [3.0000, 3.1000, 3.2000, 3.3000, 3.4000, 3.5000, 3.6000, 3.7000],
        [4.0000, 4.1000, 4.2000, 4.3000, 4.4000, 4.5000, 4.6000, 4.7000],
        [5.0000, 5.1000, 5.2000, 5.3000, 5.4000, 5.5000, 5.6000, 5.7000],
        [6.0000, 6.1000, 6.2000, 6.3000, 6.4000, 6.5000, 6.6000, 6.7000],
        [7.0000, 7.1000, 7.2000, 7.3000, 7.4000, 7.5000, 7.6000, 7.7000]])

In [11]:
diag0 = torch.randn(4, 2, 2)
diag0.requires_grad=True
bd0 = torch.block_diag(*diag0)
bd0

tensor([[-1.0913, -0.3770,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-1.4331,  0.7864,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.2266, -0.9781,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.7293, -0.6892,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.2213,  0.9440,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.1742,  0.1825,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9854,  0.0382],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.1571, -1.6586]],
       grad_fn=<CopySlices>)

In [12]:
diag1 = torch.randn(4, 2, 2)
bd1 = torch.block_diag(*diag1)
# bd1 = mat
ind1, rind1 = get_pair(8, 2)
bd1 = bd1[ind1]
bd1 = bd1[:,ind1]
bd1

tensor([[-7.0098e-01,  0.0000e+00,  3.7360e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  1.7099e+00,  0.0000e+00,  4.6721e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.4616e-03,  0.0000e+00,  6.4878e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -1.2576e+00,  0.0000e+00, -6.3018e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.5795e+00,
          0.0000e+00, -4.3898e-01,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0406e+00,  0.0000e+00,  3.1480e-02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.7742e-01,
          0.0000e+00,  9.8855e-01,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -8.1392e-02,  0.0000e+00,  1.9115e-01]])

In [13]:
(bd0@bd1)

tensor([[ 7.6498e-01, -6.4473e-01, -4.0771e-01, -1.7616e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0046e+00,  1.3447e+00, -5.3541e-01,  3.6743e-02,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [-3.3113e-04,  1.2301e+00, -1.4698e-01,  6.1638e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 1.0659e-03,  8.6676e-01,  4.7314e-01,  4.3433e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.9290e+00,
          9.8233e-01, -5.3612e-01,  2.9718e-02],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.7514e-01,
          1.8989e-01,  7.6469e-02,  5.7445e-03],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.7338e-01,
         -3.1122e-03,  9.7417e-01,  7.3093e-03],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  4.3588e-02,
          1.3500e-01,  1.5532e-01, -3.1705e-01]], grad_fn=<MmBack

In [14]:
diag2 = torch.randn(4, 2, 2)
bd2 = torch.block_diag(*diag2)
# bd2 = mat
ind2, rind2 = get_pair(8, 3)
# bd2 = bd2[ind2]
# bd2 = bd2[:, ind2]
bd2

tensor([[ 0.3735, -1.4528,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 2.1253, -1.0149,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.6922, -0.6927,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.3796,  0.2391,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.3294,  0.0854,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5977,  0.4162,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.5768, -0.0057],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.9011, -0.0681]])

In [15]:
get_pair(8, 3)

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [16]:
torch.scatter

<function _VariableFunctionsClass.scatter>

In [17]:
bd2

tensor([[ 0.3735, -1.4528,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 2.1253, -1.0149,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.6922, -0.6927,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.3796,  0.2391,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.3294,  0.0854,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.5977,  0.4162,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.5768, -0.0057],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -1.9011, -0.0681]])

In [20]:
bd2[:,rind2] = bd2.clone()

In [21]:
ind2, rind2

(tensor([0, 4, 1, 5, 2, 6, 3, 7]), tensor([0, 2, 4, 6, 1, 3, 5, 7]))

In [22]:
(bd0@bd1)@bd2

tensor([[-1.0845e+00,  0.0000e+00, -4.5706e-01,  0.0000e+00, -2.8891e-01,
          0.0000e+00,  2.7822e-01,  0.0000e+00],
        [ 3.2331e+00,  0.0000e+00, -2.8243e+00,  0.0000e+00, -3.5667e-01,
          0.0000e+00,  3.7968e-01,  0.0000e+00],
        [ 2.6141e+00,  0.0000e+00, -1.2479e+00,  0.0000e+00,  1.3223e-01,
          0.0000e+00,  2.4923e-01,  0.0000e+00],
        [ 1.8425e+00,  0.0000e+00, -8.8121e-01,  0.0000e+00,  4.9238e-01,
          0.0000e+00, -2.2389e-01,  0.0000e+00],
        [ 0.0000e+00, -3.1516e+00,  0.0000e+00,  2.4415e-01,  0.0000e+00,
         -3.6575e-01,  0.0000e+00,  1.0321e-03],
        [ 0.0000e+00,  2.5227e-01,  0.0000e+00,  1.0251e-01,  0.0000e+00,
          3.3190e-02,  0.0000e+00, -8.2742e-04],
        [ 0.0000e+00,  3.6529e-01,  0.0000e+00,  2.2041e-02,  0.0000e+00,
          5.4804e-01,  0.0000e+00, -6.0527e-03],
        [ 0.0000e+00, -2.2748e-02,  0.0000e+00,  5.9903e-02,  0.0000e+00,
          6.9235e-01,  0.0000e+00,  2.0717e-02]], grad_fn=<MmBack

In [23]:
# indices

NameError: name 'indices' is not defined

In [25]:
# torch.argsort(indices) ## this is reverse index

In [26]:
def get_pair_indices(dim, step=1):
    assert 2**int(np.log2(dim)) == dim , "The dim must be power of 2"
    assert isinstance(step, int), "Step must be integer"
    
    blocks = (2**step)
    range_ = dim//blocks
#     print(range_, blocks)
    adder_ = torch.arange(0, range_)*blocks
#     print(adder_)
    
    pairs_ = torch.Tensor([0, blocks//2])
    repeat_ = torch.arange(0, blocks//2).reshape(-1,1)
#     print(pairs_)
#     print(repeat_)
    
    block_map = (pairs_+repeat_).reshape(-1)
#     print(pairs_+repeat_)
    
#     print(block_map+adder_.reshape(-1,1))
    reorder_for_pair = (block_map+adder_.reshape(-1,1)).reshape(-1)
    indx = reorder_for_pair.type(torch.long)
#     rev_indx = torch.argsort(indx)
    
#     print(indx.reshape(-1, 2))
    
#     return indx, rev_indx
    indx = indx.reshape(-1, 2)
    
    map_idx = []
    for idx in indx:
        map_idx.append((idx[0], idx[0]))
        map_idx.append((idx[0], idx[1]))        
        map_idx.append((idx[1], idx[0]))        
        map_idx.append((idx[1], idx[1]))        
    map_idx = torch.LongTensor(map_idx)
#     print(map_idx)
    
    return map_idx

In [27]:
map_idx = get_pair_indices(8, 2)

In [28]:
map_idx.shape

torch.Size([16, 2])

In [29]:
w = torch.randn(map_idx.shape[0],1)

In [30]:
z = torch.zeros(8,8)
z[map_idx.split(1, dim=1)] = w
z

tensor([[-6.4154e-01,  0.0000e+00, -4.5817e-01,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  2.9299e-02,  0.0000e+00,  1.4902e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 5.9531e-01,  0.0000e+00,  1.0453e-03,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.4374e-01,  0.0000e+00, -3.4759e-01,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -5.2312e-01,
          0.0000e+00,  2.4168e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          2.9902e-01,  0.0000e+00, -2.6994e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  3.0292e-01,
          0.0000e+00, -2.9500e-02,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          6.3270e-01,  0.0000e+00, -1.6289e-01]])

In [31]:
torch.nonzero(mat, as_tuple=True)

(tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3,
         3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6,
         6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7]),
 tensor([1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0,
         1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0,
         1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]))

### create 3 matrices with the factorization and multiply to approximate given matrix

In [66]:
A = torch.randn(8, 8)#*5
# A.abs_()

In [67]:
w0, w1, w2 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

In [68]:
i0, i1, i2 = [get_pair_indices(8, i+1) for i in range(3)]

In [82]:
# optimizer = torch.optim.Adam([w0, w1, w2], lr=0.001)
optimizer = torch.optim.SGD([w0, w1, w2], lr=0.1)

mse = nn.MSELoss()
def mae(A, B):
    return torch.abs(A-B).mean()

In [83]:
### forward propagation
for i in range(10000):
    W0, W1, W2 = [torch.zeros(8, 8) for _ in range(3)]
    W0[i0.split(1, dim=1)] = w0
    W1[i1.split(1, dim=1)] = w1
    W2[i2.split(1, dim=1)] = w2

    W = (W0@W1)@W2
    loss = mse(W,A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(W,A))}")
#         print(f"The MAE loss is : {float(mae(W,A))}")

The MSE loss is : 95.1694107055664
The MSE loss is : 0.3550967276096344
The MSE loss is : 0.25570911169052124
The MSE loss is : 0.24937587976455688
The MSE loss is : 0.2486097514629364
The MSE loss is : 0.248269721865654
The MSE loss is : 0.2479921281337738
The MSE loss is : 0.24772588908672333
The MSE loss is : 0.2474612593650818
The MSE loss is : 0.2471964955329895


In [84]:
W

tensor([[-0.6135, -0.2089,  1.2606,  0.0628, -0.0886,  0.2485, -1.5808,  0.0597],
        [-0.5339, -0.9755,  1.0970,  0.2931, -0.0771,  1.1603, -1.3756,  0.2785],
        [-1.3683, -0.0443,  0.7879,  0.3918, -0.1977,  0.0528, -0.9881,  0.3723],
        [-0.9276,  0.1209,  0.5341, -1.0684, -0.1340, -0.1438, -0.6698, -1.0151],
        [-2.6020, -1.9169, -0.1303,  0.6355,  0.2276, -0.6601,  0.6747, -0.9441],
        [-0.2067, -1.2088, -0.0103,  0.4007,  0.0181, -0.4163,  0.0536, -0.5953],
        [ 1.4727,  0.4776, -0.1237, -1.4758, -0.1288,  0.1645,  0.6410,  2.1924],
        [ 0.3295,  0.0402, -0.0277, -0.1241, -0.0288,  0.0138,  0.1434,  0.1844]],
       grad_fn=<MmBackward>)

In [85]:
A

tensor([[-0.1554,  0.0617,  1.2094,  1.2875, -0.7553,  0.3535, -1.7621, -0.7190],
        [-0.9996, -1.1121,  1.4841,  0.2772,  0.2682,  1.0717, -0.9055,  0.1860],
        [-0.8472,  1.0142,  1.1018,  0.8449,  0.7419,  0.1684, -1.6469,  0.0050],
        [-1.7313, -0.1257, -0.6033, -0.8346, -1.2780, -0.6352, -0.2357, -1.2210],
        [-2.6038, -2.1522,  0.2725,  0.8369,  0.3256, -0.5613,  0.7126, -0.3999],
        [-0.0182, -0.8411, -0.3104,  0.2436,  0.6893, -0.5572,  0.4961, -1.3491],
        [ 1.4629,  0.4064, -0.7068, -1.4039, -0.0166,  0.1767,  0.5719,  2.2554],
        [ 0.4780,  0.7236, -1.1250, -0.4562,  0.6670,  0.3411, -0.2624, -0.2126]])

In [86]:
W-A

tensor([[-0.4582, -0.2706,  0.0512, -1.2248,  0.6666, -0.1050,  0.1813,  0.7786],
        [ 0.4657,  0.1367, -0.3872,  0.0159, -0.3453,  0.0886, -0.4701,  0.0926],
        [-0.5211, -1.0586, -0.3138, -0.4531, -0.9396, -0.1157,  0.6588,  0.3673],
        [ 0.8038,  0.2467,  1.1374, -0.2338,  1.1440,  0.4913, -0.4341,  0.2059],
        [ 0.0017,  0.2353, -0.4027, -0.2014, -0.0980, -0.0988, -0.0379, -0.5442],
        [-0.1885, -0.3677,  0.3000,  0.1571, -0.6712,  0.1409, -0.4425,  0.7538],
        [ 0.0098,  0.0712,  0.5830, -0.0720, -0.1123, -0.0122,  0.0691, -0.0630],
        [-0.1486, -0.6834,  1.0973,  0.3321, -0.6958, -0.3273,  0.4058,  0.3970]],
       grad_fn=<SubBackward0>)

### using double matrix

In [87]:
# A = torch.randn(8, 8)
# A.abs_()

In [88]:
w0, w1, w2 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

w3, w4, w5 = [torch.randn(16, 1, requires_grad=True) for _ in range(3)]

In [89]:
i0, i1, i2 = [get_pair_indices(8, i+1) for i in range(3)]

In [90]:
# optimizer = torch.optim.Adam([w0, w1, w2, w3, w4, w5], lr=0.001)
optimizer = torch.optim.SGD([w0, w1, w2, w3, w4, w5], lr=0.1)

mse = nn.MSELoss()
def mae(A, B):
    return torch.abs(A-B).mean()

In [91]:
### forward propagation
for i in range(10000):
    W0, W1, W2, W3, W4, W5 = [torch.zeros(8, 8) for _ in range(6)]
    W0[i0.split(1, dim=1)] = w0
    W1[i1.split(1, dim=1)] = w1
    W2[i2.split(1, dim=1)] = w2
    
    W3[i0.split(1, dim=1)] = w3
    W4[i1.split(1, dim=1)] = w4
    W5[i2.split(1, dim=1)] = w5

    W = ((((W0@W1)@W2)@W3)@W4)@W5
    loss = mse(W,A)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%1000 == 0:
        print(f"The MSE loss is : {float(mse(W,A))}")
#         print(f"The MAE loss is : {float(mae(W,A))}")

The MSE loss is : 1.6350189447402954
The MSE loss is : 0.10431253165006638
The MSE loss is : 0.07465889304876328
The MSE loss is : 0.06737516820430756
The MSE loss is : 0.06733483076095581
The MSE loss is : 0.0655464380979538
The MSE loss is : 0.06363171339035034
The MSE loss is : 0.06146462261676788
The MSE loss is : 0.05824487656354904
The MSE loss is : 0.053767140954732895


In [92]:
W

tensor([[-0.7124,  0.2008,  1.2148,  1.1886, -0.0414,  0.1873, -1.6268, -0.7539],
        [-0.9017, -1.1337,  1.5159,  0.3996,  0.0411,  1.0895, -1.0415,  0.0470],
        [-0.3593,  0.9518,  1.1445,  0.9539,  0.1548,  0.0188, -1.7421,  0.0740],
        [-1.6321, -0.1508, -0.5788, -0.8717, -1.4465, -0.5183, -0.2687, -1.1869],
        [-2.4294, -2.2024,  0.4334,  0.7805,  0.2845, -0.7702,  0.8017, -0.2546],
        [-0.1337, -0.8390, -0.5693,  0.2881,  0.4621, -0.4359,  0.3310, -1.2743],
        [ 1.4105,  0.3863, -0.8166, -1.3830,  0.1231,  0.1715,  0.4521,  2.2348],
        [ 0.9759,  0.5231, -0.5804, -0.3438,  0.0701,  0.1113,  0.0951, -0.2563]],
       grad_fn=<MmBackward>)

In [93]:
A

tensor([[-0.1554,  0.0617,  1.2094,  1.2875, -0.7553,  0.3535, -1.7621, -0.7190],
        [-0.9996, -1.1121,  1.4841,  0.2772,  0.2682,  1.0717, -0.9055,  0.1860],
        [-0.8472,  1.0142,  1.1018,  0.8449,  0.7419,  0.1684, -1.6469,  0.0050],
        [-1.7313, -0.1257, -0.6033, -0.8346, -1.2780, -0.6352, -0.2357, -1.2210],
        [-2.6038, -2.1522,  0.2725,  0.8369,  0.3256, -0.5613,  0.7126, -0.3999],
        [-0.0182, -0.8411, -0.3104,  0.2436,  0.6893, -0.5572,  0.4961, -1.3491],
        [ 1.4629,  0.4064, -0.7068, -1.4039, -0.0166,  0.1767,  0.5719,  2.2554],
        [ 0.4780,  0.7236, -1.1250, -0.4562,  0.6670,  0.3411, -0.2624, -0.2126]])

In [94]:
W-A

tensor([[-0.5571,  0.1391,  0.0054, -0.0989,  0.7138, -0.1661,  0.1352, -0.0350],
        [ 0.0978, -0.0215,  0.0318,  0.1224, -0.2272,  0.0178, -0.1360, -0.1390],
        [ 0.4880, -0.0625,  0.0427,  0.1090, -0.5871, -0.1496, -0.0952,  0.0690],
        [ 0.0992, -0.0250,  0.0244, -0.0371, -0.1685,  0.1168, -0.0330,  0.0342],
        [ 0.1744, -0.0502,  0.1610, -0.0564, -0.0411, -0.2088,  0.0891,  0.1453],
        [-0.1155,  0.0021, -0.2589,  0.0445, -0.2271,  0.1212, -0.1651,  0.0748],
        [-0.0524, -0.0201, -0.1098,  0.0209,  0.1396, -0.0052, -0.1198, -0.0206],
        [ 0.4978, -0.2005,  0.5446,  0.1124, -0.5968, -0.2298,  0.3575, -0.0437]],
       grad_fn=<SubBackward0>)