In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from utils import load_configs
from train import train
from models.autoencoder import AutoEncoder
from preprocessing import GaussianFilter

configs = load_configs("../configs.json")

In [2]:
model = AutoEncoder()

optimizer = Adam(model.parameters())
loss_fn = nn.BCEWithLogitsLoss(reduction="sum")

epochs = 30
batch_size = 1000
n_plots = 64

root = configs["dataset_path"]

download = os.path.exists(os.path.join(root,"MNIST"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

trainset = datasets.MNIST(root=root, download=download, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root=root, download=download, train=False, transform=transforms.ToTensor())

trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)
testloader = DataLoader(testset, shuffle=False, batch_size=batch_size)

preprocessing = GaussianFilter(1,7,1,3)

train(
    model,
    trainloader,
    testloader,
    epochs,
    optimizer,
    loss_fn,
    device,
    n_plots,
    preprocessing=preprocessing,
    save_dir=".."
)

1/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.77it/s, train_reconstruction_loss=279, train_mse=0.0996, train_psnr=10.8, train_ssim=0.339, test_reconstruction_loss=190, test_mse=0.0604, test_psnr=12.5, test_ssim=0.482]
2/30 epochs: 100%|██████████| 60/60 [00:22<00:00,  2.65it/s, train_reconstruction_loss=183, train_mse=0.0575, train_psnr=12.7, train_ssim=0.534, test_reconstruction_loss=186, test_mse=0.0592, test_psnr=12.4, test_ssim=0.485]
3/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.80it/s, train_reconstruction_loss=156, train_mse=0.0456, train_psnr=13.7, train_ssim=0.652, test_reconstruction_loss=151, test_mse=0.044, test_psnr=13.9, test_ssim=0.62]
4/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.80it/s, train_reconstruction_loss=135, train_mse=0.0365, train_psnr=14.7, train_ssim=0.736, test_reconstruction_loss=138, test_mse=0.0387, test_psnr=14.5, test_ssim=0.68]
5/30 epochs: 100%|██████████| 60/60 [00:23<00:00,  2.59it/s, train_reconstruction_loss=122, train_mse=0