### This notebook uses graph laplacian constraint to retrieve the block diagonal matrix, given a noisy version of said matrix

Works 100% on seeds 0-4. 

In [1]:
import numpy as np
import torch
import scipy.linalg as slinalg
from scipy.sparse import csgraph
np.random.seed(0)

### First lets make a noisy block diagonal matrix (note: entries should be positive)

In [2]:
splits = [4, 3, 2]
dims = sum(splits)
noise = 5e-2

In [3]:
components = []
for split in splits:
  components.append(np.ones((split, split)))
block_diag_mask = slinalg.block_diag(*components)
base_matrix = np.random.uniform(size=(dims, dims)) * block_diag_mask
noise_matrix = np.random.randn(dims, dims) * noise
noisy_matrix = np.abs(base_matrix + noise_matrix).round(2)
noisy_matrix

array([[0.98, 0.51, 1.11, 0.74, 0.  , 0.06, 0.  , 0.07, 0.01],
       [0.46, 0.76, 0.15, 0.84, 0.04, 0.01, 0.07, 0.08, 0.07],
       [0.04, 0.94, 0.35, 0.86, 0.04, 0.01, 0.04, 0.06, 0.02],
       [0.22, 0.74, 0.38, 0.54, 0.02, 0.02, 0.07, 0.  , 0.03],
       [0.04, 0.04, 0.03, 0.  , 0.42, 0.73, 0.73, 0.1 , 0.01],
       [0.04, 0.04, 0.08, 0.03, 0.68, 0.9 , 0.38, 0.07, 0.01],
       [0.07, 0.02, 0.02, 0.03, 0.85, 0.44, 0.85, 0.06, 0.1 ],
       [0.04, 0.02, 0.05, 0.02, 0.01, 0.01, 0.04, 0.75, 0.9 ],
       [0.02, 0.03, 0.06, 0.04, 0.04, 0.02, 0.02, 0.76, 0.08]])

### Observe that randomly thresholding isn't super effective:

In [4]:
thresholded = noisy_matrix.copy()
thresholded[thresholded < 0.05] = 0.
thresholded

array([[0.98, 0.51, 1.11, 0.74, 0.  , 0.06, 0.  , 0.07, 0.  ],
       [0.46, 0.76, 0.15, 0.84, 0.  , 0.  , 0.07, 0.08, 0.07],
       [0.  , 0.94, 0.35, 0.86, 0.  , 0.  , 0.  , 0.06, 0.  ],
       [0.22, 0.74, 0.38, 0.54, 0.  , 0.  , 0.07, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.42, 0.73, 0.73, 0.1 , 0.  ],
       [0.  , 0.  , 0.08, 0.  , 0.68, 0.9 , 0.38, 0.07, 0.  ],
       [0.07, 0.  , 0.  , 0.  , 0.85, 0.44, 0.85, 0.06, 0.1 ],
       [0.  , 0.  , 0.05, 0.  , 0.  , 0.  , 0.  , 0.75, 0.9 ],
       [0.  , 0.  , 0.06, 0.  , 0.  , 0.  , 0.  , 0.76, 0.08]])

### Lets get the Laplacian matrix L first. 

If A in R^{NxN}+ (positive reals) has K connected components, the rank of L is N-K. 

In [5]:
def laplacian(A):
  """
  My implementation; w_ij = -w_ij for i != j; w_ii = sum_{j != i} w_ij
  A bit slower than Scipy's for numpy arrays. 
  
  Works for both numpy array and torch tensor. 
  
  Note that this will be a differentiable function of A.
  Note that Laplacian at most rank n-1. 
  """
  eye = torch.eye if torch.is_tensor(A) else np.eye
  I = eye(len(A))
  return (I - 1) * A + I * ((1-I)*A).sum(0, keepdims=True)

assert np.all(laplacian(noisy_matrix) == csgraph.laplacian(noisy_matrix))
csgraph.laplacian(noisy_matrix)

array([[ 0.93, -0.51, -1.11, -0.74, -0.  , -0.06, -0.  , -0.07, -0.01],
       [-0.46,  2.34, -0.15, -0.84, -0.04, -0.01, -0.07, -0.08, -0.07],
       [-0.04, -0.94,  1.88, -0.86, -0.04, -0.01, -0.04, -0.06, -0.02],
       [-0.22, -0.74, -0.38,  2.56, -0.02, -0.02, -0.07, -0.  , -0.03],
       [-0.04, -0.04, -0.03, -0.  ,  1.68, -0.73, -0.73, -0.1 , -0.01],
       [-0.04, -0.04, -0.08, -0.03, -0.68,  1.3 , -0.38, -0.07, -0.01],
       [-0.07, -0.02, -0.02, -0.03, -0.85, -0.44,  1.35, -0.06, -0.1 ],
       [-0.04, -0.02, -0.05, -0.02, -0.01, -0.01, -0.04,  1.2 , -0.9 ],
       [-0.02, -0.03, -0.06, -0.04, -0.04, -0.02, -0.02, -0.76,  1.15]])

