# Binary Ground Metric Learning on MNIST

### Purpose of the notebook is to find the ground metric of the Wasserstein distance using 2 true classes as the input, i.e.: 
- For MNIST we use 30 examples (ideally is to have balanced classes) of each class and transform the problem of  binary classification into finding 2 clusters that are close in wasserstein distance
- Same for Caltech 256

## Usage
* mml.gml - Wasserstein distances
* mml.gml - OT matrices
* mml.gml - Similarity matrix
* mml.datasets - Datasets loading

** ALGORITHM 1 and Algorithm 2 are implemented ** 

When using the off-diagonal ones as a ground matrix, this puts all histograms at distance 1. This has two drawbacks, from Cuturi's paper: 
    * uninformative gradient in the first iterations
    * far from the optimum so will converge slowly..

In [1]:
from __future__ import print_function, division
%matplotlib inline

#import warnings
#warnings.filterwarnings("ignore")
import time
import tqdm
from tqdm import tqdm_notebook as tqdm_notebook
from tqdm import trange
import logging
from collections import OrderedDict
logger = logging.getLogger(__name__)

# Your code goes here
import os
import subprocess
from pathlib import Path
import numpy as np
from numpy import testing
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import seaborn as sns

from sklearn.metrics import adjusted_rand_score, confusion_matrix
from sklearn.preprocessing import StandardScaler, Normalizer

import metric_learn
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import DistanceMetric
from sklearn.neighbors.ball_tree import BallTree
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss

# Shogun - Metric Learning
from shogun import LMNN as shogun_LMNN
from shogun import RealFeatures, MulticlassLabels
from sklearn.utils.validation import check_X_y, check_array

# POT imports
import ot
from ot.datasets import get_1D_gauss as gauss

# MML import 
from mml import wasserstein, transform, gml, ot_testing, datasets, helper

data_path = str(Path(os.getcwd())) + "/data/"
results_path = str(Path(os.getcwd())) + "/results/binary"

def write_to_pickle(dataframe, name):
    dataframe.to_pickle(data_path + name + ".pickle")
def read_from_pickle(name): 
    return pd.read_pickle(data_path + name + ".pickle")

In [None]:
import importlib
importlib.reload(wasserstein)
importlib.reload(gml)
importlib.reload(helper)
# Load Caltech Data from disk 
#X_caltech = np.load()

## Select from labels pairs of datapoints to construct multiple training sets for GML. 
* Below we select binary classification for digit 1 and digit 2 or for the caltech dataset

In [2]:
# Load Hellinger representation of the data 
# TODO: Change method in load_mnist to first look on disk before anything else 
X,Y = datasets.load_mnist("Hellinger")
train_size=0.25
test_size=0.75
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,train_size=train_size,test_size=test_size,random_state=123)
np.bincount(Y_train)

array([32, 45, 42, 51, 38, 51, 43, 44, 51, 52])

In [3]:
# We can have at most 45 unique pairs for MNIST, pick two classes
# For Caltech, classes are balanced
data_dict = datasets.dictionary(X,Y)

class1 = 1
class2 = 2 
[x12,y12] = datasets.data_pairs(data_dict,class1,class2)

X_train,X_test,Y_train,Y_test = train_test_split(x12,y12,train_size=train_size,test_size=test_size,random_state=123)

# Transform into histogram 
# Neither Wasserstein POT nor LMNN are normalized, have to do it yourself
#[X_train_normalized, X_test_normalized] = transform.normalize(X_train,X_test,'l1')

X_train_normalized = X_train/X_train.sum(axis=1).reshape((-1,1))
X_test_normalized = X_test/X_test.sum(axis=1).reshape((-1,1))
n = X_train_normalized.shape[0]

In [6]:
# These are independent of what data we choose, only on the dimensionality of a datapoint, i.e. here is 64
d = X_train.shape[1]
x = np.arange(d,dtype=np.float64)
x1 = x.reshape((d,1))
# By default metric ='sqeuclidean' in the function
M_sqeuclidean = ot.dist(x1,x1,metric='sqeuclidean')
M_eye = ot.dist(x1,x1,metric='hamming')

xx,yy = np.meshgrid(np.arange(np.sqrt(d)),np.arange(np.sqrt(d)))
xy = np.hstack((xx.reshape(-1,1),yy.reshape(-1,1)))
M_mesh = ot.dist(xy, xy)

## GML $M^*$ and $W_d$ via Cuturi 
Questions to answer
- why the seesaw objective? 
- are digits normalized? no, we need to perform that. does POT do the normalization step itself? NO, we need to do it ourselves
- We need to build $M_1$, the initial distance between digits for our algorithm. Follow the algo for doing this.
- Try Wasserstein with Euclidean distance so that you have an idea of the accuracy. Errors are important too though. 
- Randomize the whole training so you can get some error bars. 
Step 1: Find $D(digit1,digit2)$ as Euclidean and an input to the Wasserstein algorithm. 
Step 2: Have a way of computing $X^*$ for two digits. This is the OT distance. (have another one using Sinkhorn)
Step 3: 1)M and 2)X will be 64x64 matrices with each entry comparing points in 1)data space 2) probability space

