In [None]:
import torch
import torch.nn as nn

class ConvAutoencoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encode = nn.Sequential(
            nn.Conv2d(1,16,3,stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16,32,3,stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32*7*7, 16),
            nn.ReLU()
        )

        self.reshape = nn.Linear(16,32*7*7)

        self.classify = nn.Linear(16,10)

        self.decode = nn.Sequential(
            nn.ConvTranspose2d(32,16,3,stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16,1,3,stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
            nn.ReLU()
        )

    def forward(self,x):
        encode = self.encode(x)
        classify = self.classify(encode)
        x = self.reshape(encode)
        x = x.view(x.size(0),32,7,7)
        decode = self.decode(x)
        return decode, classify
    
    
model = ConvAutoencoder()
decode, classify = model(torch.randn(size=(32,1,28,28)))
decode.shape, classify.shape

(torch.Size([32, 1, 28, 28]), torch.Size([32, 10]))

In [None]:
# Initialize Model, Loss, and Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder = ConvAutoencoder().to(device)

criterion = nn.MSELoss()
classification = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)


In [None]:

# Training Loop
num_epochs = 10
for epoch in range(num_epochs):
    for images, laebls in train_loader:
        images = images.to(device)  # No flattening!

        outputs, classify = autoencoder(images)  # Use existing model instance
        loss_recon = criterion(outputs, images)  # Reconstruction 
        loss_class = classification(classify, laebls)  # Reconstruction
        loss = loss_recon + loss_class

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

print("Training complete.")