In [None]:
# import libraries
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from networks import ConvUpscaleDenoiser, ConvTransposeDenoiser

In [None]:
# set available device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# extract and transform the data
train_set = torchvision.datasets.MNIST(
    root='./data/',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)
test_set = torchvision.datasets.MNIST(
    root='./data/',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=32, num_workers=1)

In [None]:
conv_upscale_net = ConvUpscaleDenoiser()
conv_upscale_net

In [None]:
conv_upscale_net.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(conv_upscale_net.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=3, verbose=True)

In [None]:
noise_factor = 0.5  # for adding noise to the images
num_epochs = 30
for epoch in range(num_epochs):
    train_loss = 0
    ###################
    # train the model #
    ###################
    loop = tqdm(train_loader, total=len(train_loader))
    for batch in loop:
        images, _ = batch
        # add random noise to the input images
        noisy_imgs = images + noise_factor * torch.randn(*images.shape)
        # clip the pixels to be between 0 and 1
        noisy_imgs = np.clip(noisy_imgs, 0, 1)
        noisy_imgs = noisy_imgs.to(device)
        images = images.to(device)
        
        outputs = conv_upscale_net(noisy_imgs)
        loss = criterion(outputs, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*images.size(0)
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=train_loss)

    scheduler.step(train_loss)

In [None]:
# save the model
torch.save(conv_upscale_net.state_dict(), 'models/model-upscale_net.pth')

In [None]:
# obtain one batch of test images
images, labels = next(iter(test_loader))

# add noise to the test images
noise_factor = 0.5
noisy_imgs = images + noise_factor * torch.randn(*images.shape)
noisy_imgs = np.clip(noisy_imgs, 0., 1.)
noisy_imgs = noisy_imgs.to(device)
images = images.to(device)

# get sample outputs
output = conv_upscale_net(noisy_imgs)
# prep images for display
noisy_imgs = noisy_imgs.cpu().numpy()

# use detach when it's an output that requires_grad
output = output.detach().cpu().numpy()

# plot the first ten input images and then reconstructed images
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))

# input images on top row, reconstructions on bottom
for noisy_imgs, row in zip([noisy_imgs, output], axes):
    for img, ax in zip(noisy_imgs, row):
        ax.imshow(np.squeeze(img), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)

In [None]:
conv_transpose_net = ConvTransposeDenoiser()
conv_transpose_net

In [None]:
conv_transpose_net.to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(conv_transpose_net.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=1/3, patience=3, verbose=True)

In [None]:
noise_factor = 0.5  # for adding noise to the images
num_epochs = 30
for epoch in range(num_epochs):
    train_loss = 0
    ###################
    # train the model #
    ###################
    loop = tqdm(train_loader, total=len(train_loader))
    for batch in loop:
        images, _ = batch
        # add random noise to the input images
        noisy_imgs = images + noise_factor * torch.randn(*images.shape)
        # clip the pixels to be between 0 and 1
        noisy_imgs = np.clip(noisy_imgs, 0, 1)
        noisy_imgs = noisy_imgs.to(device)
        images = images.to(device)
        
        outputs = conv_transpose_net(noisy_imgs)
        loss = criterion(outputs, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()*images.size(0)
        loop.set_description(f'Epoch [{epoch+1:2d}/{num_epochs}]')
        loop.set_postfix(loss=train_loss)

    scheduler.step(train_loss)

In [None]:
# save the model
torch.save(conv_transpose_net.state_dict(), 'models/model-transpose_net.pth')

In [None]:
# obtain one batch of test images
images, labels = next(iter(test_loader))

# add noise to the test images
noise_factor = 0.5
noisy_imgs = images + noise_factor * torch.randn(*images.shape)
noisy_imgs = np.clip(noisy_imgs, 0., 1.)
noisy_imgs = noisy_imgs.to(device)
images = images.to(device)

# get sample outputs
output = conv_transpose_net(noisy_imgs)
# prep images for display
noisy_imgs = noisy_imgs.cpu().numpy()

# use detach when it's an output that requires_grad
output = output.detach().cpu().numpy()

# plot the first ten input images and then reconstructed images
fig, axes = plt.subplots(nrows=2, ncols=10, sharex=True, sharey=True, figsize=(25,4))

# input images on top row, reconstructions on bottom
for noisy_imgs, row in zip([noisy_imgs, output], axes):
    for img, ax in zip(noisy_imgs, row):
        ax.imshow(np.squeeze(img), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)