In [None]:
%cd ~/projects/ip-is-all-you-need

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm.notebook import tqdm

from ip_is_all_you_need.simulations import gen_dictionary, mutual_coherence

In [None]:
from jax import random
import cr.sparse.dict as crdict
from random import shuffle

m = 3200
n = 4000
indices = list(range(n))
shuffle(indices)
subset = indices[:m]

key = random.PRNGKey(0)
Phi = crdict.fourier_basis(n)[subset, :]
print(crdict.coherence(Phi))
Phi = crdict.gaussian_mtx(key, m, n)
print(crdict.coherence(Phi))

In [None]:
m_min = 500
m_max = 3100
m_step = 100
n_min = 500
n_max = 3100
n_step = 100
batch_size = 100
grid_size = (int((m_max - m_min) / m_step), int((n_max - n_min) / n_step))
coherences = np.zeros(grid_size)

def batch_mutual_coherence(Phi: torch.Tensor) -> torch.Tensor:
    return (Phi.transpose(1, 2) @ Phi).triu(diagonal=1).abs().max(dim=2).values.max(dim=1).values


for i, n in tqdm(enumerate(range(n_min, n_max, n_step)), total=grid_size[1]):
    for j, m in enumerate(range(m_min, n + m_step, m_step)):
        Phi = gen_dictionary(batch_size, m, n)
        coherences[i, j] = batch_mutual_coherence(Phi).mean().item()

In [None]:
coherences[5,3]

In [None]:
import seaborn as sns
sns.heatmap(coherences, xticklabels=np.arange(n_min, n_max, n_step), yticklabels=np.arange(m_min, m_max, m_step))