In [None]:
pip install torch torchvision torchaudio numpy matplotlib opencv-python tqdm


In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import cv2
import os
import numpy as np

class UnderwaterDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_paths = sorted(os.listdir(image_dir))
        self.label_paths = sorted(os.listdir(label_dir))
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_paths[idx])
        label_path = os.path.join(self.label_dir, self.label_paths[idx])

        image = cv2.imread(img_path)
        label = cv2.imread(label_path)

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image)
            label = self.transform(label)

        return image, label

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = UnderwaterDataset("data/images", "data/labels", transform=transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)


In [None]:
import torch.nn as nn

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // 8),
            nn.ReLU(),
            nn.Linear(in_channels // 8, in_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        attn = self.global_avg_pool(x).view(b, c)
        attn = self.fc(attn).view(b, c, 1, 1)
        return x * attn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU()
        )

        self.res_blocks = nn.Sequential(*[nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU()
        ) for _ in range(6)])

        self.attention = AttentionBlock(128)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.res_blocks(x)
        x = self.attention(x)
        x = self.decoder(x)
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 32 * 32, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
import torch.optim as optim

# Initialize models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss functions
criterion_gan = nn.BCELoss()
criterion_pixelwise = nn.MSELoss()

# Optimizers
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        real_images = real_images.to(device)

        # Create labels
        valid = torch.ones((real_images.size(0), 1)).to(device)
        fake = torch.zeros((real_images.size(0), 1)).to(device)

        # Train Generator
        optimizer_g.zero_grad()
        fake_images = generator(real_images)
        pred_fake = discriminator(fake_images)
        loss_g = criterion_gan(pred_fake, valid) + criterion_pixelwise(fake_images, real_images)
        loss_g.backward()
        optimizer_g.step()

        # Train Discriminator
        optimizer_d.zero_grad()
        pred_real = discriminator(real_images)
        pred_fake = discriminator(fake_images.detach())
        loss_real = criterion_gan(pred_real, valid)
        loss_fake = criterion_gan(pred_fake, fake)
        loss_d = (loss_real + loss_fake) / 2
        loss_d.backward()
        optimizer_d.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss D: {loss_d.item()} | Loss G: {loss_g.item()}")


In [None]:
import torchvision.utils as vutils
import matplotlib.pyplot as plt

def visualize_results():
    dataiter = iter(train_loader)
    real_images, _ = next(dataiter)
    real_images = real_images.to(device)

    with torch.no_grad():
        fake_images = generator(real_images)

    real_images = real_images.cpu().numpy().transpose(0, 2, 3, 1)
    fake_images = fake_images.cpu().numpy().transpose(0, 2, 3, 1)

    fig, axs = plt.subplots(2, 5, figsize=(10, 5))
    for i in range(5):
        axs[0, i].imshow(real_images[i])
        axs[1, i].imshow(fake_images[i])
    plt.show()

visualize_results()
