In [4]:
import numpy as np
from sklearn.cluster import KMeans
from sklearn import metrics
from sklearn.datasets.samples_generator import make_blobs
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import time
import sys
from IPython.display import display, clear_output
from sklearn.feature_extraction.image import PatchExtractor
from sklearn.datasets import load_sample_image
from sklearn.decomposition import PCA
from sklearn.decomposition import dict_learning_online
from sklearn.decomposition import MiniBatchDictionaryLearning

#Set plots to inline
%matplotlib inline

#Function for plotting arrays of images
def plot_gallery(images, h, w, n_row=3, n_col=4):
    """Helper function to plot a gallery of portraits"""
    plt.figure(figsize=(0.9 * n_col, 1.2 * n_row))
    plt.subplots_adjust(bottom=0, left=.01, right=.99, top=.90, hspace=.35)
    for i in range(n_row * n_col):
        plt.subplot(n_row, n_col, i + 1)
        if(len(images[i].shape)==1):
          plt.imshow(images[i].reshape((h, w)), cmap=plt.cm.gray, interpolation='nearest')
          plt.clim(-2,2)  
        else: 
          plt.imshow(images[i], cmap=plt.cm.gray, interpolation='nearest')
        plt.xticks(())
        plt.yticks(())



Dimensionality Reduction
===

Prepare the Image Patch Data
---

In [5]:
#Load image and convert to grayscale
img  = load_sample_image('flower.jpg')
gray = np.mean(img,axis=2)
gray = gray/np.max(gray.data)
#gray = gray + 0.05*np.random.randn(*gray.shape)


#Extract P patches of size SxS
S=10;P=8000;
pex      = PatchExtractor(patch_size=(S,S), max_patches=P)
patches  = pex.transform(gray[np.newaxis,:,:])
X        = patches.reshape(P,S*S)

#Filter out patches that are too uniform
ind      = np.std(X,axis=1)>0.1
X        = X[ind,:]

Xtest = X[:len(X)/2]
X = X[len(X)/2:]

m = np.mean(X,0)
s = np.std(X,0)
X=(X-m)/s
Xtest= (Xtest-m)/s

#Show the original image
plt.figure(1);
plt.imshow(gray,cmap='gray')

#Show the patches
plt.figure(2);
plot_gallery(Xtest[0:20], S, S,4,5)

h=S; w=S
print X.shape
print Xtest.shape

ImportError: The Python Imaging Library (PIL) is required to load data from jpeg files

Run PCA and Show Learned Basis
---

In [None]:
K=25;
pca        = PCA(n_components=K, whiten=False).fit(X)
components = pca.components_.reshape((K, h, w))

Z = pca.transform(Xtest)
Xhat = Z.dot(pca.components_)

err  = np.mean(np.mean((Xhat-Xtest)**2))
print "Mean Reconstruction Error Per Pixel: %.4f"%(err,)

Xhat   = Xhat.reshape((Xhat.shape[0], h, w))
Xtest2 = Xtest.reshape((Xtest.shape[0], h, w))

#Plot basis
plot_gallery(components, h, w, 5,5)
print("PCA Basis")
plt.show()

#Plot representation
plt.figure(2,figsize=(4,3))
plt.imshow(Z[:25,:],interpolation="nearest",cmap="bwr")
plt.colorbar()
plt.clim(-5,5)
plt.title("PCA Representation")
plt.xlabel("PCA Coefficient")
plt.ylabel("Data Instance")

#Plot reconstructions
plot_gallery(np.hstack((Xtest2[:25],Xhat[:25])), h, w, 5,5)
plt.show()
print("True and Approximated Patches")





Run Sparse Coding and Show the Basis
---

In [None]:
K=50;
falpha=1; #value of alpah used to fit model
talpha=1; #value of alpha used to tranform data cases

sc         = MiniBatchDictionaryLearning(n_components=K, alpha=falpha, transform_alpha=talpha,batch_size=100, fit_algorithm="cd",transform_algorithm="lasso_cd",verbose=True,n_iter=100).fit(X)
components = sc.components_.reshape((K, h, w))

Z = sc.transform(Xtest)
Xhat = Z.dot(sc.components_)
err  = np.mean(np.mean((Xhat-Xtest)**2))
print "\nMean Reconstruction Error Per Pixel: %.4f"%(err,)

Xhat   = Xhat.reshape((Xhat.shape[0], h, w))
Xtest2 = Xtest.reshape((Xtest.shape[0], h, w))

#Plot basis
plot_gallery(components, h, w, 5,5)
print("Sparse Coding Basis")
plt.show()

#Plot representation
plt.figure(2,figsize=(4,3))
plt.imshow(Z[:25,:],interpolation="nearest",cmap="bwr")
plt.colorbar()
plt.clim(-6,6)
plt.title("SC Representation")
plt.xlabel("SC Coefficient")
plt.ylabel("Data Instance")

#Plot reconstructions
plot_gallery(np.hstack((Xtest2[:25],Xhat[:25])), h, w, 5,5)
plt.show()
print("True and Approximated Patches")

