# 🧼 Autoencoder for Face Image Denoising

In [1]:
import os
import io
import random
import numpy as np
import time
import torch
import torch.nn as nn
from datetime import timedelta
from PIL import Image, ImageFilter
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.utils import save_image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

  "class": algorithms.Blowfish,


In [2]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))


In [3]:
def add_random_noise(img_tensor):
    mode = random.choice(["gaussian", "s&p", "occlusion", "blur", "jpeg", "combo"])

    if mode == "gaussian":
        noise = torch.randn_like(img_tensor) * 0.1
        return torch.clamp(img_tensor + noise, 0., 1.)

    elif mode == "s&p":
        noisy = img_tensor.clone()
        prob = 0.02
        rand = torch.rand_like(noisy)
        noisy[rand < prob] = 0
        noisy[rand > 1 - prob] = 1
        return noisy

    elif mode == "occlusion":
        noisy = img_tensor.clone()
        x = random.randint(0, noisy.shape[2] - 128)
        y = random.randint(0, noisy.shape[1] - 128)
        noisy[:, y:y+128, x:x+128] = 0
        return noisy

    elif mode == "blur":
        pil_img = F.to_pil_image(img_tensor)
        blurred = pil_img.filter(ImageFilter.GaussianBlur(radius=random.uniform(1, 3)))
        return F.to_tensor(blurred)

    elif mode == "jpeg":
        pil_img = F.to_pil_image(img_tensor)
        buffer = io.BytesIO()
        pil_img.save(buffer, format='JPEG', quality=random.randint(10, 30))
        jpeg_img = Image.open(buffer)
        return F.to_tensor(jpeg_img)

    elif mode == "combo":
        img = add_random_noise(img_tensor.clone())
        return add_random_noise(img)

    else:
        return img_tensor


In [4]:
class FaceDataset(Dataset):
    def __init__(self, root, transform=None, noisy=True):
        self.root = root
        self.transform = transform
        self.image_files = [
            os.path.join(dp, f)
            for dp, _, filenames in os.walk(root)
            for f in filenames
            if f.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]
        self.noisy = noisy

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        noisy_image = add_random_noise(image) if self.noisy else image
        return noisy_image, image


In [5]:
def calculate_psnr_ssim(output_tensor, target_tensor):
    psnr_total, ssim_total = 0.0, 0.0
    output_tensor = output_tensor.detach().cpu()
    target_tensor = target_tensor.detach().cpu()

    for i in range(output_tensor.size(0)):
        output_img = output_tensor[i].permute(1, 2, 0).numpy()
        target_img = target_tensor[i].permute(1, 2, 0).numpy()
        psnr_total += psnr(target_img, output_img, data_range=1.0)
        ssim_total += ssim(target_img, output_img, channel_axis=-1, data_range=1.0)

    return psnr_total / output_tensor.size(0), ssim_total / output_tensor.size(0)

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((1024, 1024)),
    transforms.ToTensor()
])

dataset = FaceDataset("data/", transform=transform, noisy=True)
dataloader = DataLoader(dataset, batch_size=24, shuffle=True)

model = DenoisingAutoencoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

# Завантаження чекпоінту
checkpoint_path = "checkpoint.pth"
start_epoch = 0
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"✔️ Завантажено чекпоінт з епохи {start_epoch}")


✔️ Завантажено чекпоінт з епохи 50


In [None]:
# 🔁 Тренування
num_epochs = 500

for epoch in range(start_epoch, num_epochs):
    start_time = time.time()
    model.train()
    running_loss = 0.0

    for noisy_imgs, clean_imgs in dataloader:
        noisy_imgs = noisy_imgs.to(device)
        clean_imgs = clean_imgs.to(device)

        outputs = model(noisy_imgs)
        loss = criterion(outputs, clean_imgs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()


    # 💾 Збереження чекпоінту
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }, checkpoint_path)
    
    epoch_time = time.time() - start_time
    formatted_time = str(timedelta(seconds=int(epoch_time)))
    
    avg_loss = running_loss / len(dataloader)
    psnr_val, ssim_val = calculate_psnr_ssim(outputs, clean_imgs)

    epoch_time = time.time() - start_time
    formatted_time = str(timedelta(seconds=int(epoch_time)))

    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {avg_loss:.6f} | PSNR: {psnr_val:.2f} | SSIM: {ssim_val:.4f} | Time: {formatted_time}")

    # Збереження прикладу
    save_image(torch.cat((noisy_imgs[:2], outputs[:2], clean_imgs[:2]), dim=0), f"sample_epoch_{epoch+1}.png")


Epoch [51/500] | Loss: 0.000945 | PSNR: 32.20 | SSIM: 0.8536 | Time: 0:15:47
Epoch [52/500] | Loss: 0.000929 | PSNR: 30.67 | SSIM: 0.8571 | Time: 0:15:56
Epoch [53/500] | Loss: 0.000940 | PSNR: 31.13 | SSIM: 0.8401 | Time: 0:15:51
Epoch [54/500] | Loss: 0.000916 | PSNR: 30.92 | SSIM: 0.8485 | Time: 0:15:42
Epoch [55/500] | Loss: 0.000901 | PSNR: 31.76 | SSIM: 0.8687 | Time: 0:15:42
Epoch [56/500] | Loss: 0.000896 | PSNR: 31.51 | SSIM: 0.8623 | Time: 0:15:44
Epoch [57/500] | Loss: 0.000912 | PSNR: 31.60 | SSIM: 0.8522 | Time: 0:15:58
Epoch [58/500] | Loss: 0.000874 | PSNR: 31.28 | SSIM: 0.8631 | Time: 0:16:02
Epoch [59/500] | Loss: 0.000861 | PSNR: 30.81 | SSIM: 0.8312 | Time: 0:16:31
Epoch [60/500] | Loss: 0.000887 | PSNR: 30.49 | SSIM: 0.8300 | Time: 0:17:29
Epoch [61/500] | Loss: 0.000852 | PSNR: 30.62 | SSIM: 0.8483 | Time: 0:17:07
Epoch [62/500] | Loss: 0.000847 | PSNR: 31.55 | SSIM: 0.8748 | Time: 0:16:58
Epoch [63/500] | Loss: 0.000844 | PSNR: 31.75 | SSIM: 0.8582 | Time: 0:18:19

