In [11]:
#pytorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm

#other libraries
import numpy as np
import matplotlib.pyplot as plt

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

batch_size = 4

root_dir = '/mnt/storage/Datasets/CIFAR10' #make sure to change it to your own path

train_set = datasets.CIFAR10(root=root_dir, train=True, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)

test_set = datasets.CIFAR10(root=root_dir, train=False, download=True, transform=transform)

test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
#test the data loader
data_iter = iter(train_loader)
images, labels = next(data_iter)
print(images.shape)

torch.Size([4, 3, 32, 32])


In [9]:
class AutoEncoder(nn.Module):
    def __init__(self, input_shape, hidden_dim=128, latent_dim=64):
        super().__init__()
        self.input_shape = input_shape
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_shape**2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.input_shape**2),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(-1, 3, 3, self.input_shape)
        return x

In [12]:
auto_encoder = AutoEncoder(32)

loss_fn = nn.MSELoss()

optimizer = Adam(auto_encoder.parameters(), lr=1e-1)

In [None]:

#training process

epochs = 100
outputs = []

loss_list = []
 

for epoch in epochs:
    for images, _ in tqdm(train_loader):
        
        recon = auto_encoder(images)
        