In [31]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split


# Load Data

In [32]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)


train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # Batch size 64 (adjust as needed)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False) 

# AE

In [33]:
class ConvolutionalAutoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(ConvolutionalAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),  # Example: 1 input channel (grayscale), 16 output channels
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim , 32 * 7 * 7),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1), # Transpose convolution for upsampling
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()  # Output pixel values between 0 and 1
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded,decoded

In [34]:
ae_model = ConvolutionalAutoencoder(latent_dim=64)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(ae_model.parameters(), lr=0.001)

epoch = 1
for i in range(epoch):
    ite = 0
    for x,y in train_loader:
        
        optimizer.zero_grad()
        
        _, decoded = ae_model(x)
        loss = criterion(decoded, x)
        
        loss.backward()
        optimizer.step()
        
        ite += 1
        
        if ite % 10 == 0:
            print("iteration: ", ite , "Loss: ", float(loss))

iteration:  10 Loss:  1.0243024826049805
iteration:  20 Loss:  0.8595171570777893
iteration:  30 Loss:  0.802959144115448
iteration:  40 Loss:  0.7886083722114563
iteration:  50 Loss:  0.8005200624465942
iteration:  60 Loss:  0.7741352915763855
iteration:  70 Loss:  0.7678083181381226
iteration:  80 Loss:  0.71996009349823
iteration:  90 Loss:  0.7678654193878174
iteration:  100 Loss:  0.7205885052680969
iteration:  110 Loss:  0.7502965927124023
iteration:  120 Loss:  0.6717173457145691
iteration:  130 Loss:  0.6911721229553223
iteration:  140 Loss:  0.6552340984344482
iteration:  150 Loss:  0.6661316156387329
iteration:  160 Loss:  0.6115263104438782
iteration:  170 Loss:  0.6020353436470032
iteration:  180 Loss:  0.5943698287010193
iteration:  190 Loss:  0.5878819227218628
iteration:  200 Loss:  0.5584214329719543
iteration:  210 Loss:  0.5546462535858154
iteration:  220 Loss:  0.5818150043487549
iteration:  230 Loss:  0.5553558468818665
iteration:  240 Loss:  0.5387664437294006
iter

In [35]:
ae_model.eval

with torch.no_grad():
    out = []
    for x,y in test_loader:
        
        encoded, _ = ae_model(x)
        out.append(encoded)
        
out = torch.stack(out)
print(out.shape)
out

torch.Size([120, 100, 64])


tensor([[[  2.1064,   9.5924,  29.5344,  ...,  -1.6584, -11.8681,   0.8876],
         [ 10.4090,   4.9620,   5.0505,  ..., -29.4574, -11.1802,  16.3597],
         [ 17.4587,  -0.5445,  13.4837,  ..., -12.1434, -20.9418,  13.2182],
         ...,
         [  9.9296,  -2.0457,   2.3567,  ..., -10.0997,  -2.9495,   3.0960],
         [ -2.8685,   6.9513,  16.9549,  ..., -11.4245,   1.6067,  16.2619],
         [ -4.4723,  -2.3318,  21.9335,  ...,   6.8050,  -5.7151,   1.4429]],

        [[ 12.0254,  13.1583, -33.6256,  ..., -18.7318,   2.7727,  21.2906],
         [ -3.4671,  11.5326,   4.2385,  ..., -29.5447, -12.9275,  14.6879],
         [  7.2121, -12.4372,  13.8982,  ..., -21.2287,  -0.7825,  10.6011],
         ...,
         [-10.7555,   9.0414,   3.0773,  ..., -22.2989,   4.6603,   1.2493],
         [ -3.7232,   5.4407,   8.8208,  ...,   3.7009, -12.4770,   1.8025],
         [  9.3688,  21.1523,  13.4924,  ...,  -7.0173,   4.3205,  10.1362]],

        [[ -0.2827,  11.2343, -14.8151,  ...