In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
import time
from tqdm import tqdm
import os

In [None]:
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
# Load the FashionMNIST dataset
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Split labeled data evenly across classes
def split_labeled_data(dataset, num_labels, seed=42):
    np.random.seed(seed)
    indices = np.arange(len(dataset))
    labels = np.array(dataset.targets)

    labeled_indices = []
    for i in range(10):  # 10 classes
        class_indices = np.where(labels == i)[0]
        labeled_indices.extend(np.random.choice(class_indices, num_labels // 10, replace=False))

    unlabeled_indices = list(set(indices) - set(labeled_indices))
    return labeled_indices, unlabeled_indices

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 7932572.19it/s] 


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 170071.74it/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:03<00:00, 1421193.75it/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 23019485.07it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw






In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=635, latent_dim=10):
        super(VAE, self).__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc3_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.fc4 = nn.Linear(latent_dim, hidden_dim)
        self.fc5 = nn.Linear(hidden_dim, hidden_dim)
        self.fc6 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = F.softplus(self.fc1(x))  # softplus activation
        h2 = F.softplus(self.fc2(h1))  # softplus activation
        mu = self.fc3_mu(h2)
        logvar = self.fc3_logvar(h2)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.softplus(self.fc4(z))  # softplus activation
        h4 = F.softplus(self.fc5(h3))  # softplus activation
        return torch.sigmoid(self.fc6(h4))  # final layer outputs a probability distribution

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten the input
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def save_vae_weights(vae, num_labels):
    # Create a directory to save weights if it doesn't exist
    save_dir = '/content/vae_weights'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # Save the weights
    save_path = os.path.join(save_dir, f'vae_weights_{num_labels}_labels.pth')
    torch.save(vae.state_dict(), save_path)
    print(f"VAE weights saved to {save_path}")

def print_data_sizes(num_labels):
    labeled_indices, unlabeled_indices = split_labeled_data(train_dataset, num_labels=num_labels)

    print(f"\nData sizes for {num_labels} labels:")
    print(f"  Labeled data size: {len(labeled_indices)}")
    print(f"  Unlabeled data size: {len(unlabeled_indices)}")
    print(f"  Test data size: {len(test_dataset)}")
    print(f"  Total training data size: {len(train_dataset)}")

def train_vae(vae, labeled_loader, unlabeled_loader, optimizer, num_labels, num_epochs=500):
    vae.train()
    for epoch in tqdm(range(num_epochs), desc=f'Training VAE with {num_labels} labels'):
        train_loss = 0
        for (data, _), (unlabeled_data, _) in zip(labeled_loader, unlabeled_loader):
            data = data.to(device)
            unlabeled_data = unlabeled_data.to(device)

            # Flatten the input
            data = data.view(data.size(0), -1)
            unlabeled_data = unlabeled_data.view(unlabeled_data.size(0), -1)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass for labeled data
            recon_batch, mu, logvar = vae(data)
            loss = vae.loss_function(recon_batch, data, mu, logvar)

            # Forward pass for unlabeled data
            recon_unlabeled, mu_unlabeled, logvar_unlabeled = vae(unlabeled_data)
            loss_unlabeled = vae.loss_function(recon_unlabeled, unlabeled_data, mu_unlabeled, logvar_unlabeled)

            # Combine losses (you might consider a weight for the unlabeled loss)
            total_loss = loss + loss_unlabeled  # You can modify this if you want to weight the losses differently
            total_loss.backward()
            train_loss += total_loss.item()
            optimizer.step()

        avg_loss = train_loss / len(labeled_loader.dataset)
        # print(f'Epoch {epoch+1}, Average loss: {avg_loss:.4f}')


# Extract latent representations
def extract_latent_representations(vae, data_loader):
    vae.eval()
    z_list, y_list = [], []
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.view(data.size(0), -1).to(device)
            mu, _ = vae.encode(data)
            z_list.append(mu.cpu())
            y_list.append(labels.cpu())

    return torch.cat(z_list).numpy(), torch.cat(y_list).numpy()

In [None]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def visualize_latent_space(z_train, y_train):
    pca = PCA(n_components=2)
    z_pca = pca.fit_transform(z_train)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(z_pca[:, 0], z_pca[:, 1], c=y_train, cmap='viridis', alpha=0.5)
    plt.colorbar(scatter)
    plt.title('PCA of Latent Space')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.show()

