In [None]:
import torch
from lib.optimizer.pcg import conjugate_gradient

N = 400
_A = torch.rand((N, N))
A = torch.sqrt(_A.T @ _A)  # positive semidefinite and symmetric
b = torch.rand((N))
# x_pcg  = conjugate_gradient(A, b, verbose=True, max_iter=100)
# x_optim= torch.linalg.solve(A, b)
# pcg_norm = ((A @ x_pcg) - b).norm()
# optim_norm  = ((A @ x_optim) - b).norm()
# print(pcg_norm, optim_norm)

In [None]:
import timeit
import torch
import torchvision.models as models
from torch.profiler import profile, record_function, ProfilerActivity

t = 100
A = A.to("cuda")
b = b.to("cuda")


def foo():
    conjugate_gradient(A, b, verbose=True, max_iter=5)
    # torch.linalg.solve(A, b)


with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False
) as prof:
    # foo()
    timeit.timeit(foo, number=t) / t

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

In [None]:
# from lib.optimizer.pcg import conjugate_gradient
import torch
import timeit


def conjugate_gradient(
    A: torch.Tensor,  # dim (N,N)
    b: torch.Tensor,  # dim (N)
    x0: torch.Tensor | None = None,  # dim (N)
    max_iter: int = 20,
    verbose: bool = False,
    tol: float = 1e-08,
):
    k = 0
    converged = False

    xk = torch.zeros_like(b) if x0 is None else x0  # (N)
    rk = b - A @ xk  # column vector (N)
    pk = rk

    if torch.norm(rk) < tol:
        converged = True

    while k < max_iter and not converged:

        # compute step size
        ak = (rk[None] @ rk) / (pk[None] @ A @ pk)
        # update unknowns
        xk_1 = xk + ak * pk
        # compute residuals
        rk_1 = rk - ak * A @ pk
        # compute new pk
        bk = (rk_1[None] @ rk_1) / (rk[None] @ rk)
        pk_1 = rk_1 + bk * pk
        # update the next stateprint
        xk = xk_1
        pk = pk_1
        rk = rk_1

        k += 1
        if torch.norm(rk) < tol:
            converged = True

    return xk


t = 1000
N = 400
_A = torch.rand((N, N))
A = torch.sqrt(_A.T @ _A).to("cuda")  # positive semidefinite and symmetric
b = torch.rand((N)).to("cuda")


def foo():
    # conjugate_gradient(A, b, max_iter=5)
    torch.linalg.solve(A, b)


print((timeit.timeit(foo, number=t) / t) * 1000)

In [None]:
LU, pivots = torch.linalg.lu_factor(A)
X = torch.linalg.lu_solve(LU, pivots, B)

In [None]:
pivots.shape

In [None]:
import torch

t = 1000
N = 1000
_A = torch.rand((N, N))
A = torch.sqrt(_A.T @ _A).to("cuda")  # positive semidefinite and symmetric
b = torch.rand((N)).to("cuda")

In [None]:
_A = A.clone()
_b = b.clone()
_A.requires_grad = True
_b.requires_grad = True
x = conjugate_gradient(_A, _b, max_iter=5)
x_gt = torch.linalg.solve(A, b)
(x - x_gt).norm().backward()
_A.grad

In [None]:
from lib.optimizer.pcg import ConjugateGradient

_A = A.clone()
_b = b.clone()
_A.requires_grad = True
_b.requires_grad = True
x = ConjugateGradient.apply(_A, _b)
x_gt = torch.linalg.solve(A, b)
(x - x_gt).norm().backward()
_A.grad

In [None]:
torch.linalg.cond(A)

In [None]:
import torch

for i in range(9):
    out = torch.load(
        f"/home/borth/GuidedResearch/logs/2024-07-14_20-16-20_pcg_sampling/linsys/000000{i}.pt"
    )
    A = out["A"]
    b = out["b"]
    x = out["x"]

    print(torch.linalg.solve(A, b).norm())

In [None]:
from lib.optimizer.pcg import JaccobiConditionNet

j = JaccobiConditionNet()
M = j(A)
torch.linalg.cond(M @ A)

In [None]:
from tqdm import tqdm
from lib.optimizer.pcg import PCGSolver
from lib.optimizer.pcg import conjugate_gradient
from torch.optim import Adam, SGD
import torch

# define the matrixes


def generate_data(N, x_eps=1e-02, a_eps=1e-01):
    E = torch.rand((N, N)) * a_eps
    A = torch.eye(N) + (E.T @ E)
    x_gt = torch.rand((N)) * x_eps
    b = A @ x_gt
    A.requires_grad = True
    b.requires_grad = True
    return A, b, x_gt


