In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
def visualize_results(original, reconstructed, num_images=5):

    original = original.permute(0, 2, 3, 1)  
    reconstructed = reconstructed.permute(0, 2, 3, 1)  

    for i in range(min(num_images, original.shape[0])):

        plt.figure(figsize=(10, 4))

        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(original[i].numpy().clip(0, 1)) 
        plt.title("Original Image")

        # Reconstructed image
        plt.subplot(1, 3, 2)
        plt.imshow(reconstructed[i].numpy().clip(0, 1))
        plt.title("Reconstructed Image")
        
        plt.show()

In [None]:
def test_decoder(decoder, test_loader, num_images=5):

    # Decoder to evaluation mode
    decoder.eval()  

    total_loss = 0

    # Disable gradient computation for testing , since the model is not being updated. 
    with torch.no_grad(): 

        for encoded_patches, original_images, binary_masks in test_loader:

            if torch.cuda.is_available():
                
                encoded_patches = encoded_patches.cuda()
                original_images = original_images.cuda()
                binary_masks = binary_masks.cuda()

            # Forward pass
            reconstructed_images = decoder(encoded_patches)

            # Loss for evaluation 
            loss = mse_loss(reconstructed_images, original_images, binary_masks)
            total_loss += loss.item()

        # Average loss over the test set
        avg_loss = total_loss / len(test_loader)
        print(f"Test Loss: {avg_loss}")

    # Visualization
    for encoded_patches, original_images, binary_masks in test_loader:

        if torch.cuda.is_available():

            encoded_patches = encoded_patches.cuda()
            original_images = original_images.cuda()
            binary_masks = binary_masks.cuda()

        reconstructed_images = decoder(encoded_patches)

        visualize_results(original_images.cpu(), reconstructed_images.cpu(), num_images)
        
        break  # Visualize only the first batch


In [None]:
# Example usage
test_decoder(decoder, test_loader)