In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg as la
import scipy.cluster.hierarchy as hier

from sklearn.manifold import TSNE
from collections import Counter
import plotly
import plotly.plotly as py
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


In [None]:
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.global_variables_initializer())
y = tf.matmul(x,W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
for _ in range(1000):
  batch = mnist.train.next_batch(100)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

In [None]:
fig = plt.figure(figsize=(20,20)) 
imshape = (28,28)
ncols=12
nrows=1
for j in range(nrows):
    for i in range(ncols):
        imid = j*ncols+i
        imdata = np.reshape(batch[0][imid,:], imshape)
        plt.subplot(nrows,ncols,imid+1)
        plt.imshow(imdata, cmap='gray')

print('Batch samples:')
plt.show()

# Procrustes on image as a set of vectors

In [None]:
imid = 2
imid2= 5
imshape = (28,28)
imdata = np.reshape(batch[0][imid,:], imshape)
imlabel = np.argmax(batch[1][imid,:])
fig = plt.figure(figsize=(9,9)) 
#fig.suptitle(str(imlabel))

ax = plt.subplot(1,3,1)
ax.set_title('input')
plt.imshow(imdata, cmap='gray')
ax = plt.subplot(1,3,2)
ax.set_title('target')
#rand_dest = np.random.rand(imshape[0], imshape[1])
rand_dest = np.reshape(batch[0][imid2,:], imshape)
plt.imshow(rand_dest, cmap='gray')
ax = plt.subplot(1,3,3)
ax.set_title('procrustes rotation of input to target')
R = la.orthogonal_procrustes(imdata, rand_dest)
normR = la.norm(R[0],'fro')
plt.imshow(np.matmul(imdata, R[0]), cmap='gray')

plt.show()

# Procrustes on vectorized image

In [None]:
def vec_procrustes(v1, v2):
    w1 = np.vstack((v1.flatten(), [0]*784))
    w2 = np.vstack((v2.flatten(), [0]*784))
    R = la.orthogonal_procrustes(w1, w2)
    return R

In [None]:
imid = 0
imshape = (28,28)
imdata = np.reshape(batch[0][imid,:], imshape)
imlabel = np.argmax(batch[1][imid,:])
fig = plt.figure(figsize=(9,9)) 
#fig.suptitle(str(imlabel))

ax = plt.subplot(1,4,1)
ax.set_title('input')
plt.imshow(imdata, cmap='gray')
ax = plt.subplot(1,4,2)
ax.set_title('target')
#rand_dest = np.random.rand(imshape[0], imshape[1])
rand_dest = np.reshape(batch[0][1,:], imshape)
plt.imshow(rand_dest, cmap='gray')
ax = plt.subplot(1,4,3)
ax.set_title('full rotation')
R = vec_procrustes(imdata.flatten(), rand_dest.flatten())
normR = la.norm(R[0],'fro')
plt.imshow(np.reshape(np.matmul(imdata.flatten(), R[0]), imshape), cmap='gray')
ax = plt.subplot(1,4,4)
ax.set_title('partial rotation')
plt.imshow(np.reshape(np.matmul(imdata.flatten(), R[0]/(10000*normR)), imshape), cmap='gray')

plt.show()

As we can see this is too ideal.

# MNIST procrustes clusters

In [None]:
D = np.empty([len(batch[0]),len(batch[0])])
for i in range(len(batch[0])):
    im1 = np.reshape(batch[0][i,:], imshape)
    for j in range(len(batch[0])):
        im2 = np.reshape(batch[0][j,:], imshape)
        R = la.orthogonal_procrustes(im1, im2)
        im1fit = np.matmul(im1, R[0])
        D[i,j] = la.norm(im1fit-im2, 'fro')
print('Done computing pairwise procrustes distance')        

In [None]:
Z = hier.linkage(D, "average")
DG=hier.dendrogram(Z, orientation='right')
index = DG['leaves']
sD = D[index,:]
sD = sD[:,index]
fig = plt.figure()
plt.imshow(sD)
plt.axis('off')
plt.show()

In [None]:
clusts=hier.fcluster(Z, 10, criterion="maxclust")
sizes=[len([idx for idx,c in enumerate(clusts) if c == i]) for i in range(1,11)]

samples_per_clust=4
tclusts=[cid+1 for cid,csize in enumerate(sizes) if csize>samples_per_clust]
fig = plt.figure()
tcid=-1
for cid in tclusts:
    tcid+=1
    samples = [idx for idx,c in enumerate(clusts) if c == cid]
    for sampleid in range(samples_per_clust):
        imid0=samples[0]
        imid=samples[sampleid+1]
        fig.add_subplot(len(tclusts),samples_per_clust,tcid*samples_per_clust + sampleid+1)
        #plt.title(str(imid))
        plt.imshow(np.reshape(batch[0][imid,:], imshape))
        plt.axis('off')
plt.show()

In [None]:
init_notebook_mode(connected=False)
model = TSNE(n_components=2, random_state=0, verbose=1, n_iter=10000, perplexity=8, metric='precomputed')
Y = model.fit_transform(D)
tY = Y.transpose()

labels = np.argmax(batch[1],axis=1)
label_counter = Counter(labels)
traces = []
for k in label_counter.keys():
    trace = go.Scatter(
        x=tY[0][labels == k],
        y=tY[1][labels == k],
        name=str(k),
        text=str(k),
        marker=dict(size=10,
                    color=k,
                    colorscale='Viridis'),
        mode='markers',
    )
    traces.append(trace)

layout = go.Layout(hovermode='closest')

fig = go.Figure(data=traces, layout=layout)
plotly.offline.iplot(fig, filename='jupyter/tsne')


In [127]:
traces

[{'marker': {'color': 0, 'colorscale': 'Viridis', 'size': 10},
  'mode': 'markers',
  'name': '0',
  'text': '0',
  'type': 'scatter',
  'x': array([ 110.94037026,  105.93727534,  -34.27380897,    2.77196804,
           51.81325675,  100.71843206]),
  'y': array([ 29.04610036,   9.16433949, -11.94727277, -27.13717048,
          52.73696481,   2.03713836])},
 {'marker': {'color': 1, 'colorscale': 'Viridis', 'size': 10},
  'mode': 'markers',
  'name': '1',
  'text': '1',
  'type': 'scatter',
  'x': array([-75.90864145, -29.95021533,  -2.85744427,   6.95619   ,
         -48.35807533, -82.81192246, -47.15320259,  -0.34858212]),
  'y': array([ -55.31248079,  104.49703375,   28.11989417,  -72.92693666,
          -84.46477695,  -45.44396999,  101.93288579, -104.46254296])},
 {'marker': {'color': 2, 'colorscale': 'Viridis', 'size': 10},
  'mode': 'markers',
  'name': '2',
  'text': '2',
  'type': 'scatter',
  'x': array([ 57.7864092 , -36.18104249,  -6.12286558,  43.27978502,
         -11.8817