In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms as T

In [6]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder_conv1 = nn.Conv2d(3, 32, 2, 1)
        self.encoder_bn1 = nn.BatchNorm2d(32)
        self.encoder_conv2 = nn.Conv2d(32, 16, 2, 1)
        self.encoder_bn2 = nn.BatchNorm2d(16)
        self.encoder_conv3 = nn.Conv2d(16, 3, 2, 2)
        self.encoder_bn3 = nn.BatchNorm2d(3)

    def forward(self, x):
        x = F.relu(self.encoder_bn1(self.encoder_conv1(x)))
        x = F.relu(self.encoder_bn2(self.encoder_conv2(x)))
        x = F.relu(self.encoder_bn3(self.encoder_conv3(x)))
        return x    


class Decoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.decoder_deconv1 = nn.ConvTranspose2d(3, 16, 2, 2)
        self.decoder_bn1 = nn.BatchNorm2d(16)
        self.decoder_deconv2 = nn.ConvTranspose2d(16, 32, 2, 1)
        self.decoder_bn2 = nn.BatchNorm2d(32)
        self.decoder_deconv3 = nn.ConvTranspose2d(32, 3, 2, 1)
        self.decoder_bn3 = nn.BatchNorm2d(3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = F.relu(self.decoder_bn1(self.decoder_deconv1(x)))
        x = F.relu(self.decoder_bn2(self.decoder_deconv2(x)))
        x = F.relu(self.decoder_bn3(self.decoder_deconv3(x)))
        x = self.sigmoid(x)
        return x


class AutoEncoder(nn.Module):
    def __init__(self, quantize_level) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.quantize_level = quantize_level

    def forward(self, x):
        # encode input
        x = self.encoder(x)

        # add noise
        x = torch.nn.functional.sigmoid(x)
        noise = torch.normal(0.5, 0.5) / (2 ** self.quantize_level)
        x += noise
        x = torch.log(x / (1 - x))

        # decode
        x = self.decoder(x)

        return x

In [20]:
train_transform = torchvision.transforms.Compose([
    T.RandomResizedCrop((512, 512)),
    T.ToTensor()
])
test_transform = torchvision.transforms.Compose([
    T.CenterCrop((512, 512)),
    T.ToTensor()
])
transform = T.ToTensor()

In [12]:
train_dataset = torchvision.datasets.ImageNet(root='../data', transform=train_transform)
test_dataset = torchvision.datasets.ImageNet(root='../data', transform=test_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:21<00:00, 7794773.11it/s]


Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [13]:
train_dataloader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True
)
test_dataloader = DataLoader(
    train_dataset, batch_size=32, shuffle=False, num_workers=4, pin_memory=True
)

In [None]:
torch.optim.lr_scheduler.CosineAnnealingLR()