In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm


In [4]:
batch_size = 64

# Define learning rate
lr = 1e-4

# Number of Training epochs
nepoch = 10

# Dataset location
root = "../datasets"

# Scale for the added image noise
noise_scale = 0.3


In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)


Using device: mps


In [6]:
# Define our transform
# We'll upsample the images to 32x32 as it's easier to contruct our network
transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])

train_set = Datasets.MNIST(root=root, train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=batch_size,shuffle=True, num_workers=4)

test_set = Datasets.MNIST(root=root, train=False, transform=transform, download=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)


100%|██████████| 9.91M/9.91M [00:19<00:00, 508kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 112kB/s]
100%|██████████| 1.65M/1.65M [00:09<00:00, 178kB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.11MB/s]


In [None]:
# We split up our network into two parts, the Encoder and the Decoder
class Encoder(nn.Module):
    def __init__(self, channels, ch=32, z=32):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(channels, ch, 3, 2, 1)
        self.bn1 = nn.BatchNorm2d(ch)
        self.conv2 = nn.Conv2d(ch, 2 * ch, 3, 2, 1)
        self.bn2 = nn.BatchNorm2d(2 * ch)
        self.conv3 = nn.Conv2d(2 * ch, 4 * ch, 3, 2, 1)
        self.bn3 = nn.BatchNorm2d(4 * ch)

        self.conv_out = nn.Conv2d(4 * ch, z, 4, 1)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        return self.conv_out(x)
    
class Decoder(nn.Module):
    def __init__(self, channels, ch = 32, z = 32):
        super(Decoder, self).__init__()
        
        self.conv1 = nn.ConvTranspose2d(z, 4 * ch, 4, 1)
        self.bn1 = nn.BatchNorm2d(4 * ch)
        self.conv2 = nn.ConvTranspose2d(4 * ch, 2 * ch, 3, 2, 1, 1)
        self.bn2 = nn.BatchNorm2d(2 * ch)
        self.conv3 = nn.ConvTranspose2d(2 * ch, ch, 3, 2, 1, 1)
        self.bn3 = nn.BatchNorm2d(ch)
        self.conv4 = nn.ConvTranspose2d(ch, ch, 3, 2, 1, 1)
        self.bn4 = nn.BatchNorm2d(ch)

        self.conv_out = nn.Conv2d(ch, channels, 3, 1, 1)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))

        return torch.tanh(self.conv_out(x))
    
class AE(nn.Module):
    def __init__(self, channel_in, ch=16, z=32):
        super(AE, self).__init__()
        self.encoder = Encoder(channels=channel_in, ch=ch, z=z)
        self.decoder = Decoder(channels=channel_in, ch=ch, z=z)

    def forward(self, x):
        encoding = self.encoder(x)
        x = self.decoder(encoding)
        return x, encoding