Epoch [157/500] | Loss: 0.000606 | PSNR: 32.91 | SSIM: 0.8759 | Time: 0:15:01
Epoch [158/500] | Loss: 0.000603 | PSNR: 31.81 | SSIM: 0.8853 | Time: 0:14:43
Epoch [159/500] | Loss: 0.000595 | PSNR: 33.02 | SSIM: 0.8781 | Time: 0:15:21
Epoch [160/500] | Loss: 0.000591 | PSNR: 33.27 | SSIM: 0.8825 | Time: 0:15:53
Epoch [161/500] | Loss: 0.000596 | PSNR: 33.34 | SSIM: 0.8870 | Time: 0:15:42
Epoch [162/500] | Loss: 0.000588 | PSNR: 32.96 | SSIM: 0.8772 | Time: 0:15:53
Epoch [163/500] | Loss: 0.000593 | PSNR: 33.32 | SSIM: 0.8922 | Time: 0:15:41
Epoch [164/500] | Loss: 0.000586 | PSNR: 32.52 | SSIM: 0.8812 | Time: 0:14:39
Epoch [165/500] | Loss: 0.000588 | PSNR: 32.75 | SSIM: 0.8960 | Time: 0:14:43
Epoch [166/500] | Loss: 0.000577 | PSNR: 33.93 | SSIM: 0.8956 | Time: 0:14:41
Epoch [167/500] | Loss: 0.000592 | PSNR: 33.50 | SSIM: 0.8957 | Time: 0:14:44
Epoch [168/500] | Loss: 0.000587 | PSNR: 34.04 | SSIM: 0.8962 | Time: 0:14:36
Epoch [169/500] | Loss: 0.000594 | PSNR: 33.61 | SSIM: 0.8947 | 

Epoch [263/500] | Loss: 0.000522 | PSNR: 32.82 | SSIM: 0.8859 | Time: 0:14:42
Epoch [264/500] | Loss: 0.000527 | PSNR: 34.83 | SSIM: 0.9053 | Time: 0:14:34
Epoch [265/500] | Loss: 0.000516 | PSNR: 34.66 | SSIM: 0.9077 | Time: 0:14:35
Epoch [266/500] | Loss: 0.000529 | PSNR: 33.82 | SSIM: 0.9014 | Time: 0:14:42
Epoch [267/500] | Loss: 0.000529 | PSNR: 34.30 | SSIM: 0.8985 | Time: 0:14:39
Epoch [268/500] | Loss: 0.000523 | PSNR: 33.47 | SSIM: 0.8885 | Time: 0:15:02
Epoch [269/500] | Loss: 0.000523 | PSNR: 33.96 | SSIM: 0.9072 | Time: 0:15:19
Epoch [270/500] | Loss: 0.000516 | PSNR: 34.06 | SSIM: 0.9062 | Time: 0:14:49
Epoch [271/500] | Loss: 0.000517 | PSNR: 34.35 | SSIM: 0.9126 | Time: 0:14:36
Epoch [272/500] | Loss: 0.000515 | PSNR: 34.12 | SSIM: 0.9048 | Time: 0:14:43
Epoch [273/500] | Loss: 0.000526 | PSNR: 33.39 | SSIM: 0.8980 | Time: 0:14:41
Epoch [274/500] | Loss: 0.000518 | PSNR: 34.26 | SSIM: 0.9145 | Time: 0:14:37
Epoch [275/500] | Loss: 0.000515 | PSNR: 32.79 | SSIM: 0.9141 | 

Epoch [369/500] | Loss: 0.000498 | PSNR: 34.42 | SSIM: 0.9207 | Time: 0:14:37
Epoch [370/500] | Loss: 0.000484 | PSNR: 33.96 | SSIM: 0.9210 | Time: 0:14:35
Epoch [371/500] | Loss: 0.000490 | PSNR: 34.11 | SSIM: 0.9119 | Time: 0:14:34
Epoch [372/500] | Loss: 0.000487 | PSNR: 34.61 | SSIM: 0.9223 | Time: 0:14:35
Epoch [373/500] | Loss: 0.000483 | PSNR: 33.69 | SSIM: 0.9026 | Time: 0:14:31
Epoch [374/500] | Loss: 0.000485 | PSNR: 33.90 | SSIM: 0.9052 | Time: 0:14:33
Epoch [375/500] | Loss: 0.000482 | PSNR: 34.93 | SSIM: 0.9121 | Time: 0:14:38
Epoch [376/500] | Loss: 0.000497 | PSNR: 36.00 | SSIM: 0.9328 | Time: 0:14:36
Epoch [377/500] | Loss: 0.000472 | PSNR: 34.03 | SSIM: 0.9119 | Time: 0:14:25
Epoch [378/500] | Loss: 0.000483 | PSNR: 34.22 | SSIM: 0.9130 | Time: 0:14:33
Epoch [379/500] | Loss: 0.000492 | PSNR: 34.57 | SSIM: 0.9130 | Time: 0:14:33
Epoch [380/500] | Loss: 0.000494 | PSNR: 33.50 | SSIM: 0.9013 | Time: 0:14:32
Epoch [381/500] | Loss: 0.000485 | PSNR: 34.10 | SSIM: 0.9015 | 