In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.utils.checkpoint as checkpoint
from torchvision import models, transforms
from torchvision.models import vgg19
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

In [2]:
# del noisy, gt, fake_images, fake_pred
torch.cuda.empty_cache()


In [3]:
# Dataset Class
class LowLightDataset(Dataset):
    def __init__(self, noisy_dir, gt_dir=None, transform_noisy=None, transform_gt=None):
        self.noisy_dir = noisy_dir
        self.gt_dir = gt_dir
        self.noisy_files = sorted(os.listdir(noisy_dir))
        self.gt_files = sorted(os.listdir(gt_dir)) if gt_dir else None
        self.transform_noisy = transform_noisy
        self.transform_gt = transform_gt

    def __len__(self):
        return len(self.noisy_files)

    def __getitem__(self, idx):
        # Load noisy image
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[idx])
        noisy_image = Image.open(noisy_path).convert('RGB')
        if self.transform_noisy:
            noisy_image = self.transform_noisy(noisy_image)

        # Load ground truth image if available
        if self.gt_dir:
            gt_path = os.path.join(self.gt_dir, self.gt_files[idx])
            gt_image = Image.open(gt_path).convert('RGB')
            if self.transform_gt:
                gt_image = self.transform_gt(gt_image)
            return noisy_image, gt_image

        return noisy_image

# Transformations
transform_noisy = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ResNet normalization
])

transform_gt = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [4]:
# Dataset Paths
train_noisy_dir = "/kaggle/input/enhance-the-dark-world/archive/train/train"
train_gt_dir = "/kaggle/input/enhance-the-dark-world/archive/train/gt"
val_noisy_dir = "/kaggle/input/enhance-the-dark-world/archive/val/val"
val_gt_dir = "/kaggle/input/enhance-the-dark-world/archive/val/gt"

test_noisy_dir = "/kaggle/input/enhance-the-dark-world/archive/test"

# Datasets and DataLoaders
train_dataset = LowLightDataset(train_noisy_dir, train_gt_dir, transform_noisy, transform_gt)
val_dataset = LowLightDataset(val_noisy_dir, val_gt_dir, transform_noisy, transform_gt)
test_dataset = LowLightDataset(test_noisy_dir, transform_noisy=transform_noisy)

# train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
# val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
# test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

In [5]:
# Test the train loader
for noisy_img, gt_img in train_loader:
    print(f"Noisy Image Shape: {noisy_img.shape}, GT Image Shape: {gt_img.shape}")
    break

# Test the val loader
for noisy_img, gt_img in val_loader:
    print(f"Noisy Image Shape: {noisy_img.shape}, GT Image Shape: {gt_img.shape}")
    break

# Test the test loader
for noisy_img in test_loader:
    print(f"Noisy Image Shape: {noisy_img.shape}")
    break


Noisy Image Shape: torch.Size([8, 3, 160, 256]), GT Image Shape: torch.Size([8, 3, 640, 1024])
Noisy Image Shape: torch.Size([8, 3, 160, 256]), GT Image Shape: torch.Size([8, 3, 640, 1024])
Noisy Image Shape: torch.Size([8, 3, 160, 256])


In [6]:
# Model Architecture

In [7]:
# import torch
# import torch.nn as nn
# import torch.utils.checkpoint as checkpoint

# # Define Generator Network
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.conv1 = nn.Conv2d(3, 16, kernel_size=9, stride=1, padding=4)  # Reduced channels
#         self.prelu = nn.PReLU()
#         self.res_blocks = nn.ModuleList([ResidualBlock(16) for _ in range(2)])  # Fewer residual blocks
#         self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
#         self.bn = nn.BatchNorm2d(16)
#         self.upsample = nn.Sequential(
#             UpsampleBlock(16),
#             nn.Conv2d(16, 3, kernel_size=9, stride=1, padding=4),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x1 = self.prelu(self.conv1(x))
#         x2 = x1
#         for block in self.res_blocks:
#             x2 = checkpoint.checkpoint(block, x2)
#         x3 = self.bn(self.conv2(x2)) + x1
#         return self.upsample(x3)


