In [4]:
pip install matplotlib

Collecting matplotlib
  Using cached matplotlib-3.9.4-cp39-cp39-win_amd64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Using cached contourpy-1.3.0-cp39-cp39-win_amd64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.56.0-cp39-cp39-win_amd64.whl.metadata (103 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Using cached kiwisolver-1.4.7-cp39-cp39-win_amd64.whl.metadata (6.4 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.1-py3-none-any.whl.metadata (5.0 kB)
Collecting importlib-resources>=3.2.0 (from matplotlib)
  Using cached importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)
Using cached matplotlib-3.9.4-cp39-cp39-win_amd64.whl (7.8 MB)
Using cached contourpy-1.3.0-cp39-cp39-win_amd64.whl (211 kB)
Using cached cycler-0.12.1-py3-none-any.whl (8.3 kB)
Downloading fontt

In [6]:
pip install opencv-python

Collecting opencv-python
  Using cached opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl.metadata (20 kB)
Using cached opencv_python-4.11.0.86-cp37-abi3-win_amd64.whl (39.5 MB)
Installing collected packages: opencv-python
Successfully installed opencv-python-4.11.0.86
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install scikit-image

Collecting scikit-image
  Using cached scikit_image-0.24.0-cp39-cp39-win_amd64.whl.metadata (14 kB)
Collecting scipy>=1.9 (from scikit-image)
  Using cached scipy-1.13.1-cp39-cp39-win_amd64.whl.metadata (60 kB)
Collecting networkx>=2.8 (from scikit-image)
  Using cached networkx-3.2.1-py3-none-any.whl.metadata (5.2 kB)
Collecting imageio>=2.33 (from scikit-image)
  Using cached imageio-2.37.0-py3-none-any.whl.metadata (5.2 kB)
Collecting tifffile>=2022.8.12 (from scikit-image)
  Using cached tifffile-2024.8.30-py3-none-any.whl.metadata (31 kB)
Collecting lazy-loader>=0.4 (from scikit-image)
  Using cached lazy_loader-0.4-py3-none-any.whl.metadata (7.6 kB)
Using cached scikit_image-0.24.0-cp39-cp39-win_amd64.whl (12.9 MB)
Using cached imageio-2.37.0-py3-none-any.whl (315 kB)
Using cached lazy_loader-0.4-py3-none-any.whl (12 kB)
Using cached networkx-3.2.1-py3-none-any.whl (1.6 MB)
Using cached scipy-1.13.1-cp39-cp39-win_amd64.whl (46.2 MB)
Using cached tifffile-2024.8.30-py3-none-any.wh

In [28]:
#1: Imports
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.ops import DeformConv2d
from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import os
from skimage.feature import canny
import tqdm

print("All imports successful!")
print("PyTorch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("CUDA:", torch.cuda.is_available())

All imports successful!
PyTorch: 1.12.1+cu113
Torchvision: 0.13.1+cu113
CUDA: True


In [29]:
#2: Edge Map Function
def get_edge_map(image):
    image_np = image.detach().cpu().numpy().transpose(1, 2, 0)  # Detach before converting to NumPy
    image_np = (image_np + 1) / 2  # [-1, 1] to [0, 1]
    edges = canny(image_np.mean(axis=2), sigma=2)  # Average RGB channels
    return torch.tensor(edges, dtype=torch.float32).unsqueeze(0)  # [1, H, W]

In [30]:
#3: Enhanced Generator

class DeformableBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.offset_conv = nn.Conv2d(in_channels, 2 * 3 * 3, 3, padding=1)
        self.deform_conv = DeformConv2d(in_channels, out_channels, 3, padding=1)
        self.norm = nn.GroupNorm(16, out_channels)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
    
    def forward(self, x):
        offset = self.offset_conv(x)
        x = self.deform_conv(x, offset)
        x = self.norm(x)
        x = self.relu(x)
        return x

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(16, channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(16, channels)
    
    def forward(self, x):
        residual = x
        x = F.leaky_relu(self.norm1(self.conv1(x)), 0.2)
        x = self.norm2(self.conv2(x))
        return x + residual

class AttentionModule(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        attn = self.sigmoid(self.conv(x))
        return x * attn

class EnhancedGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder: Downsample to 64x64
        self.enc1 = nn.Conv2d(4, 64, 4, 2, 1)  # 256x256 -> 128x128
        self.enc2 = DeformableBlock(64, 128)    # 128x128
        self.enc3 = nn.Conv2d(128, 256, 4, 2, 1)  # 128x128 -> 64x64
        
        # Residual blocks at bottleneck
        self.res_blocks = nn.ModuleList([ResidualBlock(256) for _ in range(6)])
        self.attn = AttentionModule(256)
        
        # Decoder: Upsample back to 256x256
        self.dec3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)  # 64x64 -> 128x128
        self.dec2 = DeformableBlock(128, 64)               # 128x128
        self.dec1 = nn.ConvTranspose2d(64, 3, 4, 2, 1)    # 128x128 -> 256x256
        
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        e1 = F.leaky_relu(self.enc1(x), 0.2)
        e2 = self.enc2(e1)
        e3 = F.leaky_relu(self.enc3(e2), 0.2)
        
        r = e3
        for block in self.res_blocks:
            r = block(r)
        r = self.attn(r)
        
        d3 = F.leaky_relu(self.dec3(r), 0.2)
        d2 = self.dec2(d3)
        d1 = self.dec1(d2)
        
        return self.tanh(d1)

In [31]:
#4: Multi-Scale Discriminator
class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.scale1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.GroupNorm(16, 128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0), nn.Sigmoid()
        )
        self.scale2 = nn.Sequential(
            nn.AvgPool2d(2),
            nn.Conv2d(3, 64, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.GroupNorm(16, 128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 4, 1, 0), nn.Sigmoid()
        )
    
    def forward(self, x):
        return [self.scale1(x), self.scale2(x)]

In [32]:
#5: Loss Functions
def compute_content_loss(fake, real):
    return F.l1_loss(fake, real)

def compute_adversarial_loss(discriminator, fake, real):
    real_out = discriminator(real)
    fake_out = discriminator(fake)
    loss = 0
    for ro, fo in zip(real_out, fake_out):
        loss += nn.BCELoss()(ro, torch.ones_like(ro)) + nn.BCELoss()(fo, torch.zeros_like(fo))
    return loss / 2

def compute_perceptual_loss(fake, real):
    return F.mse_loss(fake, real)  # Simplified; use VGG if desired

def compute_edge_loss(fake, real):
    fake_edges = get_edge_map(fake)
    real_edges = get_edge_map(real)
    return F.l1_loss(fake_edges, real_edges)

def compute_total_loss(generator, discriminator, fake, real):
    c_loss = compute_content_loss(fake, real)
    a_loss = compute_adversarial_loss(discriminator, fake, real)
    p_loss = compute_perceptual_loss(fake, real)
    e_loss = compute_edge_loss(fake, real)
    lambda1, lambda2, lambda3, lambda4 = 1.0, 3.0, 0.01, 0.1
    total_loss = (lambda1 * c_loss + lambda2 * a_loss + lambda3 * p_loss + lambda4 * e_loss)
    return total_loss, c_loss, a_loss, p_loss, e_loss

In [33]:
#6: Dataset
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class EUVPDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        
        # Collect all trainA and trainB images from subdirectories
        self.distorted_images = []
        self.gt_images = []
        
        # Iterate through all subdirectories in Paired
        for subdir in os.listdir(root_dir):
            trainA_dir = os.path.join(root_dir, subdir, 'trainA')
            trainB_dir = os.path.join(root_dir, subdir, 'trainB')
            
            if os.path.isdir(trainA_dir) and os.path.isdir(trainB_dir):
                trainA_files = sorted(os.listdir(trainA_dir))
                trainB_files = sorted(os.listdir(trainB_dir))
                
                # Ensure pairing matches
                assert len(trainA_files) == len(trainB_files), f"Mismatch in {subdir}: {len(trainA_files)} vs {len(trainB_files)}"
                
                # Add full paths to lists
                self.distorted_images.extend(os.path.join(trainA_dir, f) for f in trainA_files)
                self.gt_images.extend(os.path.join(trainB_dir, f) for f in trainB_files)
        
        assert len(self.distorted_images) == len(self.gt_images), "Total mismatch in image pairs"
        print(f"Total images loaded: {len(self.distorted_images)}")
    
    def __len__(self):
        return len(self.distorted_images)
    
    def __getitem__(self, idx):
        distorted = Image.open(self.distorted_images[idx]).convert('RGB')
        gt = Image.open(self.gt_images[idx]).convert('RGB')
        
        if self.transform:
            distorted = self.transform(distorted)
            gt = self.transform(gt)
        
        edge_map = get_edge_map(distorted).to(distorted.device)
        distorted = torch.cat([distorted, edge_map], dim=0)  # [4, 256, 256]
        return distorted, gt

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Point to the Paired directory
dataset = EUVPDataset(
    r'C:\Users\plawa\anaconda3\envs\underwater_gan_new\EUVP\Paired',
    transform=transform
)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Test loading
distorted, gt = dataset[0]
print(f"Distorted shape: {distorted.shape}, Ground truth shape: {gt.shape}")

Total images loaded: 11435
Distorted shape: torch.Size([4, 256, 256]), Ground truth shape: torch.Size([3, 256, 256])


In [34]:
#7: Training Loop
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm  # For progress bar

# Assuming EnhancedGenerator and MultiScaleDiscriminator are defined in Cells 3 and 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = EnhancedGenerator().to(device)
discriminator = MultiScaleDiscriminator().to(device)

# Optimizers with adjusted learning rates
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.00002, betas=(0.5, 0.999))  # Slightly higher lr for discriminator

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    # Add tqdm for progress bar
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]")
    
    for i, (distorted, ground_truth) in enumerate(train_loader_tqdm):
        distorted, ground_truth = distorted.to(device), ground_truth.to(device)
        
        # Discriminator update (every 25 batches)
        if i % 25 == 0:
            d_optimizer.zero_grad()
            with torch.no_grad():
                fake = generator(distorted)
            real_noise = ground_truth + torch.randn_like(ground_truth) * 0.1  # Reduced noise intensity
            fake_noise = fake.detach() + torch.randn_like(fake) * 0.1
            d_loss = compute_adversarial_loss(discriminator, fake_noise, real_noise)
            d_loss.backward()
            torch.nn.utils.clip_grad_norm_(discriminator.parameters(), max_norm=1.0)  # Gradient clipping
            d_optimizer.step()
        
        # Generator update
        g_optimizer.zero_grad()
        fake = generator(distorted)
        g_loss, c_loss, a_loss, p_loss, e_loss = compute_total_loss(generator, discriminator, fake, ground_truth)
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)  # Gradient clipping
        g_optimizer.step()
        
        # Update progress bar with metrics every 10 batches
        if i % 10 == 0:
            with torch.no_grad():
                psnr = peak_signal_noise_ratio(fake, ground_truth, data_range=2.0).item()
                ssim = structural_similarity_index_measure(fake, ground_truth, data_range=2.0).item()
            train_loader_tqdm.set_postfix({
                'D Loss': f'{d_loss.item():.4f}',
                'G Loss': f'{g_loss.item():.4f}',
                'PSNR': f'{psnr:.2f}',
                'SSIM': f'{ssim:.4f}'
            })
    
    # Save models after each epoch
    torch.save(generator.state_dict(), f"generator_epoch_{epoch+1}.pth")
    torch.save(discriminator.state_dict(), f"discriminator_epoch_{epoch+1}.pth")
    
    # Print epoch summary
    print(f"Epoch [{epoch+1}/{num_epochs}] completed. Saved models.")