In [None]:
def run_experiment(num_labels):
    print_data_sizes(num_labels)
    labeled_indices, unlabeled_indices = split_labeled_data(train_dataset, num_labels=num_labels)

    labeled_subset = Subset(train_dataset, labeled_indices)
    unlabeled_subset = Subset(train_dataset, unlabeled_indices)

    batch_size = 64
    labeled_loader = DataLoader(labeled_subset, batch_size=batch_size, shuffle=True)
    unlabeled_loader = DataLoader(unlabeled_subset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Initialize and train VAE
    vae = VAE().to(device)
    optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3) #weight decay??

    # Initialize the scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.45)  # Halve the LR every 50 epochs

    train_vae(vae, labeled_loader, unlabeled_loader, optimizer, num_labels)

    # Extract latent representations
    z_train, y_train = extract_latent_representations(vae, labeled_loader)
    z_test, y_test = extract_latent_representations(vae, test_loader)

    # visualize_latent_space(z_train, y_train)

    clf = svm.SVC(kernel='rbf', random_state=42) #sigmoid yields bad accuracy
    clf.fit(z_train, y_train)

    # Test the classifier
    y_pred = clf.predict(z_test)
    accuracy = accuracy_score(y_test, y_pred)
    print(f'Test Accuracy with {num_labels} labels: {accuracy * 100:.2f}%')

    print(classification_report(y_test, y_pred))
    # Add this function call at the end of your training loop
    save_vae_weights(vae, num_labels)


# Run experiments for different label sizes
for label_size in [100, 600, 1000, 3000]:
    run_experiment(label_size)


Data sizes for 100 labels:
  Labeled data size: 100
  Unlabeled data size: 59900
  Test data size: 10000
  Total training data size: 60000


Training VAE with 100 labels: 100%|██████████| 500/500 [00:19<00:00, 25.18it/s]


Test Accuracy with 100 labels: 64.82%
              precision    recall  f1-score   support

           0       0.70      0.74      0.72      1000
           1       0.98      0.81      0.88      1000
           2       0.39      0.48      0.43      1000
           3       0.64      0.62      0.63      1000
           4       0.49      0.61      0.54      1000
           5       0.83      0.41      0.55      1000
           6       0.34      0.26      0.29      1000
           7       0.62      0.90      0.73      1000
           8       0.89      0.80      0.84      1000
           9       0.81      0.86      0.83      1000

    accuracy                           0.65     10000
   macro avg       0.67      0.65      0.65     10000
weighted avg       0.67      0.65      0.65     10000

VAE weights saved to /content/vae_weights/vae_weights_100_labels.pth

Data sizes for 600 labels:
  Labeled data size: 600
  Unlabeled data size: 59400
  Test data size: 10000
  Total training data size: 

Training VAE with 600 labels: 100%|██████████| 500/500 [01:32<00:00,  5.38it/s]


Test Accuracy with 600 labels: 76.32%
              precision    recall  f1-score   support

           0       0.73      0.80      0.76      1000
           1       0.99      0.90      0.94      1000
           2       0.65      0.57      0.61      1000
           3       0.71      0.87      0.78      1000
           4       0.58      0.61      0.59      1000
           5       0.88      0.79      0.83      1000
           6       0.46      0.39      0.43      1000
           7       0.81      0.90      0.85      1000
           8       0.93      0.90      0.91      1000
           9       0.87      0.91      0.89      1000

    accuracy                           0.76     10000
   macro avg       0.76      0.76      0.76     10000
weighted avg       0.76      0.76      0.76     10000

VAE weights saved to /content/vae_weights/vae_weights_600_labels.pth

Data sizes for 1000 labels:
  Labeled data size: 1000
  Unlabeled data size: 59000
  Test data size: 10000
  Total training data size

Training VAE with 1000 labels: 100%|██████████| 500/500 [02:27<00:00,  3.40it/s]


Test Accuracy with 1000 labels: 76.47%
              precision    recall  f1-score   support

           0       0.72      0.81      0.76      1000
           1       0.99      0.92      0.95      1000
           2       0.67      0.55      0.60      1000
           3       0.77      0.81      0.79      1000
           4       0.56      0.70      0.62      1000
           5       0.87      0.79      0.83      1000
           6       0.45      0.37      0.41      1000
           7       0.81      0.90      0.85      1000
           8       0.94      0.90      0.92      1000
           9       0.88      0.90      0.89      1000

    accuracy                           0.76     10000
   macro avg       0.76      0.76      0.76     10000
weighted avg       0.76      0.76      0.76     10000

VAE weights saved to /content/vae_weights/vae_weights_1000_labels.pth

Data sizes for 3000 labels:
  Labeled data size: 3000
  Unlabeled data size: 57000
  Test data size: 10000
  Total training data si

Training VAE with 3000 labels: 100%|██████████| 500/500 [07:00<00:00,  1.19it/s]


Test Accuracy with 3000 labels: 79.90%
              precision    recall  f1-score   support

           0       0.76      0.75      0.75      1000
           1       0.98      0.94      0.96      1000
           2       0.69      0.66      0.67      1000
           3       0.79      0.84      0.82      1000
           4       0.63      0.72      0.67      1000
           5       0.94      0.82      0.88      1000
           6       0.52      0.46      0.49      1000
           7       0.86      0.91      0.89      1000
           8       0.95      0.95      0.95      1000
           9       0.86      0.95      0.90      1000

    accuracy                           0.80     10000
   macro avg       0.80      0.80      0.80     10000
weighted avg       0.80      0.80      0.80     10000

VAE weights saved to /content/vae_weights/vae_weights_3000_labels.pth


In [None]:
from google.colab import files
import shutil

shutil.make_archive('vae_weights', 'zip', '/content/vae_weights')
files.download('vae_weights.zip')


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>