# class ResidualBlock(nn.Module):
#     def __init__(self, channels):
#         super(ResidualBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels),
#             nn.PReLU(),
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels)
#         )

#     def forward(self, x):
#         return x + self.block(x)


# class UpsampleBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(UpsampleBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1),
#             nn.PixelShuffle(upscale_factor=2),
#             nn.PReLU()
#         )

#     def forward(self, x):
#         return self.block(x)


# # Define Discriminator Network
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),  # Reduced initial channels
#             nn.LeakyReLU(0.2, inplace=True),

#             self._block(16, 32, stride=2),  # Reduced complexity
#             self._block(32, 32, stride=2),

#             nn.Flatten(),
#             nn.Linear(32 * (640 // 4) * (1024 // 4), 128),  # Smaller dimensions
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(128, 1)
#         )

#     def _block(self, in_channels, out_channels, stride):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.LeakyReLU(0.2, inplace=True)
#         )

#     def forward(self, x):
#         for layer in self.model:
#             x = checkpoint.checkpoint(layer, x) if isinstance(layer, nn.Sequential) else layer(x)
#         return x


# # Additional memory-saving techniques can also be applied during training, such as mixed precision (using PyTorch's AMP).
# # This reduced version should be much more memory-efficient and less computationally intensive.


In [8]:
# Define Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=9, stride=1, padding=4)  # Reduced channels
        self.prelu = nn.PReLU()
        self.res_blocks = nn.ModuleList([ResidualBlock(16) for _ in range(3)])  # Fewer residual blocks
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1)
        self.bn = nn.BatchNorm2d(16)
        self.upsample = nn.Sequential(
            UpsampleBlock(16),
            UpsampleBlock(16),
            nn.Conv2d(16, 3, kernel_size=9, stride=1, padding=4),
            nn.Tanh()
        )

    def forward(self, x):
        x1 = self.prelu(self.conv1(x))
        x2 = x1
        for block in self.res_blocks:
            x2 = checkpoint.checkpoint(block, x2)
        x3 = self.bn(self.conv2(x2)) + x1
        return self.upsample(x3)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


class UpsampleBlock(nn.Module):
    def __init__(self, in_channels):
        super(UpsampleBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(upscale_factor=2),
            nn.PReLU()
        )

    def forward(self, x):
        return self.block(x)


# Define Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),  # Reduced channels
            nn.LeakyReLU(0.2, inplace=True),

            self._block(16, 32, stride=2),
            self._block(32, 32, stride=1),
            self._block(32, 64, stride=2),
            self._block(64, 64, stride=1),

            nn.Flatten(),
            nn.Linear(64 * (640 // 4) * (1024 // 4), 256),  # Reduced dimensions
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def _block(self, in_channels, out_channels, stride):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        for layer in self.model:
            x = checkpoint.checkpoint(layer, x) if isinstance(layer, nn.Sequential) else layer(x)
        return x


Getting out of memory error, again and again. How can i calculate how much momory the model will consume?

In [9]:
# # Define Generator Network
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.conv1 = nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4)
#         self.prelu = nn.PReLU()
#         self.res_blocks = nn.ModuleList([ResidualBlock(32) for _ in range(4)])  # ModuleList for checkpointing
#         self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
#         self.bn = nn.BatchNorm2d(32)
#         self.upsample = nn.Sequential(
#             UpsampleBlock(32),
#             UpsampleBlock(32),
#             nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x1 = self.prelu(self.conv1(x))
#         x2 = x1
#         for block in self.res_blocks:
#             x2 = checkpoint.checkpoint(block, x2)  # Apply checkpointing here
#         x3 = self.bn(self.conv2(x2)) + x1
#         return self.upsample(x3)


# class ResidualBlock(nn.Module):
#     def __init__(self, channels):
#         super(ResidualBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels),
#             nn.PReLU(),
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels)
#         )

#     def forward(self, x):
#         return x + self.block(x)


