In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision.transforms as transforms
import os
from PIL import Image
from pytorch_msssim import ssim
import torch.nn.functional as F
from math import log10
from tqdm import tqdm
from sklearn.model_selection import KFold
import json
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import GradScaler, autocast


In [None]:
with open('config.json', 'r') as config_file:
    config = json.load(config_file)

In [None]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self):
        super(DenoisingAutoencoder, self).__init__()
        
        self.encoder_conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.encoder_bn1 = nn.BatchNorm2d(64)
        self.encoder_relu1 = nn.LeakyReLU(negative_slope=0.2)
        
        self.encoder_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.encoder_bn2 = nn.BatchNorm2d(128)
        self.encoder_relu2 = nn.LeakyReLU(negative_slope=0.2)
        
        self.encoder_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.encoder_bn3 = nn.BatchNorm2d(256)
        self.encoder_relu3 = nn.LeakyReLU(negative_slope=0.2)
        
        self.dropout = nn.Dropout(0.3)

        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(256, 256 // 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(256 // 16, 256, kernel_size=1),
            nn.Sigmoid()

        )

        
        self.decoder_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder_bn1 = nn.BatchNorm2d(128)
        self.decoder_relu1 = nn.ReLU()
        
        self.decoder_conv2 = nn.ConvTranspose2d(128 + 128, 64, kernel_size=2, stride=2)
        self.decoder_bn2 = nn.BatchNorm2d(64)
        self.decoder_relu2 = nn.ReLU()
        
        self.decoder_conv3 = nn.ConvTranspose2d(64 + 64, 3, kernel_size=2, stride=2)
        self.decoder_sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.encoder_relu1(self.encoder_bn1(self.encoder_conv1(x)))
        x1 = self.dropout(x1)

        x2 = self.encoder_relu2(self.encoder_bn2(self.encoder_conv2(x1)))
        x2 = self.dropout(x2)

        x3 = self.encoder_relu3(self.encoder_bn3(self.encoder_conv3(x2)))
        x3 = self.dropout(x3)
        

        attn = self.channel_attention(x3)
        x3 = x3 * attn

        x4 = self.decoder_relu1(self.decoder_bn1(self.decoder_conv1(x3)))

        x4 = torch.cat((x4, x2), dim=1)
        
        x5 = self.decoder_relu2(self.decoder_bn2(self.decoder_conv2(x4)))

        x5 = torch.cat((x5, x1), dim=1)

        x6 = self.decoder_sigmoid(self.decoder_conv3(x5))

        return x6

In [None]:
TRAIN_CLEAN_PATH = config["train_clean_dir"]
TRAIN_NOISY_PATH = config["train_noisy_dir"]
VALID_CLEAN_PATH = config["valid_clean_dir"]
VALID_NOISY_PATH = config["valid_noisy_dir"]

class NoisyCleanDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, transform=None):
        self.clean_dir = clean_dir
        self.noisy_dir = noisy_dir
        self.clean_images = os.listdir(self.clean_dir)
        self.noisy_images = os.listdir(self.noisy_dir)
        self.transform = transform

        assert len(self.clean_images) == len(self.noisy_images)

    def __len__(self):
        return len(self.clean_images)

    def __getitem__(self, idx):
        clean_image_name = self.clean_images[idx]
        noisy_image_name = self.noisy_images[idx]

        clean_image_path = os.path.join(self.clean_dir, clean_image_name)
        noisy_image_path = os.path.join(self.noisy_dir, noisy_image_name)

        clean_image = Image.open(clean_image_path).convert('RGB')
        noisy_image = Image.open(noisy_image_path).convert('RGB')

        if self.transform:
            clean_image = self.transform(clean_image)
            noisy_image = self.transform(noisy_image)

        return noisy_image, clean_image


In [None]:

transform = transforms.Compose([
    transforms.Resize(tuple(config["train_resize_shape"])),
    transforms.ToTensor()
])

train_dataset = NoisyCleanDataset(TRAIN_CLEAN_PATH, TRAIN_NOISY_PATH, transform=transform)
val_dataset = NoisyCleanDataset(VALID_CLEAN_PATH, VALID_NOISY_PATH, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=config["denoise_model"]["batch_size"], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config["denoise_model"]["batch_size"], shuffle=False)

