Found that the algorithm converged quite well towards C.dot(C.T) ~= cov_empir. After about 20 epochs, the loss started to be infinite. This is because np.linalg.det(C.dot(C.T)) = inf. Also 
np.linalg.det(np.linalg.inv(C.dot(C.T))) = 0 which you would think would make it non-invertible, but Python still manages to invert it.

In [23]:
import sys, json
from datetime import datetime as dt
import numpy as np
import scipy
from preconditioners.utils import generate_c, generate_centered_gaussian_data
from preconditioners.impl_cov_approx import *

In [24]:
# parameters of the run
n_epochs = 100
iter_per_epoch = 500
tol = 0.01
n = 50
d = 150
ro = 0.5
regime = 'autoregressive'
regul_lambda = 1
lr_start = 0.05
lr_decay = 0.98

params = {
    'n_epochs' : n_epochs,
    'iter_per_epoch' : iter_per_epoch,
    'tol' : tol,
    'n' : n,
    'd' : d,
    'ro' : ro,
    'regime' : regime,
    'regul_lambda' : regul_lambda,
    'lr_start' : lr_start,
    'lr_decay' : lr_decay
}

In [25]:
# generate data and initialization
c = generate_c(ro=ro,
                regime=regime,
                n=n,
                d=d
                )
w_star = np.random.multivariate_normal(mean=np.zeros(d), cov=np.eye(d))
X, y, xi = generate_centered_gaussian_data(w_star,
                                            c,
                                            n=n,
                                            d=d,
                                            sigma2=1,
                                            fix_norm_of_x=False)

# initialize C (cholesky), cov_inv, regul_lambda and learning rate
# is this a good way to initialize?
cov_empir = X.T.dot(X) / n
cov_inv = np.linalg.inv(cov_empir + 0.1 * np.eye(d))
C = scipy.linalg.cholesky(cov_inv) + 0.1 * generate_c(ro=0.1,
                                                        regime='autoregressive',
                                                        n=n,
                                                        d=d,
                                                        )


In [26]:
# run optimization
for epoch in range(n_epochs):
    if epoch == 0:
        lr = lr_start
    else:
        lr = lr * lr_decay

    for i in range(iter_per_epoch):
        # compute loss and gradient
        cov_inv = C.dot(C.T)
        loss_val = loss(cov_inv, X, cov_empir, regul_lambda)
        grad_loss_val = grad_loss(C, cov_empir, regul_lambda, X)
        error = np.linalg.norm(grad_loss_val)

        # update C (+ because we are maximizing)
        C = C + lr * grad_loss_val
        # update regul_lambda
        # regul_lambda = regul_lambda - lr*np.trace(grad_loss_val.dot(grad_loss_val.T))
        # check if we are done
        if i % 50 == 0:
            print(f"iteration {i}/{iter_per_epoch} of epoch {epoch}/{n_epochs}, loss {loss_val} and error {error}")
        if error < tol:
            break

dtstamp = str(dt.now()).replace(' ', '_')
with open(f'results_{dtstamp}.json', 'w') as f:
    json.dump({'C' : C.tolist(), 'loss' : float(loss_val), 'error' : float(error), 'parans' : params}, f)

iteration 0/500 of epoch 0/100, loss -500.427413328037 and error 113.18336157869588
iteration 50/500 of epoch 0/100, loss 201.04715397552118 and error 4.669286738870482
iteration 100/500 of epoch 0/100, loss 244.48914120931843 and error 3.746022632331737
iteration 150/500 of epoch 0/100, loss 274.57352584720087 and error 3.2200924261752055
iteration 200/500 of epoch 0/100, loss 297.6383946118007 and error 2.8682789863581353
iteration 250/500 of epoch 0/100, loss 316.3547650398627 and error 2.6115275126457522
iteration 300/500 of epoch 0/100, loss 332.10772261528655 and error 2.4134468923451036
iteration 350/500 of epoch 0/100, loss 345.7098258470225 and error 2.254599834696683
iteration 400/500 of epoch 0/100, loss 357.6790598933847 and error 2.123524095373614
iteration 450/500 of epoch 0/100, loss 368.36609991283416 and error 2.012961325631798
iteration 0/500 of epoch 1/100, loss 378.01945730672156 and error 1.9180605912861843
iteration 50/500 of epoch 1/100, loss 386.652967850485 and

