<h1> Libraries & Environment Setup </h1>

In [None]:
import os  # Environment management
import numpy as np  # Numpy
import random  # Determinism
import torch  # PyTorch
import torch.nn as nn  # Neural network module
import torch.optim as optim  # Optimizers
import matplotlib.pyplot as plt  # Plotting
import matplotlib.animation as animation # Animation
import warnings  # Silence some sklearn warnings
import optuna # Hyperparameter optimization

from collections import Counter  # Counting
from torchvision import datasets, transforms  # Datasets and transformations
from sklearn.cluster import KMeans  # KMeans clustering algorithm
from sklearn.model_selection import train_test_split  # Train-test split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  # Confusion matrix
from sklearn.manifold import TSNE  # t-SNE
from tqdm.notebook import tqdm  # Progress bars

from models import CNN  # Neural network class
from trainers import KMeansConsistencyTrainer  # Training function
from datasets import get_mnist_loaders, split_dataset  # Dataset functions

warnings.filterwarnings("ignore")    # Silence some annoying sklearn warnings
os.environ["OMP_NUM_THREADS"] = "1"  # Due to sklearn bug

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seeds for determinism
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

<h1> MNIST Data Preparation </h1>

In [2]:
# Load MNIST dataset
train_loader, test_loader = get_mnist_loaders(batch_size=128)

# Split dataset into labeled and unlabeled subsets
mnist_train = train_loader.dataset
# labeled_data, unlabeled_data = split_dataset(mnist_train, num_labeled=250)

# Create data loaders for labeled and unlabeled data
# labeled_loader = torch.utils.data.DataLoader(labeled_data, batch_size=128, shuffle=True, num_workers=0)
# unlabeled_loader = torch.utils.data.DataLoader(unlabeled_data, batch_size=128, shuffle=True, num_workers=0)

<h1> Training Model </h1>

In [3]:
# Experiment setup
def experiment(values_of_M, epochs=10, evaluate_every=1, lambda_kmeans=0.1, lambda_consistency=0.1, use_consistency=False, use_unlabeled=True, save_dir="tsne_images", output_file="tsne_animation.gif", generate_tsne=True, generate_cm=True):
    results = {}
    for M in values_of_M:
        # Create model instance
        model = CNN(use_dropout = True, dropout_rate = 0.3)
        trainer = KMeansConsistencyTrainer(model, device=device)

        # Split dataset
        labeled_data, unlabeled_data = split_dataset(mnist_train, M)
        # print_label_distribution(labeled_data, description=f"Labeled Dataset for M={M}")
        
        labeled_loader = torch.utils.data.DataLoader(labeled_data, batch_size=256, shuffle=True, num_workers=0)
        unlabeled_loader = torch.utils.data.DataLoader(unlabeled_data, batch_size=256, shuffle=True, num_workers=0) if use_unlabeled else None

        # Train and log results
        train_accs, test_accs = trainer.train(
            labeled_loader=labeled_loader, 
            unlabeled_loader=unlabeled_loader, 
            test_loader=test_loader, 
            epochs=epochs, 
            evaluate_every=evaluate_every
        )
        results[M] = (train_accs, test_accs)
        # print(f"Train accuracy for M={M}: {train_accs[-1]:.2f}%, Test accuracy for M={M}: {test_accs[-1]:.2f}%")


        # Evaluate and compute confusion matrix
        test_accuracy, preds, labels = trainer.evaluate(test_loader)
        if generate_cm:
            cm = confusion_matrix(labels, preds)
            disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(10))
            disp.plot(cmap="viridis")
            plt.title(f"Confusion Matrix for M={M}")
            plt.show()

        if generate_tsne:
            # Plot t-SNE visualization
            plot_tsne(model, M, test_loader)

    if generate_tsne:
        create_animation(save_dir, output_file)
    return results