# class UpsampleBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(UpsampleBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1),
#             nn.PixelShuffle(upscale_factor=2),
#             nn.PReLU()
#         )

#     def forward(self, x):
#         return self.block(x)


# # Define Discriminator Network
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
#             nn.LeakyReLU(0.2, inplace=True),

#             self._block(32, 32, stride=2),
#             self._block(32, 64, stride=1),
#             self._block(64, 64, stride=2),
#             self._block(64, 128, stride=1),
#             self._block(128, 128, stride=2),

#             nn.Flatten(),
#             nn.Linear(128 * (640 // 8) * (1024 // 8), 1024),  # Adjusted dimensions
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 1)
#         )

#     def _block(self, in_channels, out_channels, stride):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.LeakyReLU(0.2, inplace=True)
#         )

#     def forward(self, x):
#         for layer in self.model:
#             x = checkpoint.checkpoint(layer, x) if isinstance(layer, nn.Sequential) else layer(x)  # Apply checkpointing
#         return x


In [10]:
# # Define Generator Network
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.conv1 = nn.Conv2d(3, 32, kernel_size=9, stride=1, padding=4)
#         self.prelu = nn.PReLU()
#         self.res_blocks = nn.Sequential(*[ResidualBlock(32) for _ in range(4)])  # Reduced depth and channels
#         self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
#         self.bn = nn.BatchNorm2d(32)
#         self.upsample = nn.Sequential(
#             UpsampleBlock(32),
#             UpsampleBlock(32),
#             nn.Conv2d(32, 3, kernel_size=9, stride=1, padding=4),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x1 = self.prelu(self.conv1(x))
#         x2 = self.res_blocks(x1)
#         x3 = self.bn(self.conv2(x2)) + x1
#         return self.upsample(x3)


# class ResidualBlock(nn.Module):
#     def __init__(self, channels):
#         super(ResidualBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels),
#             nn.PReLU(),
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(channels)
#         )

#     def forward(self, x):
#         return x + self.block(x)


# class UpsampleBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(UpsampleBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1),
#             nn.PixelShuffle(upscale_factor=2),
#             nn.PReLU()
#         )

#     def forward(self, x):
#         return self.block(x)


# # Define Discriminator Network
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             self._block(32, 32, stride=2),
#             self._block(32, 64, stride=1),
#             self._block(64, 64, stride=2),
#             self._block(64, 128, stride=1),
#             self._block(128, 128, stride=2),
            
#             nn.Flatten(),
#             nn.Linear(128 * (640 // 8) * (1024 // 8), 1024),  # Adjusted dimensions
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 1)
#         )

#     def _block(self, in_channels, out_channels, stride):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.LeakyReLU(0.2, inplace=True)
#         )

#     def forward(self, x):
#         return self.model(x)


In [11]:
# # Define Generator Network
# class Generator(nn.Module):
#     def __init__(self):
#         super(Generator, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4)
#         self.prelu = nn.PReLU()
#         # self.res_blocks = nn.Sequential(*[ResidualBlock() for _ in range(16)])
#         self.res_blocks = nn.Sequential(*[ResidualBlock() for _ in range(4)])
#         self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
#         self.bn = nn.BatchNorm2d(64)
#         self.upsample = nn.Sequential(
#             UpsampleBlock(64),
#             UpsampleBlock(64),
#             nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4),
#             nn.Tanh()
#         )

#     def forward(self, x):
#         x1 = self.prelu(self.conv1(x))
#         x2 = self.res_blocks(x1)
#         x3 = self.bn(self.conv2(x2)) + x1
#         return self.upsample(x3)

# class ResidualBlock(nn.Module):
#     def __init__(self):
#         super(ResidualBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(64),
#             nn.PReLU(),
#             nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
#             nn.BatchNorm2d(64)
#         )

#     def forward(self, x):
#         return x + self.block(x)

# class UpsampleBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(UpsampleBlock, self).__init__()
#         self.block = nn.Sequential(
#             nn.Conv2d(in_channels, in_channels * 4, kernel_size=3, stride=1, padding=1),
#             nn.PixelShuffle(upscale_factor=2),
#             nn.PReLU()
#         )