# Final message
print("Training completed!")

Epoch [1/50]:   0%|                                                                                                                                      | 0/2859 [00:00<?, ?it/s]


ValueError: axes don't match array

In [2]:
#8 visualization

generator.eval()
with torch.no_grad():
    distorted, ground_truth = next(iter(train_loader))
    distorted, ground_truth = distorted.to(device), ground_truth.to(device)
    fake, _, _ = generator(distorted)
    distorted_np = distorted[0, :3, :, :].cpu().numpy().transpose(1, 2, 0)
    ground_truth_np = ground_truth[0].cpu().numpy().transpose(1, 2, 0)
    fake_np = fake[0].cpu().numpy().transpose(1, 2, 0)
    edge_np = distorted[0, 3, :, :].cpu().numpy()
    fig, axs = plt.subplots(1, 4, figsize=(20, 5))
    axs[0].imshow((distorted_np + 1) / 2)
    axs[0].set_title("Distorted")
    axs[1].imshow(edge_np, cmap='gray')
    axs[1].set_title("Edge Map")
    axs[2].imshow((fake_np + 1) / 2)
    axs[2].set_title("Restored")
    axs[3].imshow((ground_truth_np + 1) / 2)
    axs[3].set_title("Ground Truth")
    for ax in axs:
        ax.axis('off')
    plt.show()

NameError: name 'generator' is not defined