In [1]:
import os
import sys

sys.path.insert(0, os.path.abspath(".."))

import torch
import torch.nn as nn
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import load_configs
from train import train
from models.rednet import REDNet

from preprocessing import GaussianFilter

configs = load_configs("../configs.json")

In [None]:
model = REDNet()

optimizer = Adam(model.parameters())
loss_fn = nn.BCEWithLogitsLoss(reduction="sum")

epochs = 30
batch_size = 1000
n_plots = 64

root = configs["dataset_path"]

download = os.path.exists(os.path.join(root,"MNIST"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

trainset = datasets.MNIST(root=root, download=download, train=True, transform=transforms.ToTensor())
testset = datasets.MNIST(root=root, download=download, train=False, transform=transforms.ToTensor())

trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)
testloader = DataLoader(testset, shuffle=False, batch_size=batch_size)

preprocessing = GaussianFilter(1,7,1,3)

train(
    model,
    trainloader,
    testloader,
    epochs,
    optimizer,
    loss_fn,
    device,
    n_plots,
    preprocessing=preprocessing,
    save_dir=".."
)

1/30 epochs: 100%|██████████| 60/60 [00:23<00:00,  2.57it/s, train_reconstruction_loss=101, train_mse=0.02, train_psnr=20.2, train_ssim=0.886, test_reconstruction_loss=58.8, test_mse=0.00452, test_psnr=23.9, test_ssim=0.968]
2/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.78it/s, train_reconstruction_loss=55.8, train_mse=0.00321, train_psnr=25.5, train_ssim=0.98, test_reconstruction_loss=53, test_mse=0.0026, test_psnr=26.3, test_ssim=0.982]
3/30 epochs: 100%|██████████| 60/60 [00:22<00:00,  2.67it/s, train_reconstruction_loss=53.1, train_mse=0.00221, train_psnr=27, train_ssim=0.987, test_reconstruction_loss=51.3, test_mse=0.00199, test_psnr=27.5, test_ssim=0.987]
4/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.77it/s, train_reconstruction_loss=52, train_mse=0.00185, train_psnr=27.8, train_ssim=0.989, test_reconstruction_loss=50.5, test_mse=0.00172, test_psnr=28.1, test_ssim=0.989]
5/30 epochs: 100%|██████████| 60/60 [00:21<00:00,  2.76it/s, train_reconstruction_loss=51.4, tra