In [1]:
import torch
import torch.nn as nn
import torch.optim as optim # for Adam optimizer

from torchvision import datasets, transforms # datasets to use MNIST dataset, transforms for converting to tensors etc
from torch.utils.data import DataLoader # to batch, shuffle, etc data

In [2]:
stackedTransform = transforms.Compose([
    transforms.ToTensor(), # converts from image to tensor (each pixel value changed from 1-255 to 0-1)
    transforms.Normalize((0.5,), (0.5,)) # normalizes values from between 0-1 to between -1 to 1
])

train_data = datasets.MNIST(root="/data/MNISTcae", train=True, transform=stackedTransform, download=True)
test_data = datasets.MNIST(root="/data/MNISTcae", train=False, transform=stackedTransform, download=True)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False)

100%|██████████| 9.91M/9.91M [00:07<00:00, 1.38MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.48MB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.45MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 2.28MB/s]


In [None]:
class CAutoencoder(nn.Module):
    def __init__(self):
        super(self).__init__()

        self.encoder = nn.Sequential(
            # following input -> (conv, relu, pool) pattern
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1), # (28, 28, 1) -> (28, 28, 16)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # (28, 28, 16) -> (14, 14, 16)

            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1), # (14, 14, 16) -> (14, 14, 32)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), # (14, 14, 32) -> (7, 7, 32)
            
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1), # (7,7,32) -> (7, 7, 64)
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True), # (7, 7, 64) -> (4, 4, 64) to get more features in latent space?
            # latent space is (4, 4, 64) = 1024 features (rather than 3x3x64=576)
        )

        self.decoder = nn.Sequential(

        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded