# Lecture 15: Class demo

## Let's cluster images!!

For this demo, I'm going to use two image datasets: 
1. A small subset of [200 Bird Species with 11,788 Images](https://www.kaggle.com/datasets/veeralakrishna/200-bird-species-with-11788-images) (available [here](../data/birds.zip))
2. A tiny subset of [Food-101](https://www.kaggle.com/datasets/kmader/food41?select=food_c101_n10099_r32x32x1.h5)
(available [here](../data/food.zip))


To run the code below, you need to install pytorch and torchvision in the course conda environment. 

```conda install pytorch torchvision -c pytorch```

In [None]:
import os
import pathlib
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torchvision import datasets, models, transforms, utils
from torchvision.models import vgg16

In [None]:
import torchvision

Let's start with  small subset of birds dataset. You can experiment with a bigger dataset if you like.

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [None]:
set_seed(seed=42)

In [None]:
import glob

IMAGE_SIZE = 200


def read_img_dataset(data_dir):
    data_transforms = transforms.Compose(
        [
            transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    image_dataset = datasets.ImageFolder(root=data_dir, transform=data_transforms)
    dataloader = torch.utils.data.DataLoader(
        image_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
    )
    dataset_size = len(image_dataset)
    class_names = image_dataset.classes
    inputs, classes = next(iter(dataloader))
    return inputs, classes

In [None]:
def plot_sample_imgs(inputs):
    plt.figure(figsize=(10, 70))
    plt.axis("off")
    plt.title("Sample Training Images")
    plt.imshow(
        np.transpose(utils.make_grid(inputs, padding=1, normalize=True), (1, 2, 0))
    );

In [None]:
def get_cluster_images(model, Z, inputs, cluster=0, img_shape=(3, 200, 200), n_img=5):
    fig, axes = plt.subplots(
        1,
        n_img + 1,
        subplot_kw={"xticks": (), "yticks": ()},
        figsize=(10, 10),
        gridspec_kw={"hspace": 0.3},
    )
    transpose_axes = (1, 2, 0)

    if type(model).__name__ == "KMeans":
        center = model.cluster_centers_[cluster]
        dists = np.linalg.norm(Z - center, axis=1)
        # mask = model.labels_ == cluster
        # dists = np.sum((Z - center) ** 2, axis=1)
        # dists[~mask] = np.inf
        closest_index = np.argmin(dists)
        inds = np.argsort(dists)[:n_img]
        print(closest_index)
        if Z.shape[1] == 1024:
            axes[0].imshow(
                np.transpose(
                    inputs[closest_index].reshape(img_shape) / 2 + 0.5, transpose_axes
                )
            )
            # axes[0].imshow(center.reshape((32,32)))
        else:
            axes[0].imshow(
                np.transpose(center.reshape(img_shape) / 2 + 0.5, transpose_axes)
            )
        axes[0].set_title("Cluster center %d" % (cluster))
    if type(model).__name__ == "GaussianMixture":
        center = model.means_[cluster]
        cluster_probs = model.predict_proba(Z)[:, cluster]
        inds = np.argsort(cluster_probs)[-n_img:]
        dists = np.linalg.norm(Z - center, axis=1)
        # Find the index of the closest feature vector to the mean
        closest_index = np.argmin(dists)
        if Z.shape[1] == 1024:
            axes[0].imshow(
                np.transpose(
                    inputs[closest_index].reshape(img_shape) / 2 + 0.5, transpose_axes
                )
            )
        else:
            axes[0].imshow(
                np.transpose(center.reshape(img_shape) / 2 + 0.5, transpose_axes)
            )
        # axes[0].imshow(np.transpose(inputs[inds[0]].reshape(img_shape) / 2 + 0.5, transpose_axes))
        axes[0].set_title("Cluster %d" % (cluster))

    i = 1
    print("Image indices: ", inds)
    for image in inputs[inds]:
        axes[i].imshow(np.transpose(image / 2 + 0.5, transpose_axes))
        i += 1
    plt.show()

In [None]:
data_dir = pathlib.Path("../../data")
file_names = list(data_dir.joinpath("birds").glob("*/*.jpg"))
n_images = len(file_names)
BATCH_SIZE = n_images  # because our dataset is quite small
birds_inputs, birds_classes = read_img_dataset(data_dir.joinpath("birds"))

In [None]:
X_birds = birds_inputs.numpy()

In [None]:
plot_sample_imgs(birds_inputs[0:24, :, :, :])

For clustering we need to calculate distances between points. So we need a vector representation for each data point. A simplest way to create a vector representation of an image is by flattening the image.  

In [None]:
flatten_transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        transforms.Lambda(torch.flatten),
    ]
)
flatten_images = datasets.ImageFolder(
    root=data_dir.joinpath("birds"), transform=flatten_transforms
)

In [None]:
flatten_dataloader = torch.utils.data.DataLoader(
    flatten_images, batch_size=BATCH_SIZE, shuffle=True, num_workers=0
)

In [None]:
flatten_train, y_train = next(iter(flatten_dataloader))

In [None]:
flatten_images = flatten_train.numpy()

In [None]:
image_shape = [3, 200, 200]
img = flatten_images[20].reshape(image_shape)
plt.imshow(np.transpose(img / 2 + 0.5, (1, 2, 0)));

In [None]:
flatten_images.shape  # 200 by 200 images with 3 color channels

In [None]:
from sklearn.cluster import KMeans

k = 3
km_flatten = KMeans(k, n_init="auto", random_state=123)
km_flatten.fit(flatten_images);

In [None]:
km_flatten.cluster_centers_.shape

In [None]:
flatten_images.shape

In [None]:
unflatten_inputs = np.array([img.reshape(image_shape) for img in flatten_images])

In [None]:
flatten_images.shape

In [None]:
for cluster in range(k):
    get_cluster_images(km_flatten, flatten_images, unflatten_inputs, cluster, n_img=5)

We see some mis-categorizations. 

How about trying out a different input representation? Let's use transfer learning as a feature extractor with a pre-trained vision model. For each image in our dataset we'll pass it through a pretrained network and get a representation from the last layer, before the classification layer given by the pre-trained network. 

![](../../img/cnn-ex.png)

Source: https://cezannec.github.io/Convolutional_Neural_Networks/

In [None]:
def get_features(model, inputs):
    """Extract output of densenet model"""
    with torch.no_grad():  # turn off computational graph stuff
        Z_train = torch.empty((0, 1024))  # Initialize empty tensors
        y_train = torch.empty((0))
        Z_train = torch.cat((Z_train, model(inputs)), dim=0)
    return Z_train.detach()

In [None]:
densenet = models.densenet121(weights="DenseNet121_Weights.IMAGENET1K_V1")
densenet.classifier = torch.nn.Identity()  # remove that last "classification" layer

In [None]:
Z_birds = get_features(
    densenet,
    birds_inputs,
).numpy()

In [None]:
Z_birds.shape

Do we get better clustering with this representation? 

In [None]:
from sklearn.cluster import KMeans

k = 3
km = KMeans(n_clusters=k, n_init="auto", random_state=123)
km.fit(Z_birds);

In [None]:
km.cluster_centers_.shape

In [None]:
for cluster in range(k):
    get_cluster_images(km, Z_birds, X_birds, cluster, n_img=6)

KMeans seems to be doing a good job. But cluster centers are not interpretable at all now.
This dataset seems easier, as the birds have very distinct colors. Let's try a more complicated dataset.  

In [None]:
file_names = list(data_dir.joinpath("food").glob("*/*.jpg"))
n_images = len(file_names)
BATCH_SIZE = n_images  # because our dataset is quite small
food_inputs, food_classes = read_img_dataset(data_dir.joinpath("food"))
n_images

In [None]:
X_food = food_inputs.numpy()

In [None]:
plot_sample_imgs(food_inputs[0:24, :, :, :])

In [None]:
Z_food = get_features(
    densenet,
    food_inputs,
).numpy()

In [None]:
Z_food.shape

In [None]:
from sklearn.cluster import KMeans

k = 5
km = KMeans(n_clusters=k, n_init="auto", random_state=123)
km.fit(Z_food);

In [None]:
km.cluster_centers_.shape

In [None]:
for cluster in range(k):
    get_cluster_images(km, Z_food, X_food, cluster, n_img=6)

There are some mis-classifications but overall it seems pretty good! You can experiment with 
- Different values for number of clusters
- Different pre-trained models
- Other possible representations 
- Different image datasets