In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

$$f(x) = \frac{1}{n}\sum_{i=1}^n f_i(x) = \frac{1}{n}\Vert Ax-b \Vert^2$$ 
$$f_i(x) = (A[i]x - b[i])^2$$

In [None]:
# n = [i+1 for i in range(1000)]
n = 1000
m = 20

A = torch.randn(n,m)
x_sol = torch.randn(m)
b = A@x_sol + torch.empty(n).normal_(mean=0,std=1)

def f_i(x, i):
    return (A[i,:]@x - b[i])**2

F = lambda x : [f_i(x, i) for i in range(n)]
f = lambda x : torch.stack(F(x), dim=0).mean(dim=0)

In [None]:
torch.stack(F(x_sol), dim=0).mean(dim=0), torch.norm(torch.matmul(A, x_sol) - b)**2/n

In [None]:
eig = torch.symeig(torch.matmul(2/n*A.T, A), eigenvectors = False)
L = torch.max(eig[0]) #max eigenvalue


L_2 = 1

nu = 0
def var(x_0):
    x = x_0.clone().requires_grad_()#torch.rand_like(x_0, requires_grad = True)
    f_x = f(x)
    f(x).backward()    
    I = torch.randint(n, (50,))
    Nu = []
    for k in range(50):
        x_i = x_0.clone().requires_grad_()
        f_ix = f_i(x_i, I[k])
        f_ix.backward()
        Nu.append(torch.norm(x.grad - x_i.grad)**2)
    return torch.stack(Nu, dim=0).mean(dim=0)      
for k in range(50):
    nu = max(nu, var(torch.randn(m)))
    
L, L_2, nu

In [None]:
from torch.utils.data import TensorDataset, DataLoader


eps = 1e-1
delta = eps**0.25

B = int(eps**-2) # batch size
T = int(delta**-2) #oja iterations
N_1 = int(1/delta/eps)

ds = TensorDataset(A, b)

dl_1 = DataLoader(ds, 1, shuffle=True)
dl_T = DataLoader(ds, T, shuffle=True)
dl_B = DataLoader(ds, B, shuffle=True)

loss_fn =  torch.nn.functional.mse_loss


y_0 = torch.ones(m)

X = []
k = 0
y_k = y_0.clone().requires_grad_()
Y = [y_0]
while(True):    
    
    #Oja's algorithm
    
    eta = np.sqrt(T)
    
    Eig = []
    
    for j in range(10): 
    
        W = []
        w_1 = torch.empty(m).normal_(mean=0,std=1)
        W.append(w_1/torch.norm(w_1))
        for i in range(1, T):
            w_last = W[-1] 

            A_i, b_i = next(iter(dl_1))

            y = y_k.clone().requires_grad_()
            loss = loss_fn(A_i@y, b_i)
            grad, = torch.autograd.grad(loss, y, create_graph=True)
            z = grad @ w_last
            prod, = torch.autograd.grad(z, y)
            w = w_last - eta/L*prod

            W.append(w/torch.norm(w))
        v = W[torch.randint(T, (1,))]
    
        A_i, b_i = next(iter(dl_T))
        y = y_k.clone().requires_grad_()
        loss = loss_fn(A_i@y, b_i)
        grad, = torch.autograd.grad(loss, y, create_graph=True)
        z = grad @ v
        prod, = torch.autograd.grad(z, y)
    
        Eig.append([v, v@prod])
    Eig.sort(key = lambda x : x[1])
    v, prod = Eig[0]
    
    
    if prod <= -delta/2:
        with torch.no_grad():
            if int(torch.randint(2, (1,))) == 1:
                y_k += delta/L_2*v
            else:
                y_k -= delta/L_2*v
    else:
        k+=1
        
        #Natasha 1.5
        
        alpha = eps*delta
        reg_k = lambda x : L*(max(0, torch.norm(x-y_k) - delta/L_2 ))**2 #F^k(x) = f(x) + reg_k(x)
      
        p = int(10*(delta/eps/L)**(2*0.3333)) #sub-epochs amount      
        _m = int(B/p)
        
        X=[]
        
        A_i, b_i = next(iter(dl_B))
        loss = loss_fn(A_i@y_k, b_i)
        mu, = torch.autograd.grad(loss, y_k)
#         with torch.no_grad():
#             y_k.grad.zero_()
        x_hat = y_k.clone()
        for s in range(p):
            x = [x_hat.clone()]
            X.append(x_hat)
            for t in range(_m):
                A_i, b_i = next(iter(dl_1))
                x_t = x[t].clone().requires_grad_()
                F1 = loss_fn(A_i@x_t, b_i) + reg_k(x_t)
                grad1, = torch.autograd.grad(F1, x_t)
                
                F2 = loss_fn(A_i@y_k, b_i) + reg_k(y_k)
                grad2, = torch.autograd.grad(F2, y_k)
                nabla = grad1 - grad2 + mu + 2*delta*(x_t-x_hat)
                
                x.append(x_t - alpha*nabla)
    
            x_hat = torch.stack(x, dim=0).mean(dim=0)
                   
        y_k = X[torch.randint(len(X), (1,))]
        Y.append(y_k.clone().detach())
                
    if k == N_1:
        break
        

    

In [None]:
F = [loss_fn(A@y, b) for y in Y]

In [None]:
plt.plot(F)