model = DenoisingAutoencoder()
optimizer = optim.Adam(model.parameters(), lr=config["denoise_model"]["learning_rate"], weight_decay=config["denoise_model"]["optimizer_weight_decay"])



scheduler = ReduceLROnPlateau(
    optimizer, 
    mode=config["denoise_model"]["scheduler"]["mode"],
    factor=config["denoise_model"]["scheduler"]["factor"],
    patience=config["denoise_model"]["scheduler"]["patience"]
)

scaler = torch.amp.GradScaler()

In [None]:
accumulation_step = config["denoise_model"]["gradient_accumulation_steps"]
previous_avg_val_loss = 0
for epoch in range(config["denoise_model"]["epochs"]):
    model.train()
    running_loss = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    batch_idx = 0
    

    with tqdm(total=len(train_loader), desc=f"Training Epoch {epoch}") as pbar:
        for noisy_imgs, clean_imgs in train_loader:
            batch_idx += 1
            
            with autocast(device_type="cuda", enabled=True):  
                outputs = model(noisy_imgs)
                mse_loss = F.mse_loss(outputs, clean_imgs)
                ssim_loss = 1 - ssim(outputs, clean_imgs, data_range=1.0, size_average=True)
                loss = config["denoise_model"]["mse_alpha"] * mse_loss + config["denoise_model"]["ssim_beta"] * ssim_loss
                
                loss = loss / accumulation_step
            
            scaler.scale(loss).backward()

            if (batch_idx % accumulation_step == 0) or (batch_idx == len(train_loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            running_loss += loss.item() * accumulation_step
            psnr = 10 * log10(1 / mse_loss.item())
            total_psnr += psnr
            total_ssim += (1 - ssim_loss.item())

            pbar.set_postfix({
                'Batch': f"{batch_idx}/{len(train_loader)}",
                'MSE Loss': f"{mse_loss.item():.6f}",
                'SSIM Loss': f"{ssim_loss.item():.6f}",
                'Combined Loss': f"{loss.item() * accumulation_step:.6f}"
            })
            pbar.update(1)

    epoch_loss = running_loss / len(train_loader)
    avg_psnr = total_psnr / len(train_loader)
    avg_ssim = total_ssim / len(train_loader)
    print(f"Epoch {epoch}, Average Training Loss: {epoch_loss:.6f}, Average PSNR: {avg_psnr:.6f}, Average SSIM: {avg_ssim:.6f}")

    model.eval()
    val_loss = 0.0
    total_psnr = 0.0
    total_ssim = 0.0
    batch_idx = 0
    with tqdm(total=len(val_loader), desc="Validating") as pbar:
        with torch.no_grad():
            for noisy_imgs, clean_imgs in val_loader:
                batch_idx += 1
                outputs = model(noisy_imgs)
                mse_loss = F.mse_loss(outputs, clean_imgs)
                ssim_loss = 1 - ssim(outputs, clean_imgs, data_range=1.0, size_average=True)
                loss = config["denoise_model"]["mse_alpha"] * mse_loss + config["denoise_model"]["ssim_beta"] * ssim_loss
                val_loss += loss.item()

                psnr = 10 * log10(1 / mse_loss.item())
                total_psnr += psnr
                total_ssim += (1 - ssim_loss.item())

                pbar.set_postfix({
                    'Batch': f"{batch_idx}/{len(val_loader)}",
                    'MSE Loss': f"{mse_loss.item():.6f}",
                    'SSIM Loss': f"{ssim_loss.item():.6f}",
                    'Combined Validation Loss': f"{loss.item():.6f}"
                })
                pbar.update(1)

    avg_val_loss = val_loss / len(val_loader)
    avg_val_psnr = total_psnr / len(val_loader)
    avg_val_ssim = total_ssim / len(val_loader)
    scheduler.step(avg_val_loss)
    print(f"Validation Loss: {avg_val_loss:.6f}, Average PSNR: {avg_val_psnr:.6f}, Average SSIM: {avg_val_ssim:.6f}\n")

    if avg_val_loss > previous_avg_val_loss:
        torch.save(model, "./models/DenoiseAutoencoderV4.pth")
        previous_avg_val_loss = avg_val_loss

In [None]:
print('Done')
torch.save(model, "./models/DenoiseAutoencoderV4.pth")