def eval_loss(x_pcg, x_gt, verbose=True):
    l_pcg = (x_pcg - x_gt).norm()
    if verbose:
        print("Loss:", l_pcg.item())
    return l_pcg


torch.manual_seed(42)
N = 6
# A, b, x_gt = generate_data(N)
max_steps = 1000
lr = 1e-03
max_iter = 1
verbose = True
tol = 1e-08

out = torch.load(
    "/home/borth/GuidedResearch/logs/2024-07-04_12-17-20_pcg_sampling/linsys/0000000.pt"
)
A = out["A"]
b = out["b"]
# A.requires_grad = True
# b.requires_grad = True
x_gt = torch.linalg.solve(A, b)

pcg = PCGSolver(
    dim=N,
    max_iter=max_iter,
    verbose=verbose,
    tol=tol,
    mode="diagonal_offset",
    gradients="backprop",
)
# optimizer = SGD(pcg.parameters(), lr=lr, momentum=0.90)
optimizer = Adam(pcg.parameters(), lr=lr)

init_loss = None
for step in (pbar := tqdm(range(max_steps), total=max_steps)):
    pbar.set_description(f"{step}/{max_steps}")

    optimizer.zero_grad()

    A, b, x_gt = generate_data(N)
    M = pcg.condition_net(A)
    C_m = torch.linalg.cond(M @ A).item()
    C_a = torch.linalg.cond(A).item()
    x = pcg(A, b)
    loss = eval_loss(x, x_gt, verbose=False)
    loss.backward()
    optimizer.step()

    if init_loss is None:
        init_loss = loss.item()

    pbar.set_postfix(
        {"init_loss": init_loss, "loss": loss.item(), "C_m": C_m, "C_a": C_a}
    )
    pbar.update(1)

In [None]:
import torch
from lib.data.datamodule import PCGDataModule
from lib.data.dataset import PCGDataset

path = "/home/borth/GuidedResearch/data/linsys_pose"
dataset = PCGDataset(data_dir=path)
out = torch.load("/home/borth/GuidedResearch/data/linsys_pose/0000000.pt")
A = out["A"]
b = out["b"]
x_gt = torch.linalg.solve(A, b)
x_gt.norm(), x_gt, x_gt[:3].norm(), x_gt[3:6].norm(), x_gt[:3].norm() / x_gt[3:6].norm()

In [None]:
A.norm(), A.inverse().norm(), torch.linalg.cond(A) / A.norm()

In [None]:
x_gt = out["x_gt"]
eps = torch.rand_like(x_gt)
eps /= torch.linalg.vector_norm(eps)
eps *= 1e-02
l1_solution = torch.abs(eps).mean()
residual = torch.linalg.vector_norm(A @ (x_gt + eps) - b)
# residual = torch.linalg.vector_norm(A @ x_gt - b)
l1_solution, residual

In [None]:
from lib.optimizer.pcg import (
    JaccobiConditionNet,
    PCGSolver,
    IdentityConditionNet,
    ConditionNet,
    DenseConditionNet,
)
from functools import partial


class JaccobiConditionNet1(ConditionNet):
    name: str = "JaccobiConditionNet"

    def forward(self, A: torch.Tensor):
        diagonals = A.diagonal(dim1=-2, dim2=-1)
        # diagonals = torch.nn.functional.softmax(1 / diagonals, dim=-1)
        diagonals = 1 / diagonals
        return torch.diag_embed(diagonals)


condition_net = partial(DenseConditionNet, unknowns=6)
solver = PCGSolver(check_convergence=True, condition_net=condition_net)
x, info = solver(A, b)
x, info["k"]

In [None]:
A.inverse().diag() / A.inverse().diag().sum(), (1 / A.diag()) / (1 / A.diag()).sum()

In [None]:
from lib.optimizer.pcg import (
    JaccobiConditionNet,
    PCGSolver,
    IdentityConditionNet,
    ConditionNet,
    DenseConditionNet,
)

cond = DenseConditionNet(unknowns=6)
cond(A)

In [None]:
# def foo0():
#     batched_conjugate_gradient(batch["A"], batch["b"], max_iter=max_iter)


# def foo1():
#     for i in range(batch_size):
#         conjugate_gradient(batch["A"][i], batch["b"][i], max_iter=max_iter)


# print(timeit.timeit(foo0, number=n) / n)
# print(timeit.timeit(foo1, number=n) / n)

In [None]:
import logging

logging.basicConfig(level=logging.INFO)

