In [11]:
import torch
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.encode = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 256),
            nn.Linear(256, 64),
            nn.Linear(64, 32),
        )
        self.decode = nn.Sequential(
            nn.Linear(32, 64),
            nn.Linear(64, 256),
            nn.Linear(256, input_dim),
        )
        self.classify = nn.Linear(32, 10)

    def forward(self,x):
        encode = self.encode(x)
        decode = self.decode(encode)
        classify = self.classify(encode)
        return decode, classify

In [12]:

# Initialize Model, Loss, and Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = Autoencoder(1*28*28,10).to(device)
criterion_recon = nn.MSELoss()
criterion_class = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

In [None]:
# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    for images, labels in train_loader:
        # images = images.view(images.size(0), -1).to(device) 
        #  # Flatten images
        images = images.to(device)
        labels = labels.to(device)

        outputs, y_pred = autoencoder(images)  # Forward pass
        loss_recon = criterion_recon(outputs, images)  # Reconstruction loss
        loss_class = criterion_class(y_pred, labels)  # Classification loss
        loss = loss_recon + loss_class  # Total loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training complete.")

In [None]:
# Testing Loop - Compute Accuracy
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1).to(device)
        labels = labels.to(device)

        encoded = autoencoder.encode(images)
        y_pred_test = autoencoder.classify(encoded)
        predicted = torch.argmax(y_pred_test, dim=1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")