In [1]:
import torch
from rlaopt.models import LinSys
from rlaopt.utils import LinOp, SymmetricLinOp
from rlaopt.preconditioners import NewtonConfig, NystromConfig, IdentityConfig
from rlaopt.solvers import PCGConfig, SAPConfig, SAPAccelConfig

In [2]:
torch.set_default_dtype(torch.float64)
torch.manual_seed(0)

<torch._C.Generator at 0x7ff8d83618d0>

In [3]:
device = torch.device("cuda:0")
reg = 0.0001

m = 35000
n=1000

singular_values = torch.arange(1, n + 1, device=device) ** -1.0
U = torch.randn(m, n, device=device)
U, _ = torch.linalg.qr(U)
V = torch.randn(n, n, device=device)
V, _ = torch.linalg.qr(V)
A = U @ torch.diag(singular_values) @ V.T
wStar = torch.randn(n, device=device)/(n**0.5)
b = A @ wStar + 0.1 * torch.randn(m, device=device) / (m ** 0.5)
Atr = A[0:30000,:]
Atst = A[30000:-1,:]
btr = b[0:30000]
btst = b[30000:-1]
#b = torch.randn(m, device=device)

In [4]:
linop = SymmetricLinOp(device=device, shape=(n, n), matvec=lambda x: A.T @ (A @ x))

def A_blk_oracle(blk):
    shape = (blk.shape[0], blk.shape[0])
    return SymmetricLinOp(device=device, shape=shape, matvec=lambda x: A[:, blk].T @ (A[:, blk] @ x))

def A_row_oracle(blk):
    shape = (blk.shape[0], A.shape[1])
    return LinOp(device=device, shape=shape, matvec=lambda x: A[:, blk].T @ (A @ x))

system = LinSys(linop, b=A.T @ b, reg=reg, A_blk_oracle=A_blk_oracle, A_row_oracle=A_row_oracle)

In [None]:
def callback_fn(w, linsys, Atst, btst):
    res = torch.linalg.norm(linsys.b-(linsys.A @ w + linsys.reg * w))
    ntst = btst.shape[0]
    test_loss = 1/ntst * torch.linalg.norm(Atst @ w - btst) ** 2.0
    return {"res": res.item(), "test_loss": test_loss.item()}

In [6]:
# Precond Configs
# nystrom_config = NystromConfig(rank=200, rho=reg)
nystrom_config = NystromConfig(rank=100, rho=reg)
newton_config = NewtonConfig(rho=reg)
identity_config = IdentityConfig()

# Solver Configs
# solver_config = PCGConfig(precond_config=nystrom_config, max_iters=500, atol=1e-6, rtol=1e-6, device=device)
blk_sz = 100
accel_config = SAPAccelConfig(mu=reg, nu=n/blk_sz)
# solver_config = SAPConfig(precond_config=nystrom_config, max_iters=1000, atol=1e-6, rtol=1e-6, device=device,
#                           blk_sz=blk_sz, accel=False, accel_config=accel_config)

In [7]:
for r in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
    nystrom_config = NystromConfig(rank=r, rho=reg)
    solver_config = SAPConfig(precond_config=nystrom_config, max_iters=1000, atol=1e-6, rtol=1e-6, device=device,
                          blk_sz=blk_sz, accel=True, accel_config=accel_config)

    system.solve(solver_config=solver_config, w_init=torch.zeros(n, device=device), 
                    callback_fn=callback_fn, callback_kwargs={"Atst": Atst, "btst": btst}, log_in_wandb=True, wandb_init_kwargs={"project": "test_sap"})

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mpratikrathore8[0m ([33msketchy-opts[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


0,1
cum_time,▁▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇█████
iter_time,▁███████████████████████████████████████

0,1
cum_time,12.30738
iter_time,0.1208


0,1
cum_time,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
iter_time,▂█▅▁▂▃▃▄▂▃▄▂▅▃▄▄▃▄▄▁▄▂▅▄▄▂▂▂▆▁▅▂▁▂▆▃▃█▄▆

0,1
cum_time,10.73506
iter_time,0.13509


0,1
cum_time,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██
iter_time,▁███████████████████████████████████████

0,1
cum_time,9.91412
iter_time,0.15553


0,1
cum_time,▁▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇███
iter_time,▁███████████████████████████████████████

0,1
cum_time,13.12031
iter_time,0.21175


0,1
cum_time,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
iter_time,▁███████████████████████████████████████

0,1
cum_time,16.43679
iter_time,0.23947


0,1
cum_time,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█
iter_time,▄▄▃▃▃▄▄▃▄▃▃▃▃▃▄▆▄▄▃▄▄▃▃▄▃▄▂▃▄▇▃▃▁▂▃█▅▃▄▂

0,1
cum_time,19.34464
iter_time,0.24973


0,1
cum_time,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇█
iter_time,▁███████████████████████████████████████

0,1
cum_time,27.75351
iter_time,0.35348


0,1
cum_time,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█
iter_time,▅▅▃▅▃▃▇▆▃▇▇▄▆█▇▄▇▅▃▁▅▄▇▄▅▅▄▅▅▅▇▇▄▆▄▃▂▆▃▆

0,1
cum_time,35.00491
iter_time,0.3629


0,1
cum_time,▁▁▁▂▂▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇██
iter_time,▁███████████████████████████████████████

0,1
cum_time,37.72069
iter_time,0.37717


0,1
cum_time,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇███
iter_time,█▆▆▁▃▃▅▇▄▃▇▄▃▂▂▁▆▃▇▂▂▂▁▁▅▅▅▅▁▄▄▄▃▇▄▅▇▅▁▃

0,1
cum_time,54.24105
iter_time,0.54469