#     def forward(self, x):
#         return self.block(x)

# # Define Discriminator Network
# class Discriminator(nn.Module):
#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.model = nn.Sequential(
#             nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
#             nn.LeakyReLU(0.2, inplace=True),
            
#             self._block(64, 64, stride=2),
#             self._block(64, 128, stride=1),
#             self._block(128, 128, stride=2),
#             self._block(128, 256, stride=1),
#             self._block(256, 256, stride=2),
#             self._block(256, 512, stride=1),
#             self._block(512, 512, stride=2),

#             nn.Flatten(),
#             nn.Linear(512 * (640 // 16) * (1024 // 16), 1024),
#             nn.LeakyReLU(0.2, inplace=True),
#             nn.Linear(1024, 1)
#         )

#     def _block(self, in_channels, out_channels, stride):
#         return nn.Sequential(
#             nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
#             nn.BatchNorm2d(out_channels),
#             nn.LeakyReLU(0.2, inplace=True)
#         )

#     def forward(self, x):
#         return self.model(x)


In [12]:
# Define Perceptual Loss
class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = vgg19(pretrained=True).features
        self.features = nn.Sequential(*list(vgg)[:36]).eval()
        for param in self.features.parameters():
            param.requires_grad = False

    def forward(self, sr, hr):
        sr_features = self.features(sr)
        hr_features = self.features(hr)
        return nn.MSELoss()(sr_features, hr_features)


In [13]:
# # Define Perceptual Loss
# class PerceptualLoss(nn.Module):
#     def __init__(self):
#         super(PerceptualLoss, self).__init__()
#         vgg = vgg19(pretrained=True).features
#         self.features = nn.Sequential(*list(vgg)[:36]).eval()
#         for param in self.features.parameters():
#             param.requires_grad = False

#     def forward(self, sr, hr):
#         sr_features = self.features(sr)
#         hr_features = self.features(hr)
#         return nn.MSELoss()(sr_features, hr_features)


In [14]:
# Save and Load Checkpoints
def save_checkpoint(generator, discriminator, gen_optimizer, disc_optimizer, epoch, filepath):
    torch.save({
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'gen_optimizer_state_dict': gen_optimizer.state_dict(),
        'disc_optimizer_state_dict': disc_optimizer.state_dict(),
        'epoch': epoch
    }, filepath)

def load_checkpoint(filepath, generator, discriminator, gen_optimizer, disc_optimizer):
    checkpoint = torch.load(filepath)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
    gen_optimizer.load_state_dict(checkpoint['gen_optimizer_state_dict'])
    disc_optimizer.load_state_dict(checkpoint['disc_optimizer_state_dict'])
    return checkpoint['epoch']


In [15]:
# Training Loop
def train_srgan(generator, discriminator, train_loader, val_loader, gen_optimizer, disc_optimizer, perceptual_loss, adversarial_loss, device, epochs=10, checkpoint_path="srgan_checkpoint.pth"):
    start_epoch = 0
    if checkpoint_path:
        try:
            start_epoch = load_checkpoint(checkpoint_path, generator, discriminator, gen_optimizer, disc_optimizer)
            print(f"Resumed training from epoch {start_epoch+1}")
        except FileNotFoundError:
            print("No checkpoint found, starting training from scratch.")

    for epoch in range(start_epoch, epochs):
        torch.cuda.empty_cache()

        generator.train()
        discriminator.train()
        g_loss_epoch, d_loss_epoch = 0.0, 0.0

        for noisy, gt in tqdm(train_loader):
            noisy, gt = noisy.to(device), gt.to(device)

            # Train Discriminator
            disc_optimizer.zero_grad()
            real_pred = discriminator(gt)
            fake_images = generator(noisy)
            fake_pred = discriminator(fake_images.detach())

            real_loss = adversarial_loss(real_pred, torch.ones_like(real_pred, device=device))
            fake_loss = adversarial_loss(fake_pred, torch.zeros_like(fake_pred, device=device))
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            disc_optimizer.step()

            # Train Generator
            gen_optimizer.zero_grad()
            fake_pred = discriminator(fake_images)
            g_loss = perceptual_loss(fake_images, gt) + 1e-3 * adversarial_loss(fake_pred, torch.ones_like(fake_pred, device=device))
            g_loss.backward()
            gen_optimizer.step()

            g_loss_epoch += g_loss.item()
            d_loss_epoch += d_loss.item()

        print(f"Epoch [{epoch+1}/{epochs}] Generator Loss: {g_loss_epoch:.4f}, Discriminator Loss: {d_loss_epoch:.4f}")

        # Save checkpoint after each epoch
        save_checkpoint(generator, discriminator, gen_optimizer, disc_optimizer, epoch, checkpoint_path)