Learning ground metrics. 
- Where does the similarity matrix come from? i.e. for us it's multiclass. Build $w_ij$ for similar histograms using 1/nk
- For the feasible set of ground metrics, M has to have L1 norm of 1. 
- Find $W(r_i,r_j)$ such that it's small for similar images, and big for dissimilar images. Problem is all dissimilarities are on the same page - which is not entirely true. Can we do better if we engineer that? 
- Remember, Wasserstein-2 is $G_i,j(M) = <X^*,M>$ and gradient of that is $X^*$. You need to use these. 
- M is only one. that doesn't sound right...
- How do you compute the similar/dissimilar neighbours. You need to keep track of 1437x1437 wasserstein distances. Then you need to find the k closest neighbours. Or maybe half of that. Each line has all neighbours of the first position in the line. But what about when i=4 and you're only looking at j > 4


**Build similarity matrix:**
* k_neigh = 3 
* Can you build $w$ differently, i.e. by encoding some prior information that tells you how close r_i and r_j are? 

In [None]:
k_neigh = 3
w = gml.similarity_matrix(X_train_normalized,Y_train,k_neigh)
plt.imshow(w);

In [None]:
matrix_file = "matrix12"
gml_logger = None
M_learned, objective_tracker, product_tracker, interval_tracker = gml.projected_subgradient(
                                X_train_normalized, M_mesh ,w, matrix_file, gml_logger, p_max = 2, q_max = 3, early_stopping = True)

In [None]:
helper.save_ndarray("results_binary/gml12_mesh_vectorized",M_learned)
plt.imshow(M_learned);

In [None]:
helper.save_json(results_path+"gml12_objective",objective_tracker)
plt.plot(objective_tracker);

In [None]:
data  = helper.load_json(results_path+"gml12_objective")

In [None]:
plt.plot(interval_tracker);
plt.plot(product_tracker);

## Below not used, only for debugging purposes
**Cuturi's approach** 
* q_max is set to 80 and the inner loop is run at minimum 24 times
* Other criteria can be change in objective does not change more than 0.75% every 8 steps
    - I need to do this and plot the learning for z_plus and z_threshold. 
    - First do z_threshold (or z_in)
    - where the hell did they get this? 
* A maximum of 20 outer loops - so p_max = 20
* What does the t step do? 

In [None]:
def projected_subgradient(X_train_normalized, M_original_eye, w, p_max, q_max, early_stopping=False):
    X = X_train_normalized
    # M_0 = total_variation()
    M_0 = M_original_eye/(X.shape[0]*(X.shape[0]-1))
    t = 1 
    p = 0
    learning_rate = 0.1
    M_p_outer = M_0    

    #with tqdm.tqdm(total=p_max) as outer_progress_bar:
    #outer_progress_bar.set_description("Outer Loop Progress")
    pbar = tqdm_notebook(total = p_max)
    pbar.set_description("Outer Loop progress")
    objective_tracker = []
    product_tracker = []
    interval_tracker = []
    change_M = []
    z_results = []
    gamma = []
    outer_M = []
    
    q_min = 24 
    interval = 8
    
    while p < p_max: 
        #print("Outer Loop: Iteration {0}".format(p))
        G, T = gml.compute_distances(X_train_normalized, M_p_outer, showProgress=False)
        #outer_M.append(M_p_outer)
        z_plus, gamma_plus = gml.algorithm1_similar(M_p_outer, w, G, T, 3)
        q = 0 
        M_q_inner = M_p_outer
        #while (z_threshold < 0.001) - capture the change in the objective
        #with tqdm.tqdm(total=q_max) as inner_progress_bar:
        #inner_progress_bar.set_description("Inner Loop Progress")
        qbar = tqdm_notebook(total=q_max)
        qbar.set_description("Inner Loop progress")
        while q < q_max:
            #print("Inner Loop: Iteration {0}".format(q))
            G, T = gml.compute_distances(X_train_normalized, M_q_inner, showProgress=False)
            z_minus, gamma_minus = gml.algorithm1_dissimilar(M_q_inner, w, G, T, 3)
            # The following line should be a pure number, the criteria/error where optimizing
            # Please make sure the res value is correctly computed, as the algo doesn't say much. 
            # you should deduce from the paper whether it's a scalar or not, although it should be. 
            # Maybe the transpose was his way of saying this is a sum...
            # Maybe you don't actually have to do any transpose for gamma plus ...
            diff = helper.tril_vector(M_q_inner) - helper.tril_vector(M_p_outer)
            # Also need to see how this diff evolves. why does the spike appear!!!!!
            
            # Basically it starts with previous value somehow.. 
            res = gamma_plus.T.dot(diff)
            product_tracker.append(res)
            z_results.append(z_minus + z_plus)
            
            # This is the final objective list
            z_threshold = z_minus + z_plus + res
            objective_tracker.append(z_threshold)
            
            # Gradient descent 
            #wasserstein.lower_triangular(M_q_inner)
            #gamma.append(gamma_plus+gamma_minus)
            gamma.append(G)
            M_q_inner_lower = helper.tril_vector(M_q_inner) - (learning_rate/np.sqrt(q+1)) * (gamma_plus + gamma_minus)
            change_M.append(M_q_inner_lower)
            
            # UnComment here for full symmetric metric matrix
            M_q_inner = helper.symmetrize_from_vector(M_q_inner_lower,M_0.shape[0])
            
            # Below is only lower diagonal, I don't know in what world this projection would work
            #M_q_inner = wasserstein.matrix_from_vector(M_q_inner_lower,M_0.shape[0])
                
            # Because the output of this algorithm might not be a metric, 
            # use metric nearness to bring it closer to its real properties.
            M_q_inner= gml.project_metric(M_q_inner)    
        
            q += 1
            t +=1
            #inner_progress_bar.update(1)    
            qbar.update(1)
            
            # Early Stopping
            # Stop the inner loop if between 8 steps it didn't progress more than 0.75% 
            if early_stopping and q > q_min: 
                previous_thres = objective_tracker[-interval]
                val = (z_threshold - previous_thres)/previous_thres * 100
                interval_tracker.append(np.abs(val))
                if np.abs(val) < 0.75:
                    break;

            
        print("assign M inner to outer")
        M_p_outer = M_q_inner
        p +=1
        #outer_progress_bar.update(1)
        pbar.update(1)
    return M_p_outer, product_tracker, objective_tracker, interval_tracker, change_M, z_results, gamma, outer_M


