# Sparse Coding

Author(s): Raj Magesh Gauthaman (rgautha1@jh.edu)

If you find any bugs in this notebook, please email Raj (rgautha1@jh.edu)!

---

Our goal is to obtain a dictionary of basis vectors that form a sparse encoding of natural images. Note that we're using the term "basis" loosely: the [definition of a basis](https://en.wikipedia.org/wiki/Basis_(linear_algebra)) requires its vectors to be linearly independent and span the entire space; here, we impose no such constraints.

We randomly sample patches $\{P_n\}_{n \in \mathbb{Z}}$ from the [ten $512 \times 512$ natural images](http://www.rctn.org/bruno/sparsenet/) used in [Olshausen & Field (1996)](https://www.nature.com/articles/381607a0). Note that the images are standardized, whitened, and low-pass filtered (similar to the procedure described by Olshausen & Field) and that each patch is independently standardized. Each patch has size $p \times p$, making each basis vector $p^2$-dimensional. If $m > p^2$, the basis is *overcomplete*. Our $m$ $p^2$-dimensional basis vectors form a dictionary $X \in \mathbb{R}^{p^2 \times m}$. Our goal is to represent each flattened image patch $P \in \mathbb{R}^{p^2}$ as a linear combination of these basis vectors (i.e., $P \approx XZ$) with sparse coefficients $Z \in \mathbb{R}^{m}$. These considerations inform the loss function we'd like to minimize:

$$L(X, Z) = \frac{1}{2} \lVert P - XZ \rVert^2_2 + \lambda |Z|$$

Here, the first term is the $L_2$-norm of the reconstruction error induced when the original image patch $P$ is reconstructed as $\hat{P} = XZ$. The second term is the $L_1$-norm of the coefficients $Z$. The regularization parameter $\lambda$ controls the tradeoff between sparsity and the reconstruction error. Minimizing $L$ with respect to $Z$ would produce a sparse representation $Z$. We implement this optimization using the coordinate descent algorithm described in [Gregor & LeCun (2010)](https://icml.cc/Conferences/2010/papers/449.pdf) and introduced by [Li & Osher (2009)](https://doi.org/10.3934/ipi.2009.3.487). However, our goal is to learn the dictionary $X$, which involves minimizing $L$ with respect to $X$. We can derive this gradient analytically:

$$\nabla_X L(X, Z) = -(P - XZ)Z^T$$

We optimize $X$ iteratively using stochastic gradient descent:

1. We initialize $X_0$ with values drawn from a uniform distribution on $[-0.5, 0.5)$.
2. At each iteration $i$, we
   1. sample a batch of image patches $P_j$, $j \in [b]$, where $b$ is the batch size,
   2. use coordinate descent to find the optimal sparse representation $Z_j$ for each image patch $P_j$ in terms of the basis vectors $X_i$,
   3. compute the gradient of the loss for each patch as $\nabla_X L_j(X, Z_j) = -(P_j - X_iZ_j)Z_j^T$,
   4. update $X$ as $X_{i+1} = X_i - \dfrac{\eta}{b} \sum_j \nabla_X L_j(X, Z_j) = X_i + \dfrac{\eta}{b} \sum_j (P_j - X_iZ_j)Z_j^T$, where $\eta$ is the learning rate, and
   5. normalize the basis vectors (columns) in  $X_i$ to unit length.

Note: Uncomment and run the following cell if you are using Google Colab.

In [None]:
# !pip install torchdata

## Utilities

In [4]:
from __future__ import annotations
from collections.abc import Collection, Iterable
from pathlib import Path
import uuid
import math
import random

from tqdm import tqdm
import requests
import numpy as np
import torch
from torchvision import transforms as tr
from torchdata.datapipes.iter import IterDataPipe, Mapper, IterableWrapper, Shuffler
from scipy.io import loadmat
from scipy.sparse.linalg import cg
from matplotlib import pyplot as plt
from PIL import Image


def download(
    url: str,
    *,
    filepath: Path = None,
    stream: bool = True,
    allow_redirects: bool = True,
    chunk_size: int = 1024**2,
    force: bool = True,
) -> Path:
    """Download a file from a URL."""

    if filepath is None:
        filepath = Path("/tmp") / f"{uuid.uuid4()}"
    elif filepath.exists():
        if not force:
            return filepath
        else:
            filepath.unlink()

    r = requests.Session().get(url, stream=stream, allow_redirects=allow_redirects)
    with open(filepath, "wb") as f:
        for chunk in r.iter_content(chunk_size=chunk_size):
            f.write(chunk)

    return filepath


def download_images(
    url: str = "http://www.rctn.org/bruno/sparsenet/IMAGES_RAW.mat",
    directory: Path = Path("data/sparse_coding/images"),
) -> list[Image.Image]:
    """Download the images used in Olshausen & Field (1996)."""

    directory.mkdir(parents=True, exist_ok=True)
    filepath = download(url, filepath=directory / "images.mat", force=False)
    images = loadmat(filepath, simplify_cells=True)["IMAGESr"]

    paths = []

    for i_image in range(10):
        image = images[..., i_image]
        min_ = image.min()
        max_ = image.max()
        image = (image - min_) / (max_ - min_) * 255
        image = Image.fromarray(image).convert("L")

        path = directory / f"image_{i_image:02}.png"
        paths.append(path)
        image.save(path, format="PNG")

    return paths


def load_images(paths: Collection[Path]) -> list[Image.Image]:
    """Load images from disk."""
    return [Image.open(path) for path in paths]


def display_images(images: Collection[Image.Image | torch.Tensor]) -> None:
    """Display images"""
    n = len(images)
    nrows = int(math.sqrt(n))
    ncols = math.ceil(n / nrows)

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
    for image, ax in zip(images, axes.flat):
        if isinstance(image, torch.Tensor):
            image = image.cpu().numpy()
        ax.imshow(image, cmap="gray")
    for ax in axes.flat:
        ax.axis("off")
    fig.show()


def coordinate_descent(
    *,
    w: torch.Tensor,
    x: torch.Tensor,
    alpha: float,
    max_iterations: int = 1000,
    tolerance: float = 1e-4,
) -> torch.Tensor:
    """Coordinate descent algorithm described in Gregor & LeCun (2010).

    Minimizes L(z) = 1/2 (x - wz)^2 + alpha * |z| with respect to z. Supports batches of data x.

    References:

    * Gregor & LeCun (2010), https://icml.cc/Conferences/2010/papers/449.pdf, Algorithm 2, Coordinate Descent [Pg 3]
    * Li & Osher (2009), https://doi.org/10.3934/ipi.2009.3.487, Algorithm (Coordinate descent with a refined sweep) [Pg 6]

    Args:
        w: (n, m) dictionary of m n-dimensional basis vectors
        x: n-dimensional data or (n, batch_size) batch of data
        alpha: Regularization strength, controls sparsity of the solution z
        max_iterations: Maximum number of iterations. Defaults to 1000.
        tolerance: Tolerance for early stopping. Defaults to 1e-4.

    Returns:
        z: (m, batch_size) coefficients of the data x under the basis w
    """
    device = w.device
    assert w.ndim == 2, "w must be 2-dimensional"
    x = x.unsqueeze(-1) if x.ndim == 1 else x

    s = torch.eye(w.shape[-1], device=device) - w.t() @ w
    b = w.t() @ x
    z = torch.zeros_like(b, device=device)

    slicer = torch.arange(x.shape[-1])

    for iteration in range(max_iterations):
        z_shrunk = torch.nn.functional.softshrink(b, lambd=alpha)
        diff = z_shrunk - z
        diff_abs = diff.abs()
        k = diff_abs.argmax(dim=0)
        b += s[..., k] * diff[k, slicer]
        z[k, slicer] = z_shrunk[k, slicer]

        if diff_abs[k, slicer].max() < tolerance:
            break

    z = torch.nn.functional.softshrink(b, lambd=alpha)
    return z


def create_sinusoid(
    *,
    size: tuple[int, int] = (32, 32),
    amplitude: float = 1,
    frequency: float = 0.1,
    angle: float = 0,
    phase: float = 0,
) -> torch.Tensor:
    """Creates a 2D sinusoid.

    Args:
        size: Size of the sinusoid, defaults to (32, 32)
        amplitude: Amplitude of the sinusoid, defaults to 1
        frequency: Frequency of the sinusoid, defaults to 0.1
        angle: Angle of the sinusoid, defaults to 0 (horizontal)
        phase: Phase of the sinusoid, defaults to 0 (sine)

    Returns:
        2D sinusoid
    """
    angle = torch.tensor(angle)
    w_x, w_y = frequency * torch.cos(angle), frequency * torch.sin(angle)
    width, height = size
    radius = (width // 2, height // 2)
    [x, y] = torch.meshgrid(
        torch.arange(-radius[0], radius[0] + 1),
        torch.arange(-radius[1], radius[1] + 1),
        indexing="ij",
    )

    return amplitude * torch.cos(w_x * x + w_y * y + phase)


def create_sinusoids(
    *,
    frequency: float = 0.1,
    size: tuple[int, int] = (32, 32),
    n: int = 10,
) -> list[torch.Tensor]:
    """Sample sinusoids of a given frequency.

    The amplitudes, phases, and angles of the sinuoids are sampled from [0, 1), [0, 2 * pi), and [0, pi) respectively.

    Args:
        frequency: Frequency of the sinusoids. Defaults to 0.1.
        size: Size of the sinusoids. Defaults to (32, 32).
        n: Number of sinuoids to create. Defaults to 10.

    Returns:
        Sinusoids
    """
    sinusoids = []
    for _ in range(n):
        amplitude = random.random()
        phase = 2 * math.pi * random.random()
        angle = math.pi * random.random()
        sinusoids.append(
            create_sinusoid(
                size=size,
                amplitude=amplitude,
                frequency=frequency,
                angle=angle,
                phase=phase,
            )
        )
    return sinusoids


def extract_patches_from_images(
    images: Iterable[Image.Image | torch.Tensor],
    *,
    patch_size: int = 8,
    n_patches_per_image: int = 10,
    n_images_per_batch: int = 10,
) -> IterDataPipe:
    """Extract square patches from a stream of images.

    Images are standardized and preprocessed (whitened and low-pass filtered) according to Olshausen & Field (1996). Square patches are extracted from each image and individually standardized.

    References:

        * Olshausen & Field (1996), https://doi.org/10.1038/381607a0
        * Olshausen & Field (1997), https://doi.org/10.1016/S0042-6989(97)00169-7

    Args:
        images: Images to extract patches from
        patch_size: Size of each square patch. Defaults to 8.
        n_patches_per_image: Number of patches to extract from each image in the batch. Defaults to 10.
        n_images_per_batch: Number of images to use in each batch. Defaults to 10.

    Returns:
        Datapipe that returns batches of patches of shape (patch_size ** 2, batch_size)
    """

    def _preprocess_image(image: torch.Tensor) -> torch.Tensor:
        assert image.shape[-1] == image.shape[-2], "image must be square"
        size = image.shape[-1]
        f0 = 0.4 * size
        f1 = torch.fft.fftfreq(size) * size
        f2 = torch.fft.rfftfreq(size) * size
        f1, f2 = torch.meshgrid(f1, f2, indexing="ij")
        r = torch.sqrt(f1**2 + f2**2)
        filter_ = r * torch.exp(-((r / f0) ** 4))
        return torch.fft.irfft2(filter_ * torch.fft.rfft2(image))

    def _collate_fn(
        images: Collection[Image.Image | torch.Tensor],
        n_patches_per_image: int = n_patches_per_image,
    ) -> torch.Tensor:
        transform = tr.Compose(
            [
                tr.ToTensor(),
                tr.Normalize(mean=(0,), std=(1,)),
            ]
        )
        crop = tr.RandomCrop(size=patch_size)

        patches = []
        for image in images:
            if isinstance(image, Image.Image):
                image = transform(image)
            elif isinstance(image, torch.Tensor):
                image = (image - image.mean()) / image.std()
            image = _preprocess_image(image)
            patches.extend([crop(image).flatten() for _ in range(n_patches_per_image)])
        patches = torch.stack(patches, dim=-1)
        return (patches - patches.mean(dim=0, keepdim=True)) / patches.std(
            dim=0, keepdim=True
        )

    return (
        IterableWrapper(images)
        .cycle()
        .shuffle(buffer_size=n_images_per_batch)
        .batch(batch_size=n_images_per_batch)
        .collate(_collate_fn)
        .in_memory_cache()
    )


def learn_sparse_encoding(
    datapipe: IterDataPipe,
    *,
    regularization_strength: float = 0.5,
    dictionary: torch.Tensor = None,
    dictionary_size: int = 64,
    learning_rate: float = 0.1,
    n_batches: int = 100,
    device: torch.device | str | None = None,
    **kwargs,
) -> torch.Tensor:
    """Learn a dictionary of basis functions for sparse encoding for images.

    Args:
        datapipe: Datapipe that returns batches of images of shape (n_pixels, batch_size)
        regularization_strength: Parameter that controls the sparsity of the encoding. Defaults to 0.5.
        dictionary: (Optional) Dictionary to initialize the optimization with (hot-start). Defaults to None.
        dictionary_size: Number of basis vectors to include in the dictionary. Ignored if dictionary is provided. Defaults to 64.
        learning_rate: Learning rate for stochastic gradient descent. Defaults to 0.1.
        n_batches: Number of batches to sample from the datapipe. Defaults to 100.
        device: PyTorch device (CPU/GPU) to run the optimization on
        **kwargs: Passed to coordinate_descent().

    Returns:
        (n_pixels, dictionary_size) dictionary containing basis vectors
    """
    if device is not None:
        device = torch.device(device)
    else:
        device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )

    for batch, patches in tqdm(
        zip(range(n_batches), datapipe), desc="iteration", leave=False, total=n_batches
    ):
        if dictionary is not None:
            assert (
                dictionary.shape[0] == patches.shape[0]
            ), f"dictionary must have {patches.shape[0]}-dimensional basis vectors"
            dictionary = dictionary.to(device)
        else:
            dictionary = (
                torch.rand(size=(patches.shape[0], dictionary_size), device=device)
                - 0.5
            )

        dictionary /= dictionary.norm(dim=0, keepdim=True)
        patches = patches.to(device)
        coefficients = coordinate_descent(
            w=dictionary, x=patches, alpha=regularization_strength, **kwargs
        )
        dictionary += learning_rate * torch.einsum(
            "bi,bj->bij",
            ((patches.to(device) - dictionary @ coefficients).t(), coefficients.t()),
        ).mean(dim=0)

    return dictionary


def display_dictionary(dictionary: torch.Tensor) -> None:
    """Display a dictionary of basis vectors."""
    patch_size = int(math.sqrt(dictionary.shape[0]))
    images = [
        dictionary[:, _].clone().reshape(patch_size, -1)
        for _ in range(dictionary.shape[-1])
    ]
    display_images(images)

## Q1: Sparse coding of natural images (10 points)

In this section, you will learn sparse encodings of natural images using the images from the original [Olshausen & Field (1996) study](https://www.nature.com/articles/381607a0).

- Plot the basis functions learnt from these images for each of these dictionary sizes: 25, 64, 100. (3 points)
- Describe the basis functions learnt. (3 points)
- Why do they look this way? (4 points)

Feel free to alter the default parameters! If any parameter combination gives significantly better results than the defaults (or speeds up the convergence), please report them in your answer so that we can update the defaults for future semesters.

In [None]:
images = load_images(download_images())
display_images(images)

dictionary = learn_sparse_encoding(
    datapipe=extract_patches_from_images(images=images),
    dictionary_size=64,
    n_batches=2000,
)

display_dictionary(dictionary)

## Q2: Sparse coding of sinusoids (10 points)

In this section, you will learn sparse encodings for sinusoids of different frequencies.

- Learn dictionaries of basis functions for sinsuoids of (i) low-frequency (e.g. 0.5) and (ii) high-frequency (e.g. 3) using different dictionary sizes (4, 36, and 64). Plot these basis functions. (3 points)
- What variations do you observe when changing frequency and dictionary size? (3 points)
- Why do these variations arise? (4 points)

In [None]:
sinusoids = create_sinusoids(frequency=3)
display_images(sinusoids)

dictionary = learn_sparse_encoding(
    datapipe=extract_patches_from_images(images=sinusoids),
    dictionary_size=64,
    n_batches=100,
)

display_dictionary(dictionary)