In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from tqdm import tqdm

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, num_workers=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, num_workers=8, shuffle=False)

class ContractiveAutoencoder(nn.Module):
    def __init__(self):
        super(ContractiveAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 16)
        )
        self.decoder = nn.Sequential(
            nn.Linear(16, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        decoded = decoded.view(x.size(0), 1, 28, 28)  # Reshape to image
        return decoded

    def encode(self, x):
        x = x.view(x.size(0), -1)
        return self.encoder(x)

def jacobian_penalty(model, inputs):
    inputs.requires_grad = True
    outputs = model.encode(inputs)
    jacobian = torch.autograd.functional.jacobian(model.encode, inputs, create_graph=True)
    penalty = torch.sum(jacobian**2)
    return penalty


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ContractiveAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
lambda_reg = 1e-3  # Regularization strength

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, _ in tqdm(train_loader):
        images = images.to(device)
        outputs = model(images)
        loss = criterion(outputs, images)
        penalty = jacobian_penalty(model, images)
        total_loss = loss + lambda_reg * penalty

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}')


100%|██████████| 469/469 [05:03<00:00,  1.54it/s]


Epoch [1/10], Loss: 0.9233


 17%|█▋        | 81/469 [00:57<04:34,  1.42it/s]


KeyboardInterrupt: 

In [None]:
model.eval()
with torch.no_grad():
    train_latents = []
    train_labels = []
    for images, labels in train_loader:
        images = images.to(device)
        latents = model.encode(images).cpu().numpy()
        train_latents.append(latents)
        train_labels.append(labels.cpu().numpy())
    train_latents = np.concatenate(train_latents)
    train_labels = np.concatenate(train_labels)

    test_latents = []
    test_labels = []
    for images, labels in test_loader:
        images = images.to(device)
        latents = model.encode(images).cpu().numpy()
        test_latents.append(latents)
        test_labels.append(labels.cpu().numpy())
    test_latents = np.concatenate(test_latents)
    test_labels = np.concatenate(test_labels)


In [None]:
log_reg = LogisticRegression(max_iter=1000)
log_reg.fit(train_latents, train_labels)
preds = log_reg.predict(test_latents)
accuracy = accuracy_score(test_labels, preds)
print(f'Classification accuracy using latent space: {accuracy:.4f}')


In [None]:
# Using t-SNE for dimensionality reduction
tsne = TSNE(n_components=2, random_state=42)
latent_2d_tsne = tsne.fit_transform(test_latents)

plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_2d_tsne[:, 0], latent_2d_tsne[:, 1], c=test_labels, cmap='tab10', alpha=0.5)
plt.colorbar(scatter)
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.title('2D t-SNE Projection of Latent Space')
plt.show()