_____

___

___

In [None]:
# This is just to show that there are numerical errors when computing the coupling/distances. 
# In the algorithms, I just say entry at i,j equals j,i, so T1 == T2 
r = X_train_normalized[1]
c = X_train_normalized[3]
m = X_train_normalized[2]
n = X_train_normalized[4]

G1 = wasserstein.distance(r,c, **{'ground metric': M_original_eye})
[T1,M1] = wasserstein.coupling(r,c,**{'ground metric': M_original_eye})

[T2,M2] = wasserstein.coupling(c,r,**{'ground metric': M_original_eye})


print('Distance is: {0}'.format(G1))

# Check that T1 is correct - why not if so? In my algo I take them as being equal as it should be
rec_r = np.sum(T1, axis=1)
rec_c = np.sum(T1,axis=0)

# Check that distance is the same as the Frobenius product 
testing.assert_allclose(np.sum(T1 * M1),G1)

# Should they not be symmetric?
(np.transpose(T1) != T2).sum()

**Calculate all $W_2$ distances for the dataset as well as all $T$** 
Can it be optimized even more?

In [None]:
G_all = wasserstein.all_distances(X_train_normalized,M_eye)
T_all = wasserstein.all_couplings(X_train_normalized,M_eye,True)

In [7]:
import multiprocessing
n_proc = multiprocessing.cpu_count()
n_proc

4

**Sanity check for some properties**

In [None]:
# Equal to T2 since we've done lower diagonal computation
(T_all[1,3] != T2).sum()

In [None]:
# Symmetry Test
ot_testing.is_metric(G_all)

# Triangle inequality
print(G_all[1][2] + G_all[2][3] >= G_all[1][3])

**Compose subsets of similar and dissimilar neighbours, everything below is in the algo1 method**

In [None]:
# Find neighbours
S_similar = gml.similar_neighbours(3,w,G_all)
S_dissimilar = gml.dissimilar_neighbours(3,w,G_all)

### A bit more debugging that was written prior to the Testing framework

In [None]:
S_10_sum =  gml.sum_neighbours(S_10)

In [None]:
val = np.argsort(G_10[3,:],kind='mergesort')
val1  = np.argsort(G_10[5,:],kind='mergesort')

In [None]:
S_10 = gml.similar_neighbours(3,w_10,G_10)
grad_10 = gml.gradient_neighbours(w_10,T_10,S_10)
grad_10[1][0][0].shape
# --- In the method -- 
# helper.triu_vector(gamma_plus) + helper.tril_vector(gamma_plus)
plt.imshow(helper.triu(grad_objective_plus))

In [None]:
a1 = helper.triu(grad_objective_plus)
a = helper.tril(grad_objective_plus)

b = helper.tril(grad_objective_minus)
b1 = helper.triu(grad_objective_minus)

In [None]:
z_plus, gamma_plus = gml.algorithm1_similar(M_original_eye,X_train_normalized,w,G_all,T_all,3)
z_minus, gamma_minus = gml.algorithm1_dissimilar(M_original_eye,X_train_normalized,w, G_all, T_all, 3)

In [None]:
val = wasserstein.matrix_from_vector(gamma_plus,64)
diff_zero = wasserstein.tril_vector(M_original_eye) - wasserstein.tril_vector(M_original)

res1 = gamma_plus.T.dot(diff_zero)
result = np.dot(gamma[].T, diff_zero)
np.trace(result)

In [None]:
res = np.trace(gamma_plus.dot(diff_zero))

upper = grad_objective_plus[np.triu_indices(64,1)]
upper.shape
lower = np.tril(grad_objective_plus,-1)