In [16]:
# # Training Loop
# def train_srgan(generator, discriminator, train_loader, val_loader, gen_optimizer, disc_optimizer, perceptual_loss, adversarial_loss, device, epochs=10):
#     for epoch in range(epochs):
#         torch.cuda.empty_cache()

#         generator.train()
#         discriminator.train()
#         g_loss_epoch, d_loss_epoch = 0.0, 0.0

#         for noisy, gt in tqdm(train_loader):
#             noisy, gt = noisy.to(device), gt.to(device)

#             # Train Discriminator
#             disc_optimizer.zero_grad()
#             real_pred = discriminator(gt)
#             fake_images = generator(noisy)
#             fake_pred = discriminator(fake_images.detach())

#             real_loss = adversarial_loss(real_pred, torch.ones_like(real_pred, device=device))
#             fake_loss = adversarial_loss(fake_pred, torch.zeros_like(fake_pred, device=device))
#             d_loss = (real_loss + fake_loss) / 2
#             d_loss.backward()
#             disc_optimizer.step()

#             # Train Generator
#             gen_optimizer.zero_grad()
#             fake_pred = discriminator(fake_images)
#             g_loss = perceptual_loss(fake_images, gt) + 1e-3 * adversarial_loss(fake_pred, torch.ones_like(fake_pred, device=device))
#             g_loss.backward()
#             gen_optimizer.step()

#             g_loss_epoch += g_loss.item()
#             d_loss_epoch += d_loss.item()

#         print(f"Epoch [{epoch+1}/{epochs}] Generator Loss: {g_loss_epoch:.4f}, Discriminator Loss: {d_loss_epoch:.4f}")


In [17]:
# Initialize Models and Optimizers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator().to(device)
discriminator = Discriminator().to(device)

perceptual_loss = PerceptualLoss().to(device)
adversarial_loss = nn.BCEWithLogitsLoss().to(device)

gen_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 183MB/s] 


In [18]:
# # Initialize Models and Optimizers
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# generator = Generator().to(device)
# discriminator = Discriminator().to(device)

# perceptual_loss = PerceptualLoss().to(device)
# adversarial_loss = nn.BCEWithLogitsLoss().to(device)

# gen_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
# disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)


In [19]:
# # Training Setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Instantiate the model
# model = ResNetUNet(pretrained=True).to(device)

# # Define the loss function and optimizer
# criterion = nn.MSELoss()  # Or SSIMLoss for perceptual quality
# optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [20]:
# # Initialize Models and Optimizers for Multi-GPU Training
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Wrap the models with DataParallel
# generator = nn.DataParallel(Generator()).to(device)
# discriminator = nn.DataParallel(Discriminator()).to(device)

# perceptual_loss = PerceptualLoss().to(device)
# adversarial_loss = nn.BCEWithLogitsLoss().to(device)

# gen_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
# disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

# # Train the Model on Multi-GPUs
# train_srgan(generator, discriminator, train_loader, val_loader, gen_optimizer, disc_optimizer,
#             perceptual_loss, adversarial_loss, device, epochs=8, checkpoint_path="srgan_checkpoint.pth")