from lib.optimizer.pcg import JaccobiConditionNet, preconditioned_conjugate_gradient
import timeit
import torch
from torch.utils.data import DataLoader
from lib.data.dataset import PCGDataset
from lib.optimizer.pcg import PCGSolver

ckpt_path = (
    "/home/borth/GuidedResearch/logs/2024-07-11_12-23-43_pcg/checkpoints/last.ckpt"
)
solver = PCGSolver.load_from_checkpoint(ckpt_path).to("cpu")

batch_size = 2
path = "/home/borth/GuidedResearch/data/linsys_pose"
dataset = PCGDataset(data_dir=path, split="test")
loader = DataLoader(dataset, batch_size=batch_size)
batch = next(iter(loader))

jaccobi = JaccobiConditionNet()

A = batch["A"][0]
b = batch["b"][0]

M = jaccobi(A)
M_A = torch.matmul(M, A)  # (B, N, N) or (N, N)
M_b = torch.matmul(M, b.unsqueeze(-1)).squeeze(-1)  # (B, N, N) or (N, N)
x, info_jaccobi = preconditioned_conjugate_gradient(
    A=A,
    b=b,
    M=M,
    max_iter=20,
    rel_tol=1e-08,
    verbose=True,
    check_convergence=False,
)

with torch.no_grad():
    M = solver.condition_net(A)
M_A = torch.matmul(M, A)  # (B, N, N) or (N, N)
M_b = torch.matmul(M, b.unsqueeze(-1)).squeeze(-1)  # (B, N, N) or (N, N)
x, info_pcg = preconditioned_conjugate_gradient(
    A=A,
    b=b,
    M=M,
    max_iter=20,
    rel_tol=1e-08,
    verbose=True,
    check_convergence=False,
)

x, info_wo = preconditioned_conjugate_gradient(
    A=A,
    b=b,
    max_iter=20,
    rel_tol=1e-08,
    verbose=True,
    check_convergence=False,
)

In [None]:
from lightning import Trainer

trainer = Trainer()
solver.max_iter = 20
out = trainer.predict(solver, loader)

In [None]:
from collections import defaultdict

stats = defaultdict(list)
for batch in out:
    for key, value in batch.items():
        stats[key].append(value)
for key, value in stats.items():
    stats[key] = torch.stack(value, dim=-1)
cond = stats["cond"].mean()
iters = stats["relres_norms"].size(0)
relres_norms = stats["relres_norms"].view(iters, -1).mean(dim=-1)

In [None]:
relres_norms

In [None]:
import matplotlib.pyplot as plt

# Extract the relative residual norms
relres = [v.detach() for v in info_pcg["relres_norms"]]
relres_jaccobi = info_jaccobi["relres_norms"]
relres_wo = info_wo["relres_norms"]

# Create a range for the x-axis based on the length of the data
iterations_pcg = range(len(relres_pcg))
iterations_jaccobi = range(len(relres_jaccobi))
iterations_wo = range(len(relres_wo))

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(iterations_pcg, relres_pcg, label="PCG", marker="o")
plt.plot(iterations_jaccobi, relres_jaccobi, label="Jaccobi", marker="o")
plt.plot(iterations_wo, relres_wo, label="Without Optimization", marker="s")

# Adding titles and labels
plt.title("Relative Residual Norms Comparison")
plt.xlabel("Iterations")
plt.ylabel("Relative Residual Norms")

# Set y-axis to log scale
plt.yscale("log")

# Set y-axis limit to start from 10^-8
plt.ylim(bottom=1e-7)

# Add a horizontal red line at 10^-6
plt.axhline(y=1e-6, color="red", linestyle="--", label="Convergence at $10^{-6}$")

# Show legend
plt.legend()

# Show grid
plt.grid(True)

# Show plot
plt.show()

In [None]:
torch.linalg.solve(A, b)

In [None]:
torch.matmul(batch["A"][0], batch["b"][0])

In [None]:
from lib.optimizer.pcg import conjugate_gradient, log
import logging

log.setLevel(logging.INFO)
conjugate_gradient(batch["A"][0], batch["b"][0], verbose=True)

In [None]:
torch.diag_embed(1 / batch["A"][0].diagonal(dim1=-2, dim2=-1))

In [None]:
A = batch["A"]
ones = torch.ones((A.shape[0], A.shape[1]), device=A.device)
torch.diag_embed(ones).shape

In [None]:
torch.linalg.solve(batch["A"], batch["b"])

In [None]:
x = torch.ones_like(A)[0]
x.expand(A[0].shape).shape

In [None]:
torch.linalg.solve(A, b)

