# Solving linear system by nn

Our goal is to solve $Ax = b$ by nn. Of course, we can use LU decomposition. This below is for the validation purpose.

In [1]:
import torch
import torch.nn as nn


In [2]:

A = torch.tensor([3, 5, -2, 13], dtype=torch.float).reshape(2,2)
b = torch.ones(2,1)

x = torch.solve(b,A)
print(f'>>>matrix A:{A}, \n vector b: {b}')
print(f'>>>soln from LU factorization is {x[0]}')

>>>matrix A:tensor([[ 3.,  5.],
        [-2., 13.]]), 
 vector b: tensor([[1.],
        [1.]])
>>>soln from LU factorization is tensor([[0.1633],
        [0.1020]])


- Next, we can use the following nn

In [3]:
net = nn.Sequential(
    nn.Linear(1, 10),
    nn.ReLU(),
    nn.Linear(10,1),
)

def loss_list():  
  ix = torch.tensor([0,1], dtype=torch.float).reshape(2,1)
  return A@net(ix)-b


print_n = 10
n_epoch=500; epoch_per_print= int(n_epoch/print_n)

for epoch in range(n_epoch):
    #ipdb.set_trace()
    loss = sum([a**2. for a in loss_list()]) #forward pass
    #backward propogation
    lr = max(1./(n_epoch+100), 0.001)
    optimizer = torch.optim.SGD(net.parameters(), lr, momentum = .8) 
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % epoch_per_print == 0:
        x_pre = (net(torch.tensor([0,1], dtype=torch.float).reshape(2,1)))          
        print(f'Epoch {epoch+1}, Loss: {loss.item()}, \n soln: [{x_pre[0].item()}, {x_pre[1].item()}]')

Epoch 50, Loss: 3.957065200665966e-05, 
 soln: [0.16491079330444336, 0.1020437628030777]
Epoch 100, Loss: 9.795542155188741e-08, 
 soln: [0.16334718465805054, 0.10204096138477325]
Epoch 150, Loss: 2.404512144948967e-10, 
 soln: [0.16326940059661865, 0.10204081982374191]
Epoch 200, Loss: 9.237055564881302e-13, 
 soln: [0.16326555609703064, 0.10204081982374191]
Epoch 250, Loss: 2.842170943040401e-14, 
 soln: [0.16326534748077393, 0.10204081237316132]
Epoch 300, Loss: 2.842170943040401e-14, 
 soln: [0.16326534748077393, 0.10204081982374191]
Epoch 350, Loss: 0.0, 
 soln: [0.16326533257961273, 0.10204081982374191]
Epoch 400, Loss: 0.0, 
 soln: [0.16326533257961273, 0.10204081982374191]
Epoch 450, Loss: 0.0, 
 soln: [0.16326533257961273, 0.10204081982374191]
Epoch 500, Loss: 0.0, 
 soln: [0.16326533257961273, 0.10204081982374191]


- If the condition number is bad, the nn does not work.