In [None]:
# Train the Model
train_srgan(generator, discriminator, train_loader, val_loader, gen_optimizer, disc_optimizer, perceptual_loss, adversarial_loss, device, epochs=8, checkpoint_path="srgan_checkpoint.pth")


  checkpoint = torch.load(filepath)


No checkpoint found, starting training from scratch.


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
100%|██████████| 139/139 [7:21:33<00:00, 190.60s/it]  


Epoch [1/8] Generator Loss: 39.0977, Discriminator Loss: 94.3174


  0%|          | 0/139 [00:00<?, ?it/s]

In [None]:
# # Train the Model
# train_srgan(generator, discriminator, train_loader, val_loader, gen_optimizer, disc_optimizer, perceptual_loss, adversarial_loss, device, epochs=8)


In [None]:
class LowLightDataset(Dataset):
    def __init__(self, noisy_dir, gt_dir=None, transform_noisy=None, transform_gt=None):
        self.noisy_dir = noisy_dir
        self.gt_dir = gt_dir
        self.noisy_files = sorted(os.listdir(noisy_dir))
        self.gt_files = sorted(os.listdir(gt_dir)) if gt_dir else None
        self.transform_noisy = transform_noisy
        self.transform_gt = transform_gt

    def __len__(self):
        return len(self.noisy_files)

    def __getitem__(self, idx):
        # Load noisy image
        noisy_path = os.path.join(self.noisy_dir, self.noisy_files[idx])
        noisy_image = Image.open(noisy_path).convert('RGB')
        if self.transform_noisy:
            noisy_image = self.transform_noisy(noisy_image)

        # Include filename
        filename = os.path.basename(noisy_path)

        # Load ground truth image if available
        if self.gt_dir:
            gt_path = os.path.join(self.gt_dir, self.gt_files[idx])
            gt_image = Image.open(gt_path).convert('RGB')
            if self.transform_gt:
                gt_image = self.transform_gt(gt_image)
            return noisy_image, gt_image, filename

        return noisy_image, filename

test_dataset = LowLightDataset(test_noisy_dir, transform_noisy=transform_noisy)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)


In [None]:
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
# i = 0
# Generate Submission
def save_test_predictions(model, dataloader, output_dir, device):
    i = 0
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    with torch.no_grad():
        for noisy_img, filenames in tqdm(test_loader):
            noisy_img = noisy_img.to(device)
            outputs = model(noisy_img).cpu()
            # i += 1
            # if i<=3:
            #     print(outputs.shape)
            outputs = outputs * 0.5 + 0.5  # Denormalize to [0, 1]
            outputs = outputs.clamp(0, 1)
            for i, filename in enumerate(filenames):
                # submission.append(img.permute(1, 2, 0).numpy())
                output_path = os.path.join(output_dir, f"{filename.split('.')[0]}.png")
                curr_img = outputs[i].permute(1, 2, 0).numpy()
                # plt.imsave(output_path, curr_img, cmap='viridis')
                curr_img = (curr_img * 255).astype('uint8')     # Convert [0, 1] to [0, 255]
                Image.fromarray(curr_img).save(output_path)     # Save as an RGB image
                # plt.imsave(output_path, curr_img)

# Directory to save test predictions
test_output_dir = "test_outputs_for_pred"
save_test_predictions(model, test_loader, test_output_dir, device)


In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image

def images_to_csv(folder_path, output_csv):
    data_rows = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(folder_path, filename)
            image = Image.open(image_path).convert('L') 
            image_array = np.array(image).flatten()[::8]
            data_rows.append([filename.split('.')[0], *image_array])
    column_names = ['ID'] + [f'pixel_{i}' for i in range(len(data_rows[0]) - 1)]
    df = pd.DataFrame(data_rows, columns=column_names)
    df.to_csv(output_csv, index=False)
    print(f'Successfully saved to {output_csv}')

folder_path = '/kaggle/working/test_outputs_for_pred'
output_csv = 'submission.csv'
images_to_csv(folder_path, output_csv)


In [None]:
out_csv = pd.read_csv('submission.csv')
out_csv