In [None]:
torch.bmm(A[None], b[None, ..., None]).shape

In [None]:
import timeit


def foo():
    b[None] @ b


n = 100000
timeit.timeit(foo, number=n) / n

In [None]:
def foo():
    if torch.norm(b) > 0.1:
        pass


n = 100000
timeit.timeit(foo, number=n) / n

In [None]:
pcg = PCGSolver(dim=N, max_iter=max_iter, verbose=verbose, tol=tol, mode="dense")
A, b, x_gt = generate_data(N)
M = pcg.condition_net(A)
M

In [None]:
torch.linalg.cond(pcg.condition_net(A) @ A)

In [None]:
from lib.optimizer.pcg import preconditioned_conjugate_gradient

A, b, x_gt = generate_data(N)
x_cg = preconditioned_conjugate_gradient(A, b, max_iter=2)
pcg.max_iter = 2
x = pcg(A, b)
loss = eval_loss(x, x_gt, verbose=True)
loss = eval_loss(x_cg, x_gt, verbose=True)

In [None]:
len(list(p.size() for p in pcg.condition_net.parameters()))

In [None]:
from lib.optimizer.pcg import conjugate_gradient

torch.manual_seed(42)
N = 100
E = torch.rand((N, N)) * 1e-02
A = torch.eye(N) + (E.T @ E)
x = torch.rand((N)) * 1e-02
b = A @ x

x_pcg = conjugate_gradient(A, b, max_iter=1)
r_pcg = (A @ x_pcg - b).mean()
r_pcg

