In [44]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import torch.nn as nn
from torch.optim import SGD
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

train_dataset = MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)

image, label = train_dataset[0]
image.shape

torch.Size([1, 28, 28])

In [36]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder =  nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=(3,3), stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(in_features=4*26*26, out_features=16)
        )

        self.decoder = nn.Sequential(
            nn.Linear(16, 4*26*26),
            nn.Unflatten(1, (4, 26, 26)),
            nn.ConvTranspose2d(4, 1, kernel_size=(3,3), stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [14]:
model = AutoEncoder()

In [28]:
criterion = nn.MSELoss()
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Flatten(start_dim=1, end_dim=-1)
    (3): Linear(in_features=2704, out_features=16, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=2704, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(4, 26, 26))
    (2): ConvTranspose2d(4, 1, kernel_size=(3, 3), stride=(1, 1))
    (3): Sigmoid()
  )
)

In [34]:
def train(model, data_loader, epoch):
    model.train()
    total_loss = []
    for i in tqdm(range(epoch)):
        train_loss = 0
        for image, _ in data_loader:
            image = image.to(device)
            pred = model(image)
            optimizer.zero_grad()
            loss = criterion(image, pred)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        total_loss.append(train_loss)
        avg_loss = train_loss / len(data_loader)
        print(f"Epoch:{i}: Avg_Loss:{avg_loss}")

In [35]:
train(model, train_loader, 10)

 10%|█         | 1/10 [00:08<01:16,  8.55s/it]

Epoch:0: Avg_Loss:0.09454656923413277


 20%|██        | 2/10 [00:16<01:06,  8.31s/it]

Epoch:1: Avg_Loss:0.09177553934057553


 30%|███       | 3/10 [00:24<00:57,  8.25s/it]

Epoch:2: Avg_Loss:0.08355558652877808


 40%|████      | 4/10 [00:32<00:47,  8.00s/it]

Epoch:3: Avg_Loss:0.07069299393693607


 50%|█████     | 5/10 [00:40<00:39,  7.92s/it]

Epoch:4: Avg_Loss:0.06138290915191173


 60%|██████    | 6/10 [00:47<00:30,  7.67s/it]

Epoch:5: Avg_Loss:0.05533997349739075


 70%|███████   | 7/10 [00:54<00:22,  7.56s/it]

Epoch:6: Avg_Loss:0.05123660056690375


 80%|████████  | 8/10 [01:03<00:15,  7.86s/it]

Epoch:7: Avg_Loss:0.04822347264587879


 90%|█████████ | 9/10 [01:11<00:08,  8.04s/it]

Epoch:8: Avg_Loss:0.04583048402766387


100%|██████████| 10/10 [01:20<00:00,  8.03s/it]

Epoch:9: Avg_Loss:0.04382894632021586





In [None]:
image, label = test_dataset[0]
image.shape
pred = model(image.to(device).unsqueeze(0)).cpu().detach().numpy()
pred.shape
# image_transposed = np.transpose(pred, (1, 2, 0))
# plt.imshow(image_transposed)
# plt.show()

ValueError: axes don't match array