# DEC Training on Torchvision MNIST Dataset
This notebook provides an example workflow on training and I/O operations of DEC models, and shows how to compute clustering performance metrics and visualize embedding clusters.

## 1. Setup

Package requirements:
- dec_torch
- torchvision

1. Specify global torch computation device

In [None]:
import torch
import torchvision
import os

# Computation device
device = "cuda"

# Dataset source and image properties
torchvision_dataset = torchvision.datasets.MNIST
dataset_root = "dataset"  # Path to download the dataset
height, width, channels = 28, 28, 1
input_dim = height * width * channels

# Training epochs and stopping criteria
autoencoder_pretraining_epoch = 10000
autoencoder_finetuning_epoch = 10000
dec_reassignment_tolerance = 0.001

# Output paths
output_dir = "output"
j = lambda x : os.path.join(output_dir, x)
autoencoder_pretrained_output = j("pretrained.stacked.autoencoder.pth")
autoencoder_finetuned_output = j("finetuned.stacked.autoencoder.pth")
dec_encoder_output = j("dec.encoder.pth")
dec_centroids_output = j("dec.centroids.pth")

# DEC K-means initialization
kmeans_trials = 1

In [None]:
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

if not os.path.exists(output_dir):
    os.mkdir(output_dir, exist_ok=True)

import warnings
from tqdm import TqdmWarning

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action="ignore", category=TqdmWarning)

## 2. Data Loading


In [None]:
import torchvision.transforms as T

transform = T.Compose([
    T.ToTensor(),
    T.Lambda(torch.flatten)
])

training_set = torchvision_dataset(dataset_root, train=True, transform=transform, download=True)
validation_set = torchvision_dataset(dataset_root, train=False, transform=transform, download=True)

print("Training   set size:", len(training_set))
print("Validation set size:", len(validation_set))

In [None]:
# Option - Dataset without label (used while training the models)
class UnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, dataset: torch.utils.data.Dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, _ = self.dataset[idx]
        return image

training_set_unlabeled = UnlabeledDataset(training_set)
validation_set_unlabeled = UnlabeledDataset(validation_set)

In [None]:
# Option - Extract all unlabeled data to memory (faster alternative)
from torch.utils.data import DataLoader, TensorDataset
from dec_torch.utils.data import extract_all_data

training_input_cpu, training_labels = extract_all_data(DataLoader(training_set))
validation_input_cpu, validation_labels = extract_all_data(DataLoader(validation_set))

# This can be skipped if you want to use swap memory or all GPU memory.
# If you skip pinning the whole tensor, then consider setting `pin_memory=True` in DataLoader instead.
# training_input = training_input.pin_memory()
# validation_input = validation_input.pin_memory()

# Load all data to cuda device if you can afford it.
training_input = training_input_cpu.to(device)
validation_input = validation_input_cpu.to(device)

training_set_unlabeled = TensorDataset(training_input)
validation_set_unlabeled = TensorDataset(validation_input)

In [None]:
# Preview the dataset
import matplotlib.pyplot as plt

def preview_images(dataset, indices: list[int], channels, height, width, cmap="grey"):
    images = [dataset[idx] for idx in indices]
    images = [img[0] if isinstance(img, (tuple, list)) else img for img in images]
    images = [img.cpu() for img in images]
    images = [img.view(channels, height, width) for img in images]
    images = [img.permute(1, 2, 0) for img in images]

    plt.figure(figsize=(2 * len(indices), 2))
    for i, img in enumerate(images, start=1):
        plt.subplot(1, len(indices), i)
        plt.imshow(img, cmap=cmap)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

preview_images(
    dataset=training_set_unlabeled,
    indices=list(torch.randint(0, 5000, (10,))),
    channels=channels,
    height=height,
    width=width
)

## 3. Stacked Autoencoder Training

### 3.1 Greedy Layer-wise Training

In [None]:
# Run after "2. Data Loading"
from torch.utils.data import DataLoader

# Set num_workers if data is not already loaded on GPU
training_loader = DataLoader(training_set_unlabeled, batch_size=256)
validation_loader = DataLoader(validation_set_unlabeled, batch_size=256)


from dec_torch.autoencoder import StackedAutoEncoder, CoderConfig, StackedAutoEncoderConfig

latent_dims = [500, 500, 2000, 10]
sae_config = StackedAutoEncoderConfig.build(
    input_dim=input_dim,
    latent_dims=latent_dims,
    input_dropout=.2,
)

model = StackedAutoEncoder(sae_config).to(device)
model

In [None]:
# Train stacked autoencoder
from torch import nn
from datetime import datetime, timedelta

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

start_time = datetime.now()

