In [3]:
import torch
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [4]:

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3)
        self.norm1 = nn.InstanceNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.norm2 = nn.InstanceNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.norm3 = nn.InstanceNorm2d(256)

    def forward(self, x):
        x = F.relu(self.norm1(self.conv1(x)))
        x = F.relu(self.norm2(self.conv2(x)))
        x = F.relu(self.norm3(self.conv3(x)))
        return x

class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.norm1 = nn.InstanceNorm2d(256)
        self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.norm2 = nn.InstanceNorm2d(256)

    def forward(self, x):
        residual = x
        out = F.relu(self.norm1(self.conv1(x)))
        out = self.norm2(self.conv2(out))
        out += residual
        return F.relu(out)

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm1 = nn.InstanceNorm2d(128)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.norm2 = nn.InstanceNorm2d(64)
        self.conv3 = nn.Conv2d(64, 3, kernel_size=7, padding=3)

    def forward(self, x):
        x = F.relu(self.norm1(self.deconv1(x)))
        x = F.relu(self.norm2(self.deconv2(x)))
        x = torch.tanh(self.conv3(x))  # Output with Tanh activation
        return x

class EnhanceNet(nn.Module):
    def __init__(self, num_res_blocks=9):
        super(EnhanceNet, self).__init__()
        self.encoder = Encoder()
        self.residual_blocks = nn.Sequential(*[ResidualBlock() for _ in range(num_res_blocks)])
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

In [None]:
import torch
import torch.nn as nn 
from torchvision import transforms
from PIL import Image
import os


model_paths = {
    'Haze': 'enhance_net_haze.pth',
    'Lens Blur': 'enhance_net_blur.pth',
    'Rain': 'enhance_net_rain.pth',
    'Shadow': 'enhance_net_shadow.pth',
    'Snow': 'enhance_net_snow.pth'
}


test_folders = {
    'Haze': '../CURE-TSR/Test/Haze-3',
    'Lens Blur': '../CURE-TSR/Test/Blur-3',
    'Rain': '../CURE-TSR/Test/Rain-3',
    'Shadow': '../CURE-TSR/Test/Shadow-3',
    'Snow': '../CURE-TSR/Test/Snow-3'
}

# Device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Load the EnhanceNet model architecture
class EnhanceNet(nn.Module):
    def __init__(self, num_res_blocks=9):
        super(EnhanceNet, self).__init__()
        self.encoder = Encoder()
        self.residual_blocks = nn.Sequential(*[ResidualBlock() for _ in range(num_res_blocks)])
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


def load_enhance_model(model_path):
    model = EnhanceNet().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model


def test_enhance_model(model, folder_path):
    for img_file in os.listdir(folder_path):
        if img_file.endswith(('.jpg', '.jpeg', '.png', '.bmp')):
            img_path = os.path.join(folder_path, img_file)
            image = Image.open(img_path).convert("RGB")
            image = transform(image).unsqueeze(0).to(device)  # Add batch dimension

            with torch.no_grad():
                enhanced_image = model(image)

            # Process the enhanced image (you can save or visualize it)
            print(f"Processed {img_file}")

# Loop through the conditions and test each model
for condition, model_path in model_paths.items():
    print(f"Testing EnhanceNet for {condition}...")
    
    # Load the specific model for this condition
    model = load_enhance_model(model_path)
    
    # Test the model on the corresponding images
    test_folder = test_folders[condition]
    test_enhance_model(model, test_folder)