<h1> Semi-supervision for Image Classification </h1>
Grant Sinha, gsinha@uwaterloo.ca

<h2> Abstract </h2>

The below approach combines supervised learning (via cross-entropy loss for labeled data) and unsupervised learning (via a K-means clustering loss for unlabeled data). These are combined as a weighted linear combination for the final loss function. In this way, unlabaled data will guide feature extraction and refine cluster boundaries in the feature space beyond what would be the case if only using the labeled portion. 

We will use the MNIST dataset for its simplicity and ease of use. We assume M samples in the training set are labeled, treating the rest of the training set as unlabeled. We will define a convolutional network to output a fairly small (64, 7, 7)-shaped feature map, which is flattened and fed through a linear layer before classification. 

We also investigate the effect of an additional loss term for consistency regularization. Specifically, we apply a small, random transformation to training data and encourage the model's features for the augmented image to remain close to the unaugmented image's features.

To benchmark performance, we measure accuracy on a test split. We use confusion matrices in order to further probe these results and identify model errors. Finally, to investigate the feature space, we create tSNE plots. Analyzing how these vary across experiments will give us the desired insight into effects of K-means and consistency regularization. Ultimately, these will lead to higher accuracy and more discriminative features.

<h1> Libraries & Environment Setup </h1>

In [1]:
import os  # Environment management
import numpy as np  # Numpy
import random  # Determinism
import torch  # PyTorch
import torch.nn as nn  # Neural network module
import torch.optim as optim  # Optimizers
import matplotlib.pyplot as plt  # Plotting
import matplotlib.animation as animation # Animation
import warnings  # Silence some sklearn warnings
import optuna # Hyperparameter optimization

from collections import Counter  # Counting
from torchvision import datasets, transforms  # Datasets and transformations
from sklearn.cluster import KMeans  # KMeans clustering algorithm
from sklearn.model_selection import train_test_split  # Train-test split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  # Confusion matrix
from sklearn.manifold import TSNE  # t-SNE
from tqdm.notebook import tqdm  # Progress bars

from models import CNN  # Neural network class
from trainers import KMeansConsistencyTrainer  # Training function
from datasets import get_mnist_loaders, split_dataset  # Dataset functions

warnings.filterwarnings("ignore")    # Silence some annoying sklearn warnings
os.environ["OMP_NUM_THREADS"] = "1"  # Due to sklearn bug

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seeds for determinism
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

Using device: cuda


<h1> MNIST Data Preparation </h1>

In [2]:
# Load MNIST dataset
train_loader, test_loader = get_mnist_loaders(batch_size=128)

# Split dataset into labeled and unlabeled subsets
mnist_train = train_loader.dataset
# labeled_data, unlabeled_data = split_dataset(mnist_train, num_labeled=250)

# Create data loaders for labeled and unlabeled data
# labeled_loader = torch.utils.data.DataLoader(labeled_data, batch_size=128, shuffle=True, num_workers=0)
# unlabeled_loader = torch.utils.data.DataLoader(unlabeled_data, batch_size=128, shuffle=True, num_workers=0)

<h1> Training Model </h1>

In [3]:
# Experiment setup
def experiment(values_of_M, epochs=10, evaluate_every=1, lambda_kmeans=0.1, lambda_consistency=0.1, use_consistency=False, use_unlabeled=True, save_dir="tsne_images", output_file="tsne_animation.gif", generate_tsne=True, generate_cm=True):
    results = {}
    for M in values_of_M:
        # Create model instance
        model = CNN(use_dropout = True, dropout_rate = 0.3)
        trainer = KMeansConsistencyTrainer(model, device=device)

        # Split dataset
        labeled_data, unlabeled_data = split_dataset(mnist_train, M)
        # print_label_distribution(labeled_data, description=f"Labeled Dataset for M={M}")
        
        labeled_loader = torch.utils.data.DataLoader(labeled_data, batch_size=256, shuffle=True, num_workers=0)
        unlabeled_loader = torch.utils.data.DataLoader(unlabeled_data, batch_size=256, shuffle=True, num_workers=0) if use_unlabeled else None

        # Train and log results
        train_accs, test_accs = trainer.train(
            labeled_loader=labeled_loader, 
            unlabeled_loader=unlabeled_loader, 
            test_loader=test_loader, 
            epochs=epochs, 
            evaluate_every=evaluate_every
        )
        results[M] = (train_accs, test_accs)
        # print(f"Train accuracy for M={M}: {train_accs[-1]:.2f}%, Test accuracy for M={M}: {test_accs[-1]:.2f}%")


        # Evaluate and compute confusion matrix
        test_accuracy, preds, labels = trainer.evaluate(test_loader)
        if generate_cm:
            cm = confusion_matrix(labels, preds)
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(10))
            disp.plot(cmap="viridis")
            plt.title(f"Confusion Matrix for M={M}")
            plt.show()

        if generate_tsne:
            # Plot t-SNE visualization
            plot_tsne(model, M, test_loader)

    if generate_tsne:
        create_animation(save_dir, output_file)
    return results