history = model.greedy_fit(
    training_loader,
    optimizer,
    criterion,
    val_loader=validation_loader,
    n_epoch=autoencoder_pretraining_epoch,
    max_verbose=10
)

end_time = datetime.now()
print("Elapsed time", end_time - start_time)

model.save(autoencoder_pretrained_output)

In [None]:
# Training-Validation loss history
from dec_torch.utils.visualization import loss_plot
from matplotlib import pyplot as plt

fig, axes = plt.subplots(len(history), 1, figsize=(8, 10), sharex=True)
for i, (ax, h) in enumerate(zip(axes, history)):
    loss_plot(h[h["epoch"] > 10], ax=ax).set_title("Autoencoder " + str(i))
plt.show()

### 3.2 Fine-tuning Stacked Autoencoder

In [None]:
# Setup model, output path, dataloader
# Run after "2. Data Loading"
from torch.utils.data import DataLoader

# Set num_workers if data is not already loaded on GPU
training_loader = DataLoader(training_set_unlabeled, batch_size=256)
validation_loader = DataLoader(validation_set_unlabeled, batch_size=256)


from dec_torch.autoencoder import StackedAutoEncoder, CoderConfig, StackedAutoEncoderConfig

pretrained_sae = StackedAutoEncoder.load(autoencoder_pretraining_output, map_location="cpu")

finetune_config = pretrained_sae.config.replace_input_dropout(None)
model = StackedAutoEncoder(finetune_config)

model.load_state_dict(pretrained_sae.state_dict())
model = model.to(device)

del pretrained_sae

model

In [None]:
# Train stacked autoencoder
from torch import nn
from datetime import datetime, timedelta

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

start_time = datetime.now()

history = model.fit(
    training_loader,
    optimizer,
    criterion,
    val_loader=validation_loader,
    n_epoch=autoencoder_finetuning_epoch,
    max_verbose=10
)

end_time = datetime.now()
print("Elapsed time", end_time - start_time)

model.save(autoencoder_finetuned_output)

In [None]:
# Training-Validation loss history
from dec_torch.utils.visualization import loss_plot

loss_plot(history[history["epoch"] > 10]).set_title("SAE Fine-tuning Loss Graph")

## 4. DEC Training

In [None]:
# Run after "2. Data Loading"
# Extract embeddings and true labels
from dec_torch.autoencoder import StackedAutoEncoder
from dec_torch.utils.data import extract_all_data

# Ensure autoencoder is loaded to "cpu" for k-means initialization.
autoencoder = StackedAutoEncoder.load(autoencoder_finetuned_output, map_location="cpu")

# Use torch.utils.data.Subset if the amount of training samples is too much.
# embeddings, labels_true = extract_all_data(TensorDataset(training_input), transform=autoencoder.encoder)
embeddings = None
with torch.no_grad():
    embeddings = autoencoder.encoder(training_input_cpu)

In [None]:
# K-means centroid initialization
from dec_torch import dec

clusters_list, clusters_scores = dec.init_clusters_trials(embeddings, n_clusters=10, n_trials=kmeans_trials)

selected_index = clusters_scores.iloc[0].name
centroids = clusters_list[selected_index]

print(f"Selected clusters #{selected_index}")
clusters_scores

In [None]:
# Extract predictied labels
labels_pred = None
with torch.no_grad():
    labels_pred = torch.argmax(dec.DEC.soft_assignment(embeddings, centroids, alpha=1), dim=1)

In [None]:
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    silhouette_score,
    calinski_harabasz_score
)

ari = adjusted_rand_score(training_labels, labels_pred)
print("ARI", ari)

nmi = normalized_mutual_info_score(training_labels, labels_pred)
print("NMI", nmi)

if not len(labels_pred.unique()) == 1:
    sil = silhouette_score(embeddings, labels_pred)
    print("SIL", sil)
    ch = calinski_harabasz_score(embeddings, labels_pred)
    print("CH ", ch)
else:
    print("Cannot compute silhouette (SIL) and calinski-harabasz (CH) scores.")
    print("There is only one cluster in the labels.")

In [None]:
from matplotlib import pyplot as plt
from dec_torch.utils.visualization import cluster_plot

fig, axes = plt.subplots(1, 2, figsize=(12,5))

for ax, (label_type, labels) in zip(axes, [("true label", training_labels), ("pred. label", labels_pred)]):
    cluster_plot(
        embeddings,
        labels,
        reduction="umap",
        centroids=centroids,
        centroids_options = {"marker": "s", "color": "blue", "s": 50},
        ax=ax
    )
    ax.set_title(label_type)

plt.show()