In [6]:
At = torch.tensor(noisy_matrix)

In [7]:
L = laplacian(At)

### Now we're going to approximate $A$ by $\hat A$ and optimize laplacian of $\hat A$ to be low rank.

Doing this with Pytorch and gradient descent instead of how it is done in the paper http://openaccess.thecvf.com/content_cvpr_2014/papers/Feng_Robust_Subspace_Segmentation_2014_CVPR_paper.pdf (they use quadratic programming solver).

First make a function to find a low rank approx using PCA, which we will compute and use as a learning target. 
(another way might be to parameterize the low rank approximation as a matrix product) 

In [8]:
def low_rank_approx(A, rank):
  """
  Uses PCA to compute a low rank approximation to A.
  """
  assert rank <= len(A)
  U, S, V = torch.svd(A)
  return torch.mm(torch.mm(U[:,:rank], torch.diag(S[:rank])), V[:,:rank].t())

In [9]:
# Verify that this works
for i in range(1, len(L)):
  low_rank_L = low_rank_approx(L, i)
  print('Rank {} approximation error: {}'.format(torch.matrix_rank(low_rank_L), torch.sum((low_rank_L - L)**2)))

Rank 1 approximation error: 24.267596936057103
Rank 2 approximation error: 17.222321487422356
Rank 3 approximation error: 11.323908627841078
Rank 4 approximation error: 7.190660808959906
Rank 5 approximation error: 3.2489762647701905
Rank 6 approximation error: 0.25137874998117804
Rank 7 approximation error: 0.08383584308686802
Rank 8 approximation error: 7.960309488000459e-29


In [10]:
import torch.nn.functional as F

In [11]:
hat_At_sqrt = torch.sqrt(At.detach())
hat_At_sqrt.requires_grad_(True)
K = 3 # number of connected components
optimizer = torch.optim.Adam([hat_At_sqrt], lr=1e-3)

In [12]:
for i in range(1000):
  hat_At = hat_At_sqrt ** 2
  hat_L = laplacian(hat_At + hat_At.T)
  low_rank_target = low_rank_approx(hat_L, len(hat_L) - K)
  loss = F.l1_loss(hat_L, low_rank_target) + F.mse_loss(hat_At, At)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

In [13]:
(hat_At - At).detach().numpy().round(2)

array([[-0.  ,  0.  ,  0.  ,  0.  ,  0.  , -0.06,  0.  , -0.07, -0.01],
       [ 0.  ,  0.  ,  0.  ,  0.  , -0.04, -0.01, -0.07, -0.08, -0.07],
       [ 0.  ,  0.  ,  0.  ,  0.  , -0.04, -0.01, -0.04, -0.06, -0.02],
       [-0.  , -0.  , -0.  ,  0.  , -0.02, -0.02, -0.07,  0.  , -0.03],
       [-0.04, -0.04, -0.03,  0.  ,  0.  ,  0.  ,  0.  , -0.1 , -0.01],
       [-0.04, -0.04, -0.08, -0.03,  0.  , -0.  , -0.  , -0.07, -0.01],
       [-0.07, -0.02, -0.02, -0.03,  0.  , -0.  , -0.  , -0.06, -0.1 ],
       [-0.04, -0.02, -0.05, -0.02, -0.01, -0.01, -0.04, -0.  , -0.  ],
       [-0.02, -0.03, -0.06, -0.04, -0.04, -0.02, -0.02,  0.  ,  0.  ]])

## Behold the thresholded matrix:

In [14]:
hat_At.detach().numpy().round(2)

array([[0.98, 0.51, 1.11, 0.74, 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.46, 0.76, 0.15, 0.84, 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.04, 0.94, 0.35, 0.86, 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.22, 0.74, 0.38, 0.54, 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.42, 0.73, 0.73, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.68, 0.9 , 0.38, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.85, 0.44, 0.85, 0.  , 0.  ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.75, 0.9 ],
       [0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.76, 0.08]])

In [15]:
At.numpy().round(2)

array([[0.98, 0.51, 1.11, 0.74, 0.  , 0.06, 0.  , 0.07, 0.01],
       [0.46, 0.76, 0.15, 0.84, 0.04, 0.01, 0.07, 0.08, 0.07],
       [0.04, 0.94, 0.35, 0.86, 0.04, 0.01, 0.04, 0.06, 0.02],
       [0.22, 0.74, 0.38, 0.54, 0.02, 0.02, 0.07, 0.  , 0.03],
       [0.04, 0.04, 0.03, 0.  , 0.42, 0.73, 0.73, 0.1 , 0.01],
       [0.04, 0.04, 0.08, 0.03, 0.68, 0.9 , 0.38, 0.07, 0.01],
       [0.07, 0.02, 0.02, 0.03, 0.85, 0.44, 0.85, 0.06, 0.1 ],
       [0.04, 0.02, 0.05, 0.02, 0.01, 0.01, 0.04, 0.75, 0.9 ],
       [0.02, 0.03, 0.06, 0.04, 0.04, 0.02, 0.02, 0.76, 0.08]])