In [2]:
# Import necessary libraries
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from torch.optim import Adam
import torch.nn.functional as F

# Set up project paths
def setup_paths():
    cwd = os.getcwd()
    project_root = os.path.abspath(os.path.join(cwd, ".."))
    src_path = os.path.join(project_root, "src")
    if src_path not in sys.path:
        sys.path.append(src_path)

setup_paths()

In [3]:
# Import VAE models and utilities
from model import MNISTVariationalAutoEncoder, CIFAR10VariationalAutoEncoder
from utils import m2_loss_labeled, validate

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
## Function to load dataset
def load_dataset(dataset_name):
    if dataset_name == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        dataset = datasets.MNIST(root="../data", train=True, transform=transform, download=True)
        test_dataset = datasets.MNIST(root="../data", train=False, transform=transform, download=True)
        model = MNISTVariationalAutoEncoder(latent_dim=128, num_classes=10).to(device)
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        dataset = datasets.CIFAR10(root="../data", train=True, transform=transform, download=True)
        test_dataset = datasets.CIFAR10(root="../data", train=False, transform=transform, download=True)
        model = CIFAR10VariationalAutoEncoder(latent_dim=128, num_classes=10).to(device)

    return dataset, test_dataset, model

In [5]:
# Function to train the model on labeled data only
def train_labeled(model, loader, optimizer, criterion, num_epochs=15):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # Labeled data loss
            y_onehot = F.one_hot(y, model.num_classes).float().to(device)
            recon, mean, log_var, logits = model(x, y_onehot=y_onehot)
            loss = criterion(recon, x, mean, log_var, logits, y)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Labeled Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")

In [6]:
# Function to train the model on the full dataset (labeled + unlabeled)
def train_full(model, labeled_loader, unlabeled_loader, optimizer, criterion_labeled, criterion_unlabeled, num_epochs=15):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for (x_labeled, y_labeled), (x_unlabeled, _) in zip(labeled_loader, unlabeled_loader):
            x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
            x_unlabeled = x_unlabeled.to(device)

            optimizer.zero_grad()

            # Labeled data loss
            y_onehot = F.one_hot(y_labeled, model.num_classes).float().to(device)
            recon_labeled, mean_l, log_var_l, logits_l = model(x_labeled, y_onehot=y_onehot)
            loss_labeled = criterion_labeled(recon_labeled, x_labeled, mean_l, log_var_l, logits_l, y_labeled)

            # Unlabeled data loss
            recon_unlabeled, mean_u, log_var_u, logits_u = model(x_unlabeled)
            loss_unlabeled = criterion_unlabeled(recon_unlabeled, x_unlabeled, mean_u, log_var_u, logits_u)

            # Combine losses
            loss = loss_labeled + loss_unlabeled
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Full Dataset Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss:.4f}")

In [7]:
# Function to save the trained model
def save_model(model, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"Model saved to: {path}")

# Function to load the trained model
def load_model(model, path):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
        print(f"Model loaded successfully from {path}")
    else:
        raise FileNotFoundError(f"Checkpoint not found at {path}. Ensure the model is trained and saved.")
    return model

In [8]:
# Function to extract logits and images
def extract_logits_and_images(model, data_loader):
    model.eval()
    logits_list, labels_list, images_list = [], [], []
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            _, _, _, logits = model(x)
            logits_list.append(logits.cpu().numpy())
            labels_list.append(y.numpy())
            images_list.append(x.cpu().numpy())
    return np.concatenate(logits_list), np.concatenate(labels_list), np.concatenate(images_list)

In [9]:
# Visualization functions
def plot_logits_and_images(images, logits, labels, dataset_name, num_samples=5):
    softmax_logits = torch.softmax(torch.tensor(logits), dim=1).numpy()
    sample_indices = np.random.choice(len(logits), num_samples, replace=False)
    plt.figure(figsize=(15, 8))
    for i, idx in enumerate(sample_indices):
        plt.subplot(2, num_samples, i + 1)
        img = images[idx]
        if dataset_name == "MNIST":
            plt.imshow(img.squeeze(), cmap="gray")
        else:
            img = np.transpose(img, (1, 2, 0))
            plt.imshow((img - img.min()) / (img.max() - img.min()))
        plt.axis("off")
        plt.title(f"Label: {labels[idx]}")
        plt.subplot(2, num_samples, num_samples + i + 1)
        plt.bar(range(len(softmax_logits[idx])), softmax_logits[idx])
        plt.title("Softmax of Logits")
        plt.xlabel("Class")
        plt.ylabel("Probability")
    plt.tight_layout()
    plt.show()

In [10]:
def visualize_latent_space(logits, labels, method="PCA"):
    if method == "PCA":
        reducer = PCA(n_components=2)
    elif method == "TSNE":
        reducer = TSNE(n_components=2, random_state=42)
    else:
        raise ValueError("Invalid method. Choose 'PCA' or 'TSNE'.")
    reduced_latents = reducer.fit_transform(logits)
    plt.figure(figsize=(8, 8))
    scatter = plt.scatter(reduced_latents[:, 0], reduced_latents[:, 1], c=labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter, label='Class Label')
    plt.title(f"Latent Space Visualization ({method})")
    plt.xlabel(f"{method} Component 1")
    plt.ylabel(f"{method} Component 2")
    plt.show()

In [12]:
dataset_name = "MNIST"  # Change to "CIFAR10" for CIFAR-10
dataset, test_dataset, vae_model = load_dataset(dataset_name)

# Split dataset for labeled and unlabeled training
labeled_size = int(len(dataset) * 0.1)
unlabeled_size = len(dataset) - labeled_size
labeled_data, unlabeled_data = random_split(dataset, [labeled_size, unlabeled_size])
labeled_loader = DataLoader(labeled_data, batch_size=128, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Train on labeled data only
optimizer = Adam(vae_model.parameters(), lr=1e-3)
train_labeled(vae_model, labeled_loader, optimizer, m2_loss_labeled, num_epochs=15)
save_model(vae_model, "../trained_models/vae_labeled_only.pth")

# Train on the full dataset (labeled + unlabeled)
vae_model.load_state_dict(torch.load("../trained_models/vae_labeled_only.pth", map_location=device))
train_full(vae_model, labeled_loader, unlabeled_loader, optimizer, m2_loss_labeled, m2_loss_unlabeled, num_epochs=15)
save_model(vae_model, "../trained_models/vae_full_trained.pth")

ValueError: too many values to unpack (expected 3)

In [None]:
# Load model and analyze logits
vae_model = load_model(vae_model, "../trained_models/vae_labeled_only.pth")
logits, labels, images = extract_logits_and_images(vae_model, test_loader)
plot_logits_and_images(images, logits, labels, dataset_name)
visualize_latent_space(logits, labels, method="PCA")
visualize_latent_space(logits, labels, method="TSNE")

In [None]:
# Load model and analyze logits
vae_model = load_model(vae_model, "../trained_models//vae_full_trained.pth")
logits, labels, images = extract_logits_and_images(vae_model, test_loader)
plot_logits_and_images(images, logits, labels, dataset_name)
visualize_latent_space(logits, labels, method="PCA")
visualize_latent_space(logits, labels, method="TSNE")