# Plot training curves
def plot_training_curves(results, values_of_M):
    plt.figure(figsize=(12, 6))
    
    # Plot test accuracy for each M
    for M in values_of_M:
        train_accs, test_accs = results[M]
        plt.plot(test_accs, label=f"M={M} (Test Accuracy)")

    plt.title("Test Accuracy vs. Epochs for Different Values of M")
    plt.xlabel("Epoch")
    plt.ylabel("Test Accuracy (%)")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid()
    plt.show()

# Plot t-SNE visualization
def plot_tsne(model, M, data_loader, save_dir="tsne_images"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    model.eval()
    features_list = []
    labels_list = []
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            _, features = model(images)
            features_list.append(features.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
    features = np.concatenate(features_list)
    labels = np.concatenate(labels_list)
    tsne = TSNE(n_components=2, random_state=0)
    tsne_results = tsne.fit_transform(features)
    plt.figure(figsize=(6, 6))
    scatter = plt.scatter(tsne_results[:, 0], tsne_results[:, 1], c=labels, cmap='viridis', alpha=0.5)
    plt.colorbar(scatter)
    plt.title(f't-SNE Visualization for M={M}')
    plt.savefig(os.path.join(save_dir, f'tsne_M_{M}.png'))
    plt.show()  # Display the plot in the output
    plt.close()

def create_animation(save_dir="tsne_images", output_file="tsne_animation.gif"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    fig = plt.figure(figsize=(6, 6))
    images = []
    
    # Extract numeric values of M from filenames and sort them
    file_names = sorted(
        [f for f in os.listdir(save_dir) if f.startswith('tsne_M_') and f.endswith('.png')],
        key=lambda x: int(x.split('_')[2].split('.')[0])
    )
    
    for file_name in file_names:
        img = plt.imread(os.path.join(save_dir, file_name))
        images.append([plt.imshow(img, animated=True)])
    
    ani = animation.ArtistAnimation(fig, images, interval=500, blit=True, repeat_delay=1000)
    ani.save(os.path.join(save_dir, output_file), writer='imagemagick')
    plt.close()


<h1> Hyperparameter Optimization with Optuna </h1>

In [None]:
def objective(trial):
    # Define the hyperparameter search space
    lambda_kmeans = trial.suggest_loguniform('lambda_kmeans', 1e-5, 1e-1)
    lambda_consistency = trial.suggest_loguniform('lambda_consistency', 1e-6, 1e-2)
    # dropout = trial.suggest_uniform('dropout', 0.1, 0.5)
    
    # Use a fixed value of M for the optimization
    M = 100  # You can change this value as needed
    values_of_M = [M]
    
    # Run the experiment with the current hyperparameters
    results = experiment(
        values_of_M=values_of_M, 
        epochs=10, 
        evaluate_every=2,
        lambda_kmeans=lambda_kmeans, 
        lambda_consistency=lambda_consistency, 
        use_consistency=True, 
        use_unlabeled=True, 
        generate_tsne=False, 
        generate_cm=False
    )
    
    # Get the test accuracy for the current hyperparameters
    test_accs = results[M][1]
    best_test_acc = max(test_accs)
    
    return best_test_acc

# Create an Optuna study and optimize the objective function
set_seed(7 * 5 * 3 * 2 * 2)

study_name = "Optimize_KMeans_and_Consistency_Losses"
study = optuna.create_study(study_name=study_name, direction='maximize')
study.optimize(objective, n_trials=200)

# Print the best hyperparameters
print("Best hyperparameters: ", study.best_params)

<h1> Visualize Hyperparameter Optimization Results </h1>

In [None]:
# Optimization History Plot
history_plot = optuna.visualization.plot_optimization_history(study)
history_plot.show()

# Hyperparameter Importance Plot
importance_plot = optuna.visualization.plot_param_importances(study)
importance_plot.show()

# Parallel Coordinate Plot
parallel_plot = optuna.visualization.plot_parallel_coordinate(study)
parallel_plot.show()

# Slice Plot
slice_plot = optuna.visualization.plot_slice(study)
slice_plot.show()

# Contour Plot
contour_plot = optuna.visualization.plot_contour(study)
contour_plot.show()

# EDF Plot
edf_plot = optuna.visualization.plot_edf(study)
edf_plot.show()