Skip to content

Commit

Permalink
cosmit
Browse files Browse the repository at this point in the history
  • Loading branch information
ogrisel committed Sep 16, 2011
1 parent 9a0db43 commit 6ec4e1b
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions examples/decomposition/plot_img_denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@
import numpy as np

from sklearn.decomposition import DictionaryLearningOnline
from sklearn.feature_extraction.image import extract_patches_2d, \
reconstruct_from_patches_2d
from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.feature_extraction.image import reconstruct_from_patches_2d

###############################################################################
# Load Lena image and extract patches

lena = sp.lena() / 256.0

# downsample for higher speed
Expand All @@ -56,20 +57,23 @@

# Extract all clean patches from the left half of the image
print 'Extracting clean patches...'
t0 = time()
patch_size = (7, 7)
data = extract_patches_2d(distorted[:, :height / 2], patch_size)
data = data.reshape(data.shape[0], -1)
data -= np.mean(data, axis=0)
data /= np.std(data, axis=0)
print 'done in %.2fs.' % (time() - t0)

###############################################################################
# Learn the dictionary from clean patches
print 'Learning the dictionary... ',

print 'Learning the dictionary... '
t0 = time()
dico = DictionaryLearningOnline(n_atoms=100, alpha=1e-2, n_iter=500)
V = dico.fit(data).components_
dt = time() - t0
print 'done in %.2f.' % dt
print 'done in %.2fs.' % dt

pl.figure(figsize=(4.2, 4))
for i, comp in enumerate(V[:100]):
Expand All @@ -83,6 +87,8 @@
fontsize=16)
pl.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)

###############################################################################
# Display the distorted image

def show_with_diff(image, reference, title):
"""Helper function to display denoising"""
Expand All @@ -103,29 +109,33 @@ def show_with_diff(image, reference, title):
pl.suptitle(title, size=16)
pl.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)

###############################################################################
# Display the distorted image
show_with_diff(distorted, lena, 'Distorted image')

###############################################################################
# Extract noisy patches and reconstruct them using the dictionary

print 'Extracting noisy patches... '
t0 = time()
data = extract_patches_2d(distorted[:, height / 2:], patch_size)
data = data.reshape(data.shape[0], -1)
intercept = np.mean(data, axis=0)
data -= intercept
print 'done in %.2fs.' % (time() - t0)

transform_algorithms = [
('Orthogonal Matching Pursuit\n1 atom', 'omp',
{'transform_n_nonzero_coefs': 1}),
('Orthogonal Matching Pursuit\n2 atoms', 'omp',
{'transform_n_nonzero_coefs': 2}),
('Least-angle regression\n5 atoms', 'lars', {'transform_n_nonzero_coefs': 5}),
('Thresholding\n alpha=0.1', 'threshold', {'transform_alpha': .1})]
('Orthogonal Matching Pursuit\n1 atom',
'omp', {'transform_n_nonzero_coefs': 1}),
('Orthogonal Matching Pursuit\n2 atoms',
'omp', {'transform_n_nonzero_coefs': 2}),
('Least-angle regression\n5 atoms',
'lars', {'transform_n_nonzero_coefs': 5}),
('Thresholding\n alpha=0.1', 'threshold',
{'transform_alpha': .1}),
]

reconstructions = {}
for title, transform_algorithm, kwargs in transform_algorithms:
print title, '... ',
print title, '... '
reconstructions[title] = lena.copy()
t0 = time()
dico.set_params(transform_algorithm=transform_algorithm, **kwargs)
Expand All @@ -144,7 +154,7 @@ def show_with_diff(image, reference, title):
reconstructions[title][:, height / 2:] = reconstruct_from_patches_2d(
patches, (width, height / 2))
dt = time() - t0
print 'done in %.2f.' % dt
print 'done in %.2fs.' % dt
show_with_diff(reconstructions[title], lena,
title + ' (time: %.1fs)' % dt)

Expand Down

0 comments on commit 6ec4e1b

Please sign in to comment.