In [35]:
-np.log(0)

  -np.log(0)


inf

In [30]:
c

array([[1.00000000e+00, 5.00000000e-01, 2.50000000e-01, ...,
        5.60519386e-45, 2.80259693e-45, 1.40129846e-45],
       [5.00000000e-01, 1.00000000e+00, 5.00000000e-01, ...,
        1.12103877e-44, 5.60519386e-45, 2.80259693e-45],
       [2.50000000e-01, 5.00000000e-01, 1.00000000e+00, ...,
        2.24207754e-44, 1.12103877e-44, 5.60519386e-45],
       ...,
       [5.60519386e-45, 1.12103877e-44, 2.24207754e-44, ...,
        1.00000000e+00, 5.00000000e-01, 2.50000000e-01],
       [2.80259693e-45, 5.60519386e-45, 1.12103877e-44, ...,
        5.00000000e-01, 1.00000000e+00, 5.00000000e-01],
       [1.40129846e-45, 2.80259693e-45, 5.60519386e-45, ...,
        2.50000000e-01, 5.00000000e-01, 1.00000000e+00]])

In [28]:
np.linalg.inv(c)

array([[ 1.33333333e+000, -6.66666667e-001, -5.55111512e-017, ...,
         4.14867685e-061,  2.07433843e-061, -1.03716921e-061],
       [-6.66666667e-001,  1.66666667e+000, -6.66666667e-001, ...,
        -5.76344236e-224, -2.88172118e-224,  5.76344236e-224],
       [ 0.00000000e+000, -6.66666667e-001,  1.66666667e+000, ...,
        -2.07649895e-207, -1.03824947e-207,  2.07649895e-207],
       ...,
       [ 0.00000000e+000,  0.00000000e+000,  0.00000000e+000, ...,
         1.66666667e+000, -6.66666667e-001,  1.85037171e-017],
       [ 0.00000000e+000,  0.00000000e+000,  0.00000000e+000, ...,
        -6.66666667e-001,  1.66666667e+000, -6.66666667e-001],
       [ 0.00000000e+000,  0.00000000e+000,  0.00000000e+000, ...,
         0.00000000e+000, -6.66666667e-001,  1.33333333e+000]])

In [43]:
np.linalg.inv(C.dot(C.T))

array([[ 1.17602343,  0.61935177,  0.13669089, ...,  0.20212151,
         0.19575874,  0.15799879],
       [ 0.61935177,  1.47944831,  0.85037222, ...,  0.05756441,
         0.1788989 , -0.10457303],
       [ 0.13669089,  0.85037222,  1.17798244, ..., -0.09120367,
         0.21656216,  0.06504115],
       ...,
       [ 0.20212151,  0.05756441, -0.09120367, ...,  1.17728955,
         0.65638874,  0.48543651],
       [ 0.19575874,  0.1788989 ,  0.21656216, ...,  0.65638874,
         1.17084957,  0.57068323],
       [ 0.15799879, -0.10457303,  0.06504115, ...,  0.48543651,
         0.57068323,  0.80193775]])

In [40]:
cov_empir

array([[ 1.175888  ,  0.61937691,  0.13667743, ...,  0.20214484,
         0.19576334,  0.15800335],
       [ 0.61937691,  1.47931759,  0.85041812, ...,  0.05757058,
         0.1789092 , -0.10459119],
       [ 0.13667743,  0.85041812,  1.1778379 , ..., -0.09121609,
         0.21657191,  0.06504332],
       ...,
       [ 0.20214484,  0.05757058, -0.09121609, ...,  1.17713254,
         0.65642295,  0.48545778],
       [ 0.19576334,  0.1789092 ,  0.21657191, ...,  0.65642295,
         1.1706932 ,  0.57070927],
       [ 0.15800335, -0.10459119,  0.06504332, ...,  0.48545778,
         0.57070927,  0.80176651]])