In [None]:
N = 5
tri_N = ((N * N - N) // 2) + N
L = torch.zeros((N, N))
tril_indices = torch.tril_indices(row=N, col=N, offset=0)
L[tril_indices[0], tril_indices[1]] = torch.rand(tri_N)
L[4, 0] = 0

In [None]:
L.inverse()

In [None]:
import torch
import time


def cg_batch(
    A_bmm, B, M_bmm=None, X0=None, rtol=1e-3, atol=0.0, maxiter=None, verbose=False
):
    """Solves a batch of PD matrix linear systems using the preconditioned CG algorithm.

    This function solves a batch of matrix linear systems of the form

        A_i X_i = B_i,  i=1,...,K,

    where A_i is a n x n positive definite matrix and B_i is a n x m matrix,
    and X_i is the n x m matrix representing the solution for the ith system.

    Args:
        A_bmm: A callable that performs a batch matrix multiply of A and a K x n x m matrix.
        B: A K x n x m matrix representing the right hand sides.
        M_bmm: (optional) A callable that performs a batch matrix multiply of the preconditioning
            matrices M and a K x n x m matrix. (default=identity matrix)
        X0: (optional) Initial guess for X, defaults to M_bmm(B). (default=None)
        rtol: (optional) Relative tolerance for norm of residual. (default=1e-3)
        atol: (optional) Absolute tolerance for norm of residual. (default=0)
        maxiter: (optional) Maximum number of iterations to perform. (default=5*n)
        verbose: (optional) Whether or not to print status messages. (default=False)
    """
    K, n, m = B.shape

    if M_bmm is None:
        M_bmm = lambda x: x
    if X0 is None:
        X0 = M_bmm(B)
    if maxiter is None:
        maxiter = 5 * n

    assert B.shape == (K, n, m)
    assert X0.shape == (K, n, m)
    assert rtol > 0 or atol > 0
    assert isinstance(maxiter, int)

    X_k = X0
    R_k = B - A_bmm(X_k)
    Z_k = M_bmm(R_k)

    P_k = torch.zeros_like(Z_k)

    P_k1 = P_k
    R_k1 = R_k
    R_k2 = R_k
    X_k1 = X0
    Z_k1 = Z_k
    Z_k2 = Z_k

    B_norm = torch.norm(B, dim=1)
    stopping_matrix = torch.max(rtol * B_norm, atol * torch.ones_like(B_norm))

    if verbose:
        print("%03s | %010s %06s" % ("it", "dist", "it/s"))

    optimal = False
    start = time.perf_counter()
    for k in range(1, maxiter + 1):
        start_iter = time.perf_counter()
        Z_k = M_bmm(R_k)

        if k == 1:
            P_k = Z_k
            R_k1 = R_k
            X_k1 = X_k
            Z_k1 = Z_k
        else:
            R_k2 = R_k1
            Z_k2 = Z_k1
            P_k1 = P_k
            R_k1 = R_k
            Z_k1 = Z_k
            X_k1 = X_k
            denominator = (R_k2 * Z_k2).sum(1)
            denominator[denominator == 0] = 1e-8
            beta = (R_k1 * Z_k1).sum(1) / denominator
            P_k = Z_k1 + beta.unsqueeze(1) * P_k1

        denominator = (P_k * A_bmm(P_k)).sum(1)
        denominator[denominator == 0] = 1e-8
        alpha = (R_k1 * Z_k1).sum(1) / denominator
        X_k = X_k1 + alpha.unsqueeze(1) * P_k
        R_k = R_k1 - alpha.unsqueeze(1) * A_bmm(P_k)
        end_iter = time.perf_counter()

        residual_norm = torch.norm(A_bmm(X_k) - B, dim=1)

        if verbose:
            print(
                "%03d | %8.4e %4.2f"
                % (
                    k,
                    torch.max(residual_norm - stopping_matrix),
                    1.0 / (end_iter - start_iter),
                )
            )

        if (residual_norm <= stopping_matrix).all():
            optimal = True
            break

    end = time.perf_counter()

    if verbose:
        if optimal:
            print(
                "Terminated in %d steps (reached maxiter). Took %.3f ms."
                % (k, (end - start) * 1000)
            )
        else:
            print(
                "Terminated in %d steps (optimal). Took %.3f ms."
                % (k, (end - start) * 1000)
            )

    info = {"niter": k, "optimal": optimal}

    return X_k, info


A = torch.tensor([[4.0, 1], [1, 3]])[None, ...]


def A_bmm(X):
    Y = [(A[i] @ X[i]).unsqueeze(0) for i in range(1)]
    return torch.cat(Y, dim=0)


b = torch.tensor([1.0, 2])[None, ...][..., None]
cg_batch(A_bmm=A_bmm, B=b)

In [None]:
N = 100
_A = torch.rand((N, N))
z = torch.zeros_like(_A)
for idx in range(70):
    i = idx * 10
    j = (idx + 1) * 10
    z[i:j, i:j] = 1.0
_A *= z
A = torch.sqrt(_A.T @ _A)  # positive and symetric
# A = torch.diag(torch.diag(A))
b = torch.rand((N)) + 2


def A_bmm(X):
    Y = [(A[None, ...][i] @ X[i]).unsqueeze(0) for i in range(1)]
    return torch.cat(Y, dim=0)

In [None]:
import time

s = time.time()
out = cg_batch(A_bmm=A_bmm, B=b[None, ...][..., None], maxiter=10)
print(time.time() - s)
x_pcg = out[0].squeeze()
s = time.time()
x_optim = torch.linalg.solve(A, b)
print(time.time() - s)
pcg_norm = ((A @ x_pcg) - b).norm()
optim_norm = ((A @ x_optim) - b).norm()
pcg_norm, optim_norm, b.norm()

In [None]:
class SequentialTrainer:
    def __init__(
        self,
        kernel_size: int = 1,
        stride: int = 1,
        dilation: int = 1,
        start_frame: int = 0,
        end_frame: int = 126,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.start_frame = start_frame
        self.end_frame = end_frame
        self.frames = list(range(self.start_frame, self.end_frame))
        self.mode = "sequential"
        self.prev_last_frame_idx = self.start_frame
        self.kernal_size = kernel_size
        self.stride = stride
        self.dilation = dilation

    def frame_idxs_iter(self):
        frame_idxs = list(range(self.start_frame, self.end_frame + 1, self.dilation))
        for idx in range(0, len(frame_idxs), self.stride):
            idxs = frame_idxs[idx : idx + self.kernal_size]
            if len(idxs) == self.kernal_size:
                yield idxs


trainer = SequentialTrainer(
    kernel_size=3,
    stride=3,
    dilation=2,
    start_frame=0,
    end_frame=10,
)

for frame_idxs in trainer.frame_idxs_iter():
    print(frame_idxs)

In [None]:
import torch

N = 2
_A = torch.rand((N, N))
A = _A.T @ _A  # positive semidefinite and symmetric
b = torch.rand((N))
x = torch.linalg.solve(A, b)

In [None]:
x0 = torch.zeros(N, requires_grad=True)
residual = ((A @ x0 - b) ** 2).sum()
residual.backward()
x0.grad

In [None]:
x0 = torch.zeros(N, requires_grad=True)
error = ((x0 - x) ** 2).sum()
error.backward()
x0.grad

In [None]:
A = torch.tensor([[1.0, -1], [-1, 1]])
A @ x.unsqueeze(-1)

In [None]:
x.unsqueeze(-1)

In [None]:
error = (x0 - x).sum()
error.backward()
x0.grad