In [None]:
%load_ext autoreload
%autoreload 2

from utils.dataloader import load_base_dataset, load_processed_dataset, create_torch_dataloader
from utils.image_processing import process_images
from unet.unet_model import UNet
from utils.train import train_model
from utils.checkpoint import save_checkpoint, load_checkpoint
import matplotlib.pyplot as plt

In [None]:
import os
if os.path.exists('dataset/custom_test/sharp'):
    for file in os.listdir('dataset/custom_test/sharp'):
        os.remove('dataset/custom_test/sharp/' + file)
    os.rmdir('dataset/custom_test/sharp')

In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
test_dataset = datasets.ImageFolder(root='dataset/custom_test', transform=transforms.ToTensor())
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [None]:
print(test_dataset)
print(test_dataloader)

In [None]:
model = UNet(3, 3)
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000}M Parameters")

In [None]:
prev_checkpoint = load_checkpoint("checkpoint.pth")

In [None]:
model.load_state_dict(prev_checkpoint['model_state'])

In [None]:
if not os.path.exists('dataset/final_output/sharp'):
    os.makedirs('dataset/final_output/sharp')
else:
    for file in os.listdir('dataset/final_output/sharp'):
        os.remove('dataset/final_output/sharp/' + file)
if not os.path.exists('dataset/final_output/blur'):
    os.makedirs('dataset/final_output/blur')
else:
    for file in os.listdir('dataset/final_output/blur'):
        os.remove('dataset/final_output/blur/' + file)

model.eval()
model.to('cuda')

with torch.no_grad():
    for i, (data, target) in enumerate(test_dataloader):
        data = data.to('cuda')
        output = model(data)
        for j in range(output.shape[0]):
            output[j] = output[j].clamp(0, 1)
            save_image_sharp = transforms.ToPILImage()(output[j].detach().cpu())
            save_image_blur = transforms.ToPILImage()(data[j].detach().cpu())
            save_image_sharp.save(f"dataset/final_output/sharp/{i*2+j}.png")
            save_image_blur.save(f"dataset/final_output/blur/{i*2+j}.png")
        if i % 10 == 0:
            fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
            axes[0].imshow(data[0].detach().cpu().permute(1,2,0))
            axes[1].imshow(output[0].detach().cpu().permute(1,2,0))
            axes[0].axis('off')
            axes[1].axis('off')
            fig.tight_layout()
            plt.show()

In [None]:
import utils.eval