In [None]:
from torch.utils.data import DataLoader

training_loader = DataLoader(training_set_unlabeled, batch_size=256)
validation_loader = DataLoader(validation_set_unlabeled, batch_size=256)

model = dec.DEC(autoencoder.encoder, centroids).to(device)

del clusters_list, clusters_scores, embeddings, labels_pred, centroids
model

In [None]:
# Train DEC
from torch import nn
from datetime import datetime, timedelta

from dec_torch import dec
criterion = dec.dec.KLDivLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

start_time = datetime.now()

history = model.fit(
    training_loader,
    optimizer,
    criterion,
    val_loader=validation_loader,
    tolerance=dec_reassignment_tolerance,
    max_verbose=1000,
    max_epoch=10000
)

end_time = datetime.now()
print("Elapsed time", end_time - start_time)

dec.io.save(model, dec_encoder_output, dec_centroids_output)

In [None]:
# Training-Validation loss history
from dec_torch.utils.visualization import loss_plot

loss_plot(history)

## 5. Cluster Visualization \& Metrics

In [None]:
# Load DEC model
# Run after "2. Data Loading"
from dec_torch import dec

dec_model = dec.io.load(dec_encoder_output, dec_centroids_output, sequential_encoder=True, map_location="cpu")
dec_model.eval()

print("DEC centroids shape:", dec_model.centroids.shape)
dec_model

In [None]:
# Load training inputs, embeddings, true and predicted labels
import torch
from torch.utils.data import DataLoader

# from dec_torch.utils.data import extract_all_data
# inputs, labels_true = extract_all_data(DataLoader(training_set))

embeddings = None
labels_pred = None
with torch.no_grad():
    embeddings = dec_model.encoder(training_input_cpu)  # TODO: Consider using hook instead of calling the model twice
    labels_pred = torch.argmax(dec_model(training_input_cpu), dim=1)

In [None]:
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    silhouette_score,
    calinski_harabasz_score
)

ari = adjusted_rand_score(training_labels, labels_pred)
print("ARI", ari)

nmi = normalized_mutual_info_score(training_labels, labels_pred)
print("NMI", nmi)

if not len(labels_pred.unique()) == 1:
    sil = silhouette_score(embeddings, labels_pred)
    print("SIL", sil)
    ch = calinski_harabasz_score(embeddings, labels_pred)
    print("CH ", ch)
else:
    print("Cannot compute silhouette (SIL) and calinski-harabasz (CH) scores.")
    print("There is only one cluster in the labels.")

In [None]:
from dec_torch.utils.visualization import cluster_plot
from matplotlib import pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12,5))

for ax, (label_type, labels) in zip(axes, [("true label", training_labels), ("pred. label", labels_pred)]):
    cluster_plot(
        embeddings,
        labels,
        reduction="umap",
        centroids=dec_model.centroids.detach().cpu(),
        centroids_options = {"marker": "s", "color": "blue", "s": 50},
        ax=ax
    )
    ax.set_title(label_type)

del embeddings, labels_pred
plt.show()

In [None]:
# Load Validation Data

#from dec_torch.utils.data import extract_all_data
# inputs, labels_true = extract_all_data(DataLoader(validation_set))

embeddings = None
labels_pred = None
with torch.no_grad():
    embeddings = dec_model.encoder(validation_input_cpu)
    labels_pred = torch.argmax(dec_model(validation_input_cpu), dim=1)

In [None]:
from sklearn.metrics import (
    adjusted_rand_score,
    normalized_mutual_info_score,
    silhouette_score,
    calinski_harabasz_score
)

ari = adjusted_rand_score(labels_true, labels_pred)
print("ARI", ari)

nmi = normalized_mutual_info_score(labels_true, labels_pred)
print("NMI", nmi)

if not len(labels_pred.unique()) == 1:
    sil = silhouette_score(embeddings, labels_pred)
    print("SIL", sil)
    ch = calinski_harabasz_score(embeddings, labels_pred)
    print("CH ", ch)
else:
    print("Cannot compute silhouette (SIL) and calinski-harabasz (CH) scores.")
    print("There is only one cluster in the labels.")

In [None]:
from dec_torch.utils.visualization import cluster_plot
from matplotlib import pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12,5))

for ax, (label_type, labels) in zip(axes, [("true label", training_labels), ("pred. label", labels_pred)]):
    cluster_plot(
        embeddings,
        labels,
        reduction="umap",
        centroids=dec_model.centroids.detach().cpu(),
        centroids_options = {"marker": "s", "color": "blue", "s": 50},
        ax=ax
    )
    ax.set_title(label_type)

del embeddings, labels_pred
plt.show()