In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
from torchvision.transforms import transforms
import numpy as np
import cv2
from PIL import Image

In [15]:
# Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3)
        self.norm1 = nn.InstanceNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.norm2 = nn.InstanceNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.norm3 = nn.InstanceNorm2d(256)

    def forward(self, x):
        x = F.relu(self.norm1(self.conv1(x)))
        x = F.relu(self.norm2(self.conv2(x)))
        x = F.relu(self.norm3(self.conv3(x)))
        return x

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.norm1 = nn.InstanceNorm2d(256)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.norm2 = nn.InstanceNorm2d(256)

    def forward(self, x):
        residual = x
        out = F.relu(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out += residual
        return F.relu(out)

# Decoder
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm1 = nn.InstanceNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm2 = nn.InstanceNorm2d(64)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=7, padding=3)

    def forward(self, x):
        x = F.relu(self.norm1(self.deconv1(x)))
        x = F.relu(self.norm2(self.deconv2(x)))
        x = torch.tanh(self.conv3(x))  # Output with Tanh activation
        return x

# Enhance-Net Model
class EnhanceNet(nn.Module):
    def __init__(self, num_res_blocks=9):
        super(EnhanceNet, self).__init__()
        self.encoder = Encoder()
        self.residual_blocks = nn.Sequential(*[ResidualBlock() for _ in range(num_res_blocks)])
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

In [16]:
# Loss function: Mean Absolute Error (MAE)
def mae_loss(output, target):
    return torch.mean(torch.abs(output - target))

# Perceptual Loss using pretrained VGG19
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features[:16].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg

    def forward(self, output, target):
        output_features = self.vgg(output)
        target_features = self.vgg(target)
        return mae_loss(output_features, target_features)

In [17]:
class CustomImageDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.images = glob.glob(os.path.join(folder_path, '*.*')) 
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        image = Image.open(img_path).convert("RGB")  
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
# Moving the model to mps
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

learning_rate = 1e-4
batch_size = 1  
epochs = 1
gradient_accumulation_steps = 1

train_paths = '../CURE-TSR/Train/Haze-3' # Path to the training dataset
val_paths = '../CURE-TSR/Test/Haze-3' # Path to the validation dataset

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = CustomImageDataset(train_paths, transform=transform)
val_dataset = CustomImageDataset(val_paths, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = EnhanceNet().to(device)
mae_criterion = mae_loss
perceptual_criterion = PerceptualLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)

# Training loop with validation
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    optimizer.zero_grad()

    # Training step
    for i, images in enumerate(train_loader):
        images = images.to(device)
        outputs = model(images)

        # Compute losses
        loss_mae = mae_criterion(outputs, images)  # MAE loss
        loss_perceptual = perceptual_criterion(outputs, images)  # Perceptual loss
        loss = loss_mae + loss_perceptual  # Combined loss

        # Backpropagation with gradient accumulation
        loss.backward()
        running_loss += loss.item()

        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()  
            optimizer.zero_grad() 

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_images in val_loader:  
            val_images = val_images.to(device) 
            val_outputs = model(val_images)  

            # Validation losses
            val_loss_mae = mae_criterion(val_outputs, val_images)  # MAE loss
            val_loss_perceptual = perceptual_criterion(val_outputs, val_images)  # Perceptual loss
            val_loss += val_loss_mae.item() + val_loss_perceptual.item()  # Accumulate validation loss

    # Print epoch information
    print(f'Epoch [{epoch + 1}/{epochs}], '
          f'Training Loss: {running_loss / len(train_loader):.4f}, '
          f'Validation Loss: {val_loss / len(val_loader):.4f}')

    # Update learning rate based on validation loss
    scheduler.step(val_loss)

# Save the trained EnhanceNet model
torch.save(model.state_dict(), 'enhance_net_haze.pth')