In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.mixture import GaussianMixture
from sklearn.metrics import mean_squared_error, accuracy_score, adjusted_rand_score, normalized_mutual_info_score
import numpy as np
import json

# 1. Load CIFAR-10 Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# 2. Define CNN Autoencoder
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.Flatten(),
            nn.Linear(8192, 512), nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(512, 8192), nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1), nn.Sigmoid()
        )

    def forward(self, x):
        latent = self.encoder(x)
        return self.decoder(latent), latent

autoencoder = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

# 3. Train Autoencoder
for epoch in range(10):
    for images, _ in train_loader:
        optimizer.zero_grad()
        reconstructed, _ = autoencoder(images)
        loss = criterion(reconstructed, images)
        loss.backward()
        optimizer.step()

# Save reconstructed test images
reconstructed_images = []
with torch.no_grad():
    for images, _ in test_loader:
        reconstructed, _ = autoencoder(images)
        reconstructed_images.extend(reconstructed)

# 4. Define and Train NN Classifier
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(512, 256), nn.ReLU(),
            nn.Linear(256, 128), nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.fc(x)

classifier = Classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=0.001)

# Freeze autoencoder's encoder and use it as a feature extractor
autoencoder.encoder.eval()
for param in autoencoder.encoder.parameters():
    param.requires_grad = False

# Train classifier
for epoch in range(10):
    for images, labels in train_loader:
        with torch.no_grad():
            _, features = autoencoder(images)
        optimizer.zero_grad()
        outputs = classifier(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Evaluate classifier accuracy
correct, total = 0, 0
predictions = {}
with torch.no_grad():
    for images, labels in test_loader:
        _, features = autoencoder(images)
        outputs = classifier(features)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        for idx, label in enumerate(predicted):
            predictions[f"image_{idx:03}"] = label.item()

accuracy = 100 * correct / total
print(f'Classification Accuracy: {accuracy:.2f}%')

# Save predictions
with open("classification_predictions.json", "w") as f:
    json.dump(predictions, f)

# 5. GMM Clustering
gmm = GaussianMixture(n_components=10)
features_list, true_labels = [], []

with torch.no_grad():
    for images, labels in test_loader:
        _, features = autoencoder(images)
        features_list.append(features)
        true_labels.extend(labels.numpy())

features_list = torch.cat(features_list).cpu().numpy()
gmm.fit(features_list)
clusters = gmm.predict(features_list)

# Evaluate GMM Clustering
ari = adjusted_rand_score(true_labels, clusters)
nmi = normalized_mutual_info_score(true_labels, clusters)
print(f'Adjusted Rand Index (ARI): {ari:.4f}')
print(f'Normalized Mutual Information (NMI): {nmi:.4f}')

# Save clusters
clusters_dict = {f"image_{i:03}": int(cluster) for i, cluster in enumerate(clusters)}
with open("gmm_clusters.json", "w") as f:
    json.dump(clusters_dict, f)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:02<00:00, 67.8MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Classification Accuracy: 16.52%
Adjusted Rand Index (ARI): 0.0304
Normalized Mutual Information (NMI): 0.0593
