In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def calc_unobserved_rmse(U, V, A_hat, mask):
  pred = np.multiply(A_hat, (1 - mask))
  truth = np.multiply(np.dot(U, V.T), (1 - mask))
  cnt = np.sum(1 - mask)
  return (np.linalg.norm(pred - truth, "fro") ** 2 / cnt) ** 0.5

In [3]:
def make_plots(original,R_hat,mask,noise):
    fig, ((ax1, on1, vi1, ax2),
          (ax3, no1, ax4, er2)) = plt.subplots(2, 4, sharex='col', sharey='row', figsize=(18, 6))

    p = ax1.imshow(original, cmap=plt.get_cmap('winter'), interpolation='nearest')
    ax1.set_title('Original')
    fig.colorbar(p, ax=ax1)
    
    p = on1.imshow(original + noise, cmap=plt.get_cmap('winter'), interpolation='nearest')
    on1.set_title('Original with noise')
    fig.colorbar(p, ax=on1)
    
    p = vi1.imshow(original + noise, cmap=plt.get_cmap('winter'), interpolation='nearest')
    vi1.set_title('Input data')
    fig.colorbar(p, ax=vi1)
    p = vi1.imshow(1-mask, 'jet', interpolation='none', alpha=0.8)
    
    p = ax2.imshow(R_hat, cmap=plt.get_cmap('winter'), interpolation='nearest')
    ax2.set_title('Recovered')
    fig.colorbar(p, ax=ax2)
    
    p = ax3.imshow(mask, cmap=plt.get_cmap('summer'), interpolation='nearest')
    fig.colorbar(p, ax=ax3)
    ax3.set_title('Mask')
    
    p = no1.imshow(noise, cmap=plt.get_cmap('winter'), interpolation='nearest')
    no1.set_title('Noise')
    fig.colorbar(p, ax=no1)
    
    R = original
    p = ax4.imshow(np.abs(R-R_hat)/R, cmap=plt.get_cmap('hot'), interpolation='nearest', vmin=0, vmax=1)
    ax4.set_title('% error wrt clean data')
    fig.colorbar(p, ax=ax4)
    
    R = original + noise
    p = er2.imshow(np.abs(R-R_hat)/R, cmap=plt.get_cmap('hot'), interpolation='nearest', vmin=0, vmax=1)
    er2.set_title('% error wrt noisy observation')
    fig.colorbar(p, ax=er2)

    print("RMSE:", calc_unobserved_rmse(U, V, R_hat, mask))

In [4]:
#### VERSION THAT PREFERS SMALLER STEP SIZE #####

max_stepsize_searches = 10
max_direction_searches = 1000

def find_next_point(current_point, step_size):
    norm = np.linalg.norm(current_point,'nuc')
    
    for j in range(max_stepsize_searches):
        step_size = step_size / 2
        for i in range(max_direction_searches):
            direction = np.random.randn(n, n)
            shift = direction * step_size
            shift[1 - mask] = 0
            next_point = current_point + shift

            if np.linalg.norm(next_point,'nuc') < norm:
                break
        if np.linalg.norm(next_point,'nuc') < norm:
                break
            
    if i == max_direction_searches - 1:
        print('fail')
    
    nex_step_size = step_size
    
    return next_point, nex_step_size
    

In [5]:
#### VERSION THAT PREFERS LARGER STEP SIZE #####

max_stepsize_searches = 100
max_direction_searches = 1000

def find_next_point(current_point, step_size):
    norm = np.linalg.norm(current_point,'nuc')
    
    for i in range(max_direction_searches):
        direction = np.random.randn(n, n)
        for j in range(max_stepsize_searches):
            shift = direction * step_size
            shift[1 - mask] = 0
            next_point = current_point + shift
            step_size = step_size / 2

            if np.linalg.norm(next_point,'nuc') < norm:
                break
        if np.linalg.norm(next_point,'nuc') < norm:
                break
            
    if i == max_direction_searches - 1:
        print('fail')
    
    nex_step_size = step_size
    
    return next_point, nex_step_size
    

In [6]:
# Define a matrix, mask and noise
p = 0.5
rank = 1
sigma = 0#1e-3

n = 7
constant_step_size = 0.05

U = np.random.randn(n, rank)
V = np.random.randn(n, rank)
original = np.dot(U, V.T)
noise = sigma * np.random.randn(n, n)
R = noise + original

lam = 5*sigma*np.sqrt(n*p)        
mask = np.random.choice([0, 1], size=(n,n), p=[1-p, p])

In [10]:
step_size = 1e-3

##########################

corrupted_matrix = R
corrupted_matrix[mask] = 0

starting_point = corrupted_matrix
starting_point[mask] = np.random.randn(n, n)

current_point = starting_point

prev_norm = np.linalg.norm(starting_point,'nuc')

norms = []

for i in range(300):
    cur_norm = np.linalg.norm(current_point,'nuc')
    norms.append(cur_norm)
    diff = cur_norm - prev_norm
    if diff > 0.1:
        break
    
    print(i, cur_norm, diff, step_size)
    prev_norm = cur_norm
    
    step_size = constant_step_size
    
    current_point, step_size = find_next_point(current_point, step_size)

plt.plot(norms)

0 7.882116986352521 0.0 0.001


KeyboardInterrupt: 

In [None]:
make_plots(original,current_point,mask,noise)

Note: above error is computed with respect to the original with noise