# Plot training curves
def plot_training_curves(results, values_of_M):
    plt.figure(figsize=(12, 6))
    
    # Plot test accuracy for each M
    for M in values_of_M:
        train_accs, test_accs = results[M]
        plt.plot(test_accs, label=f"M={M} (Test Accuracy)")

    plt.title("Test Accuracy vs. Epochs for Different Values of M")
    plt.xlabel("Epoch")
    plt.ylabel("Test Accuracy (%)")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid()
    plt.show()

# Plot t-SNE visualization
def plot_tsne(model, M, data_loader, save_dir="tsne_images"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    model.eval()
    features_list = []
    labels_list = []
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            _, features = model(images)
            features_list.append(features.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
    features = np.concatenate(features_list)
    labels = np.concatenate(labels_list)
    tsne = TSNE(n_components=2, random_state=0)
    tsne_results = tsne.fit_transform(features)
    plt.figure(figsize=(6, 6))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='viridis', alpha=0.5)
    plt.colorbar(scatter)
    plt.title(f't-SNE Visualization for M={M}')
    plt.savefig(os.path.join(save_dir, f'tsne_M_{M}.png'))
    plt.show()  # Display the plot in the output
    plt.close()

def create_animation(save_dir="tsne_images", output_file="tsne_animation.gif"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    fig = plt.figure(figsize=(6, 6))
    images = []
    
    # Extract numeric values of M from filenames and sort them
    file_names = sorted(
        [f for f in os.listdir(save_dir) if f.startswith('tsne_M_') and f.endswith('.png')],
        key=lambda x: int(x.split('_')[2].split('.')[0])
    )
    
    for file_name in file_names:
        img = plt.imread(os.path.join(save_dir, file_name))
        images.append([plt.imshow(img, animated=True)])
    
    ani = animation.ArtistAnimation(fig, images, interval=500, blit=True, repeat_delay=1000)
    ani.save(os.path.join(save_dir, output_file), writer='imagemagick')
    plt.close()


<h1> Hyperparameter Optimization with Optuna </h1>

In [None]:
def objective(trial):
    # Define the hyperparameter search space
    lambda_kmeans = trial.suggest_loguniform('lambda_kmeans', 1e-5, 1e-1)
    lambda_consistency = trial.suggest_loguniform('lambda_consistency', 1e-6, 1e-2)
    # dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
    
    # Use a fixed value of M for the optimization
    M = 1000  # You can change this value as needed
    values_of_M = [M]
    
    # Run the experiment with the current hyperparameters
    results = experiment(
        values_of_M=values_of_M, 
        epochs=10, 
        evaluate_every=2,
        lambda_kmeans=lambda_kmeans, 
        lambda_consistency=lambda_consistency, 
        use_consistency=True, 
        use_unlabeled=True, 
        generate_tsne=False, 
        generate_cm=False
    )
    
    # Get the test accuracy for the current hyperparameters
    test_accs = results[M][1]
    best_test_acc = max(test_accs)
    
    return best_test_acc

# Create an Optuna study and optimize the objective function
set_seed(7 * 5 * 3 * 2 * 2)

study_name = "Optimize_KMeans_and_Consistency_Losses"
study = optuna.create_study(study_name=study_name, direction='maximize')
study.optimize(objective, n_trials=100)

# Print the best hyperparameters
print("Best hyperparameters: ", study.best_params)

[I 2024-12-23 16:52:59,068] A new study created in memory with name: Optimize_KMeans_and_Consistency_Losses


Epoch [1/10], Loss: 2.3870286345481873, Train Accuracy: 15.70%, Test Accuracy: 16.89%
Epoch [2/10], Loss: 2.0041126012802124, Train Accuracy: 32.20%, Test Accuracy: 68.43%
Epoch [3/10], Loss: 1.5176301002502441, Train Accuracy: 66.80%, Test Accuracy: 68.43%
Epoch [4/10], Loss: 0.9843920767307281, Train Accuracy: 77.70%, Test Accuracy: 80.67%
Epoch [5/10], Loss: 0.6774343550205231, Train Accuracy: 80.80%, Test Accuracy: 80.67%
Epoch [6/10], Loss: 0.5431683361530304, Train Accuracy: 83.50%, Test Accuracy: 85.77%
Epoch [7/10], Loss: 0.5054378882050514, Train Accuracy: 84.70%, Test Accuracy: 85.77%
Epoch [8/10], Loss: 0.49727731943130493, Train Accuracy: 84.40%, Test Accuracy: 86.18%
Epoch [9/10], Loss: 0.4702412709593773, Train Accuracy: 85.00%, Test Accuracy: 86.18%
Epoch [10/10], Loss: 0.4555510953068733, Train Accuracy: 85.80%, Test Accuracy: 87.69%


[I 2024-12-23 16:54:47,091] Trial 0 finished with value: 87.69 and parameters: {'lambda_kmeans': 0.0010337871850909183, 'lambda_consistency': 8.34959100047629e-06}. Best is trial 0 with value: 87.69.


Epoch [1/10], Loss: 2.210574448108673, Train Accuracy: 25.20%, Test Accuracy: 39.83%
Epoch [2/10], Loss: 1.7021483480930328, Train Accuracy: 55.60%, Test Accuracy: 80.71%
Epoch [3/10], Loss: 1.061434656381607, Train Accuracy: 74.30%, Test Accuracy: 80.71%
Epoch [4/10], Loss: 0.6843858659267426, Train Accuracy: 79.70%, Test Accuracy: 84.37%
Epoch [5/10], Loss: 0.5040683075785637, Train Accuracy: 83.80%, Test Accuracy: 84.37%
Epoch [6/10], Loss: 0.4279017448425293, Train Accuracy: 86.50%, Test Accuracy: 88.20%
Epoch [7/10], Loss: 0.4005499705672264, Train Accuracy: 87.50%, Test Accuracy: 88.20%
Epoch [8/10], Loss: 0.3867444470524788, Train Accuracy: 87.80%, Test Accuracy: 88.41%
Epoch [9/10], Loss: 0.3868577480316162, Train Accuracy: 87.70%, Test Accuracy: 88.41%
Epoch [10/10], Loss: 0.3716476932168007, Train Accuracy: 88.70%, Test Accuracy: 89.34%


[I 2024-12-23 16:56:28,086] Trial 1 finished with value: 89.34 and parameters: {'lambda_kmeans': 1.758538400114545e-05, 'lambda_consistency': 0.00013983372460331177}. Best is trial 1 with value: 89.34.


Epoch [1/10], Loss: 2.2314908504486084, Train Accuracy: 13.90%, Test Accuracy: 60.80%
Epoch [2/10], Loss: 1.729572206735611, Train Accuracy: 59.60%, Test Accuracy: 73.15%
Epoch [3/10], Loss: 1.0963037461042404, Train Accuracy: 75.30%, Test Accuracy: 73.15%
Epoch [4/10], Loss: 0.6987478137016296, Train Accuracy: 80.60%, Test Accuracy: 83.22%
Epoch [5/10], Loss: 0.519818589091301, Train Accuracy: 82.60%, Test Accuracy: 83.22%
Epoch [6/10], Loss: 0.41649414598941803, Train Accuracy: 86.70%, Test Accuracy: 88.21%
Epoch [7/10], Loss: 0.39179813116788864, Train Accuracy: 87.20%, Test Accuracy: 88.21%
Epoch [8/10], Loss: 0.39184504747390747, Train Accuracy: 86.60%, Test Accuracy: 88.57%
Epoch [9/10], Loss: 0.3639788553118706, Train Accuracy: 88.80%, Test Accuracy: 88.57%
Epoch [10/10], Loss: 0.3747972697019577, Train Accuracy: 88.80%, Test Accuracy: 89.46%


[I 2024-12-23 16:58:08,109] Trial 2 finished with value: 89.46 and parameters: {'lambda_kmeans': 1.1602777659091768e-05, 'lambda_consistency': 0.0036811660066349105}. Best is trial 2 with value: 89.46.


Epoch [1/10], Loss: 2.232264280319214, Train Accuracy: 21.70%, Test Accuracy: 45.28%
Epoch [2/10], Loss: 1.7366627752780914, Train Accuracy: 55.00%, Test Accuracy: 74.08%
Epoch [3/10], Loss: 1.113545373082161, Train Accuracy: 73.00%, Test Accuracy: 74.08%
Epoch [4/10], Loss: 0.7254999876022339, Train Accuracy: 80.60%, Test Accuracy: 83.27%
Epoch [5/10], Loss: 0.5456205978989601, Train Accuracy: 81.60%, Test Accuracy: 83.27%
Epoch [6/10], Loss: 0.44455868005752563, Train Accuracy: 85.10%, Test Accuracy: 87.32%
Epoch [7/10], Loss: 0.42366183549165726, Train Accuracy: 86.20%, Test Accuracy: 87.32%
Epoch [8/10], Loss: 0.4118703156709671, Train Accuracy: 87.00%, Test Accuracy: 88.35%
Epoch [9/10], Loss: 0.3888649120926857, Train Accuracy: 87.10%, Test Accuracy: 88.35%
Epoch [10/10], Loss: 0.38151922821998596, Train Accuracy: 87.30%, Test Accuracy: 88.66%


[I 2024-12-23 16:59:49,075] Trial 3 finished with value: 88.66 and parameters: {'lambda_kmeans': 0.0008378883008109513, 'lambda_consistency': 0.0026401803200526825}. Best is trial 2 with value: 89.46.


Epoch [1/10], Loss: 2.2831152081489563, Train Accuracy: 19.70%, Test Accuracy: 30.16%
Epoch [2/10], Loss: 1.875045359134674, Train Accuracy: 47.20%, Test Accuracy: 72.95%
Epoch [3/10], Loss: 1.305981159210205, Train Accuracy: 75.40%, Test Accuracy: 72.95%
Epoch [4/10], Loss: 0.86561319231987, Train Accuracy: 75.80%, Test Accuracy: 81.49%
Epoch [5/10], Loss: 0.6176142245531082, Train Accuracy: 80.20%, Test Accuracy: 81.49%
Epoch [6/10], Loss: 0.5113637521862984, Train Accuracy: 83.10%, Test Accuracy: 85.43%
Epoch [7/10], Loss: 0.49655669927597046, Train Accuracy: 84.40%, Test Accuracy: 85.43%
Epoch [8/10], Loss: 0.4935138002038002, Train Accuracy: 84.10%, Test Accuracy: 86.67%
Epoch [9/10], Loss: 0.45682963728904724, Train Accuracy: 85.60%, Test Accuracy: 86.67%
Epoch [10/10], Loss: 0.45345984399318695, Train Accuracy: 85.50%, Test Accuracy: 87.82%


[I 2024-12-23 17:01:29,761] Trial 4 finished with value: 87.82 and parameters: {'lambda_kmeans': 0.018819718238073685, 'lambda_consistency': 6.692366450568696e-05}. Best is trial 2 with value: 89.46.


Epoch [1/10], Loss: 2.297993779182434, Train Accuracy: 20.20%, Test Accuracy: 47.11%
Epoch [2/10], Loss: 1.8772067725658417, Train Accuracy: 46.20%, Test Accuracy: 70.81%
Epoch [3/10], Loss: 1.3026941120624542, Train Accuracy: 73.90%, Test Accuracy: 70.81%
Epoch [4/10], Loss: 0.8098151981830597, Train Accuracy: 78.30%, Test Accuracy: 83.58%
Epoch [5/10], Loss: 0.5619940608739853, Train Accuracy: 82.90%, Test Accuracy: 83.58%
Epoch [6/10], Loss: 0.4932525083422661, Train Accuracy: 84.00%, Test Accuracy: 86.27%
Epoch [7/10], Loss: 0.4561535269021988, Train Accuracy: 85.60%, Test Accuracy: 86.27%
Epoch [8/10], Loss: 0.44601666927337646, Train Accuracy: 85.30%, Test Accuracy: 87.76%
Epoch [9/10], Loss: 0.4311358481645584, Train Accuracy: 85.60%, Test Accuracy: 87.76%
Epoch [10/10], Loss: 0.432711124420166, Train Accuracy: 86.40%, Test Accuracy: 88.56%


[I 2024-12-23 17:03:10,463] Trial 5 finished with value: 88.56 and parameters: {'lambda_kmeans': 5.539315622786286e-05, 'lambda_consistency': 0.00013725499218143717}. Best is trial 2 with value: 89.46.


Epoch [1/10], Loss: 2.2072219848632812, Train Accuracy: 22.10%, Test Accuracy: 50.36%
Epoch [2/10], Loss: 1.7091401517391205, Train Accuracy: 60.50%, Test Accuracy: 78.29%
Epoch [3/10], Loss: 1.0706259906291962, Train Accuracy: 75.30%, Test Accuracy: 78.29%
Epoch [4/10], Loss: 0.710592195391655, Train Accuracy: 77.80%, Test Accuracy: 84.79%
Epoch [5/10], Loss: 0.5260713547468185, Train Accuracy: 83.50%, Test Accuracy: 84.79%
Epoch [6/10], Loss: 0.4239533469080925, Train Accuracy: 85.60%, Test Accuracy: 87.02%


<h1> Visualize Hyperparameter Optimization Results </h1>

In [None]:
# Optimization History Plot
history_plot = optuna.visualization.plot_optimization_history(study)
history_plot.show()

# Hyperparameter Importance Plot
importance_plot = optuna.visualization.plot_param_importances(study)
importance_plot.show()

# Parallel Coordinate Plot
parallel_plot = optuna.visualization.plot_parallel_coordinate(study)
parallel_plot.show()

# Slice Plot
slice_plot = optuna.visualization.plot_slice(study)
slice_plot.show()

# Contour Plot
contour_plot = optuna.visualization.plot_contour(study)
contour_plot.show()

# EDF Plot
edf_plot = optuna.visualization.plot_edf(study)
edf_plot.show()