In [None]:
%matplotlib inline

In [None]:
%config InlineBackend.figure_format = "retina"

In [None]:
%cd ~/projects/ip-is-all-you-need

In [None]:
import logging

import numpy as np
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import vmap, pmap, jit
import torch
torch.set_default_dtype(torch.float64)

from ip_is_all_you_need.simulations import gen_dictionary, generate_measurements_and_coeffs

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
Phi = gen_dictionary(1, 500, 1000)
y, x = generate_measurements_and_coeffs(Phi, p=0.1)

In [None]:
Phi_tmp = gen_dictionary(10, 500, 1000)
y_tmp, x_tmp = generate_measurements_and_coeffs(Phi_tmp, p=0.1)
Phi_batch = jnp.array(Phi_tmp.numpy())
y_batch, x_batch = jnp.array(y_tmp.numpy()), jnp.array(x_tmp.numpy())

In [None]:
Phi = jnp.array(Phi[0].numpy())
y = jnp.array(y[0].numpy())
x = jnp.array(x[0].numpy())

In [None]:
from collections import defaultdict


def ip(Phi, y, tol, normalize=True):
    m = Phi.shape[0]
    cols = []
    all_cols = list(range(Phi.shape[1]))
    rem_cols = all_cols.copy()
    history = defaultdict(list)
    for k in range(101):
        Phi_t = Phi[:, cols]
        Phi_rem = Phi[:, rem_cols]
        P_perp = jnp.eye(m) - Phi_t @ jnp.linalg.pinv(Phi_t)
        Phi_rem_projected = P_perp @ Phi_rem
        numerator = jnp.linalg.norm(Phi_rem_projected, axis=0)[:, None] ** 2
        if normalize:
            denominator_quotient = jnp.exp(
                2
                * (
                    jnp.log(jnp.abs(Phi_rem_projected.T @ y))
                    - jnp.log(jnp.linalg.norm(P_perp @ y))
                )
            )
            denominator = (numerator - denominator_quotient).clip(0)
            objective = jnp.exp(jnp.log(numerator) - jnp.log(denominator))
        else:
            objective = jnp.exp(2 * (jnp.log(jnp.abs(Phi_rem_projected.T @ y)) - jnp.log(jnp.linalg.norm(Phi_rem_projected, axis=0)[:, None])))

        max_col = rem_cols[jnp.argmax(objective).item()]
        cols.append(max_col)
        rem_cols.remove(max_col)
        history["col"].append(max_col)
        history["objective"].append(objective.max().item())

    return history


def ip2(Phi, y, max_iters, tol=1e-12):
    m, n = Phi.shape
    cols = []
    all_cols = list(range(Phi.shape[1]))
    rem_cols = all_cols.copy()
    history = defaultdict(list)
    Phi_t = Phi[:, jnp.array(cols, dtype=jnp.int32)]
    Phi_rem = Phi[:, jnp.array(rem_cols, dtype=jnp.int32)]
    for _ in range(max_iters):
        Psi_t = jnp.hstack((y, Phi_t))
        Phi_t_pinv = jnp.linalg.pinv(Phi_t)
        Pi_t = Phi_t @ Phi_t_pinv
        Pi_t_perp = jnp.eye(m) - Pi_t
        Psi_t_pinv = jnp.linalg.pinv(Psi_t)
        P_t = Psi_t @ Psi_t_pinv
        P_t_perp = jnp.eye(m) - P_t
        numerator = jnp.diag(Phi_rem.T @ Pi_t_perp @ Phi_rem)
        denominator = jnp.diag(Phi_rem.T @ P_t_perp @ Phi_rem).clip(0)
        objective = jnp.exp(jnp.log(numerator) - jnp.log(denominator))

        if (objective - 1).max() < tol:
            break

        max_col = rem_cols[jnp.argmax(objective).item()]
        cols.append(max_col)
        rem_cols.remove(max_col)

        Phi_t = Phi[:, jnp.array(cols, dtype=jnp.int32)]
        Phi_rem = Phi[:, jnp.array(rem_cols, dtype=jnp.int32)]

        history["col"].append(max_col)
        history["objective"].append(objective.max().item())
        history["x_hat"].append(
            jnp.zeros(n).at[jnp.array(cols, jnp.int32)].set(jnp.linalg.lstsq(Phi_t, y)[0].ravel())
        )
        history["y_hat"].append(Phi @ history["x_hat"][-1])

        if jnp.isinf(objective.max()).item():
            break

    return history


In [None]:
%time hist = ip(Phi, y, 1e-6)

In [None]:
%time hist2 = ip2(Phi, y, 10_000)

In [None]:
%time hist3 = ip(Phi, y, 1e-6, normalize=False)

In [None]:
hist3["col"] == hist["col"]

In [None]:
hist2["col"] == hist3["col"][:-1]

In [None]:
plt.plot(x); plt.plot(hist2["x_hat"][-1], "o", fillstyle="none")