TODO

spiega l'architettura del conv e del dense

spiega il training loop coi vari parametri usati (non scordare add_noise)

spiega come funziona evaluate, ovvero che metriche stai usando e perché

spiega la funzione per visualizzare le immagini

blocchi dove alleni e valuti (scegliere se trattare contemporanemente i dae o meno)

conclusioni

# Midterm 3, Assignment 1 - Gaetano Barresi [579102]

A Denoising Autoencoder (DAE) is a neural network that learns robust latent representations by reconstructing clean inputs from corrupted data. Unlike standard autoencoders that simply reconstruct the original inputs, a DAE must remove noise from artificially corrupted inputs and discover stable features that capture the true data distribution. The goal of this work is to train two versions of a DAE on the CIFAR10 dataset, one using dense layers and one using convolutional layers and to show an accuracy comparison between them.

The architecture of a DAE has the following form:

$$
\
[Input] → [Corruption] → [Encoder] → [Latent Code] → [Decoder] → [Reconstruction]
\
$$

$[Input]$ is self-explanatory. $[Corruption]$ is the stage where we inject artificial noise to the input through some noise process $C(\hat{x} ∣ x)$ and obtain $\hat{x}$, a corrupted version of the original input. In this work the corruption process is done via standard Gaussian noise, so:

$$
\
\hat{x} = x + ϵ, ϵ ∼ 𝒩(0, σ^2).
\
$$

$[Encoder]$ $f()$ maps noisy input $\hat{x}$ to latent representation $z$:

$$
\
z = f(\hat{x}) ∈ ℝᵈ,
\
$$

where $d << input\_dim$. Learned $[Latent Code]$ $z$ should be robust to partial destruction of the input.
$[Decoder]$ $g()$ reconstructs clean input from $z$:

$$
\
x' = g(z) ≈ x,
\
$$

where $x'$ is our $[Reconstruction]$.

The "magic" of DAEs lies in their bottleneck architecture. As data passes through the encoder, the network gradually compresses the input while stripping away noise and preserving only the most essential features. This forced dimensionality reduction creates a distilled latent representation containing only the core patterns needed for reconstruction. The decoder then reverses this process, carefully rebuilding the clean input from these compressed features layer by layer. By training on noisy-clean pairs, the network learns to discard random corruptions during compression while maintaining the structural integrity needed for accurate reconstruction.

More formally, the DAE is trained to minimize the loss function

$$
\
L = (x, g(f(\hat{x}))),
\
$$

or, using a probabilistic interpretation, the DAE learns the denoising distribuition

$$
\
P(x|\hat{x})
\
$$

by minimizing

$$
\
-log P(x|z = f(\hat{x})).
\
$$

After this brief introduction, we proceed with the implementation. First, we need to load and preprocess the dataset. We normalize the CIFAR-10 data to the range [-1, 1] to achieve several benefits such as: compatibility with the tanh activation, effective noise handling due to the symmetric range that ensures Gaussian noise is equally well-handled in both positive and negative directions, gradient flow optimization, numerical stability.

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# transform for CIFAR-10, normalizing to [-1, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Range: [-1, 1]
])
# Load TR set
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=100, shuffle=True)
# Load TS set
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=100, shuffle=False)

print("Data loaded successfully.")

To implement the DAE, we subclass PyTorch’s `nn.Module` base class, inheriting its core functionality for neural network operations. With a single class we can specify which kind of DAE to instantiate, a dense or a convolutional one. 



The `forward()` method will slightly change behaviour according to the DAE selected.

The class has methods to train and evaluate the DAE:
- `train_model()` is a classical training loop. We use MSE loss, Adam optimizer and a learnig rate of 1e-4
- `evaluate()`
- `visualize_result()`

We have a function `add_noise()` used inside the training loop

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt


class DAE(nn.Module):
    def __init__(self, mode=None):
        super(DAE, self).__init__()

        self.mode = mode

        if mode == 'conv':
            # Convolutional Encoder
            self.encoder = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),   # 32x32
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1), # 16x16
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(128, 256, 3, padding=1) # 8x8
            )
            # Convolutional Decoder
            self.decoder = nn.Sequential(
                nn.Conv2d(256, 128, 3, padding=1),  # 8x8
                nn.ReLU(),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(128, 64, 3, padding=1),   # 16x16
                nn.ReLU(),
                nn.Upsample(scale_factor=2),
                nn.Conv2d(64, 3, 3, padding=1),     # 32x32
                nn.Tanh()
            )
        elif mode == 'dense':
            # Dense Encoder
            self.encoder = nn.Sequential(
                nn.Linear(3 * 32 * 32, 1024),
                nn.ReLU(),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU()
            )
            # Dense Decoder
            self.decoder = nn.Sequential(
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, 3 * 32 * 32),
                nn.Tanh()
            )
        else:
            raise ValueError("Invalid mode. Choose 'conv' or 'dense'.")


    def forward(self, x):
        if self.mode == 'conv':
            x = self.encoder(x)
            x = self.decoder(x)
        else:
            x = x.view(x.size(0), -1)  # Flatten to (batch_size, 3072)
            x = self.encoder(x)
            x = self.decoder(x)
            x = x.view(-1, 3, 32, 32)  # Reshape to image dimensions
        return x
    

    def add_noise(self, inputs, noise_factor=0.2):
        noise = noise_factor * torch.randn_like(inputs)
        noisy = inputs + noise
        return torch.clamp(noisy, -1., 1.)  # CIFAR-10 is normalized to [-1, 1]


    def train_model(self, train_loader, num_epochs=20, lr=1e-4, print_interval=100):
        #parameter print_interval (int): Print loss every N examples

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)
        
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        
        self.train()  # Set to training mode
        for epoch in range(num_epochs):
            for batch_idx, (clean_imgs, _) in enumerate(train_loader):
                clean_imgs = clean_imgs.to(device)
                
                # Add noise and reconstruct
                noisy_imgs = self.add_noise(clean_imgs)
                reconstructed = self(noisy_imgs)
                
                # Compute loss and update
                loss = self.criterion(reconstructed, clean_imgs)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Print progress
                if batch_idx % print_interval == 0:
                    print(f"Epoch [{epoch+1}/{num_epochs}], Batch {batch_idx}, Loss: {loss.item():.4f}")


    def evaluate(self, test_loader, noise_factor=0.2):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)

        #Compute PSNR and SSIM on test set
        self.eval()  # Set to evaluation mode
        total_psnr = 0.0
        total_ssim = 0.0
        num_samples = 0
        
        with torch.no_grad():
            for clean_imgs, _ in test_loader:
                clean_imgs = clean_imgs.to(device)
                noisy_imgs = self.add_noise(clean_imgs, noise_factor)
                reconstructed = self(noisy_imgs)
                
                # Convert to numpy (CPU) for metric calculation
                clean_np = clean_imgs.cpu().numpy()
                recon_np = reconstructed.cpu().numpy()
                
                # Compute metrics per image
                for i in range(clean_np.shape[0]):
                    # PSNR (higher is better)
                    total_psnr += psnr(clean_np[i], recon_np[i], data_range=2.0)  # data_range=2 for [-1,1]
                    
                    # SSIM (higher is better, multichannel=True for RGB)
                    total_ssim += ssim(clean_np[i].transpose(1,2,0), 
                                     recon_np[i].transpose(1,2,0), 
                                     data_range=2.0, 
                                     channel_axis=2)
                
                num_samples += clean_np.shape[0]
        
        avg_psnr = total_psnr / num_samples
        avg_ssim = total_ssim / num_samples
        return avg_psnr, avg_ssim
    

    def visualize_results(self, test_loader, num_images=5):
        #Plot noisy vs reconstructed vs clean images
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device)

        self.eval()
        with torch.no_grad():
            clean_imgs, _ = next(iter(test_loader))
            clean_imgs = clean_imgs.to(device)[:num_images]
            noisy_imgs = self.add_noise(clean_imgs)
            reconstructed = self(noisy_imgs)
            
            # Denormalize to [0,1] for plotting
            clean_imgs = (clean_imgs + 1) / 2
            noisy_imgs = (noisy_imgs + 1) / 2
            reconstructed = (reconstructed + 1) / 2
            
            fig, axes = plt.subplots(num_images, 3, figsize=(10, num_images*2))
            for i in range(num_images):
                axes[i,0].imshow(noisy_imgs[i].cpu().permute(1,2,0))
                axes[i,0].set_title("Noisy")
                axes[i,1].imshow(reconstructed[i].cpu().permute(1,2,0))
                axes[i,1].set_title("Reconstructed")
                axes[i,2].imshow(clean_imgs[i].cpu().permute(1,2,0))
                axes[i,2].set_title("Clean")
                for ax in axes[i]:
                    ax.axis('off')
            plt.tight_layout()
            plt.show()