### Тест субградиентного метода

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import time

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
from oracles import LassoNonsmoothOracle
from optimization import subgradient_method

In [4]:
np.random.seed(42)

In [5]:
def generate_data(m, n):
    A = np.random.randn(m, n)
    x_true = np.random.rand(n) - 0.5
    b = A.dot(x_true)
    return A, b, x_true

In [6]:
m, n = 100, 50

In [7]:
A, b, x_true = generate_data(m, n)

In [8]:
regcoef = 0.1

In [9]:
alpha_0 = [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]
x_0 = [np.ones(n) * i + 10 for i in [0., 0.2, 0.4, 0.6, 0.8, 1.0]]

In [10]:
x_star, message, history = subgradient_method(LassoNonsmoothOracle(A, b, regcoef), x_0=x_0[1], tolerance=1e-2, max_iter=10000, alpha_0=alpha_0[4],
                       display=False, trace=True)

In [11]:
message

'iterations_exceeded'

In [12]:
x_star

array([-0.33177242, -0.31087538, -0.03763972, -0.2101219 , -0.24886062,
        0.14471751,  0.15099724,  0.3261392 , -0.07670759, -0.42240751,
       -0.22558826,  0.47816156,  0.09399293,  0.20620073, -0.49323832,
        0.01671281, -0.46073508,  0.06650097,  0.41979232,  0.46617221,
        0.04232614, -0.04154048,  0.44526145,  0.17099869, -0.43632437,
       -0.03282187, -0.25662262, -0.02952387, -0.0470185 ,  0.0776546 ,
       -0.13870012, -0.33221947,  0.00326225,  0.32334386,  0.1672456 ,
       -0.1606394 ,  0.43167001,  0.41061069, -0.28921174, -0.45947086,
        0.23879639,  0.32421828,  0.46439516,  0.38552608,  0.0986093 ,
        0.02621163, -0.4676485 , -0.24890889, -0.03700039,  0.12456722])

In [13]:
x_true

array([-0.33176342, -0.31006526, -0.03928783, -0.21372014, -0.25253702,
        0.14523764,  0.15097175,  0.32451707, -0.08229143, -0.42691491,
       -0.22633362,  0.48000546,  0.09625858,  0.20973708, -0.49264217,
        0.01820573, -0.46247134,  0.06923021,  0.42077085,  0.46841913,
        0.04055268, -0.04179519,  0.44501709,  0.17360983, -0.43510072,
       -0.03427442, -0.25975498, -0.03019568, -0.04453443,  0.0821448 ,
       -0.14322951, -0.33572955,  0.00371943,  0.33020288,  0.16401223,
       -0.16163213,  0.43552302,  0.41220757, -0.2906516 , -0.46058241,
        0.24130871,  0.32761026,  0.46786457,  0.38983639,  0.10053477,
        0.03005234, -0.47215895, -0.25334355, -0.04038525,  0.1285039 ])

In [14]:
np.linalg.norm(x_star - x_true, ord=2)

0.020524751563215746

In [15]:
history['duality_gap']

[366871.43400910666,
 277014.6070091684,
 215552.16374937282,
 170405.02303615733,
 135932.17800260615,
 108963.65284049386,
 87513.46086299204,
 70248.42183165326,
 56232.36613535353,
 44787.969646068974,
 35414.45667382604,
 27734.93844217058,
 21459.991146692817,
 16361.78166711364,
 12255.907747313975,
 8988.05669927665,
 6425.099937157019,
 4449.211666053992,
 2955.2173286157686,
 1850.8807838167156,
 1059.8586750570469,
 522.062675333658,
 193.58597395295718,
 28.560901316673842,
 674.8523270245248,
 20.962099061816403,
 679.0008973468499,
 13.503888163444856,
 654.0338427839797,
 8.219587923188527,
 627.6153786151631,
 4.550136238522185,
 602.9978316772554,
 1.4827460188954333,
 580.9759584900675,
 -1.1319083900660942,
 561.3293908476656,
 -3.3944675734855707,
 543.7419968739906,
 -5.3062261368839145,
 527.7913078677632,
 -6.985673766611733,
 513.4385717039864,
 -8.368692979485914,
 500.276456158579,
 -9.55104995801355,
 488.1772387359387,
 -10.538638420724908,
 476.993711997865