In [1]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
class ImagePairsDataset(Dataset):
    def __init__(self, originals_dir, filtered_dir, transform=None):
        self.originals_dir = originals_dir
        self.filtered_dir = filtered_dir
        self.transform = transform
        self.image_names = os.listdir(originals_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]
        original_path = os.path.join(self.originals_dir, image_name)
        filtered_path = os.path.join(self.filtered_dir, image_name)

        original_image = Image.open(original_path)
        filtered_image = Image.open(filtered_path)

        if self.transform:
            original_image = self.transform(original_image)
            filtered_image = self.transform(filtered_image)
        
        return original_image, filtered_image

In [3]:
# Funzione per il downsampling
class Downsample(nn.Module):
    def __init__(self, filters, size, apply_batchnorm=True):
        super(Downsample, self).__init__()
        layers = [
            nn.Conv2d(in_channels=filters, out_channels=filters, kernel_size=size, stride=2, padding=1, bias=False)
        ]
        if apply_batchnorm:
            layers.append(nn.BatchNorm2d(filters))
        
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


# Funzione per l'upsampling
class Upsample(nn.Module):
    def __init__(self, filters, size, apply_dropout=False):
        super(Upsample, self).__init__()
        layers = [
            nn.ConvTranspose2d(in_channels=filters, out_channels=filters // 2, kernel_size=size, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(filters // 2)
        ]

        if apply_dropout:
            layers.append(nn.Dropout(0.5))

        layers.append(nn.ReLU(inplace=True))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [4]:
# ResNet block with skip connections
class ResNetBlock(nn.Module):
    def __init__(self, filters):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(filters)
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(filters)

    def forward(self, x):
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        x += residual
        return F.relu(x)

# Simplified Transformer block with self-attention and cross-attention
class TransformerBlock(nn.Module):
    def __init__(self, filters, num_heads=4):
        super(TransformerBlock, self).__init__()
        self.self_attention = nn.MultiheadAttention(filters, num_heads)
        self.cross_attention = nn.MultiheadAttention(filters, num_heads)
        self.norm1 = nn.LayerNorm(filters)
        self.norm2 = nn.LayerNorm(filters)

    def forward(self, x):
        # Reshape for attention (batch, channels, height, width) -> (height*width, batch, channels)
        b, c, h, w = x.size()
        x_flat = x.view(b, c, -1).permute(2, 0, 1)  # (h*w, batch, channels)

        # Self-attention
        x_self_attended, _ = self.self_attention(x_flat, x_flat, x_flat)
        x = x + x_self_attended.permute(1, 2, 0).view(b, c, h, w)

        # Cross-attention (optional, applied to the same input for simplicity)
        x_flat = x.view(b, c, -1).permute(2, 0, 1)
        x_cross_attended, _ = self.cross_attention(x_flat, x_flat, x_flat)
        x = x + x_cross_attended.permute(1, 2, 0).view(b, c, h, w)

        return x

In [5]:
# Classe del generatore
class Generator(nn.Module):
    def __init__(self, output_channels):
        super(Generator, self).__init__()

        # Stack del downsampling
        self.down_stack = nn.ModuleList([
            Downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
            Downsample(128, 4),  # (batch_size, 64, 64, 128)
            Downsample(256, 4),  # (batch_size, 32, 32, 256)
            Downsample(512, 4),  # (batch_size, 16, 16, 512)
        ])

        # ResNet blocks
        self.resnet_block1 = ResNetBlock(512)
        self.resnet_block2 = ResNetBlock(512)

        # Transformer block
        self.transformer_block = TransformerBlock(512)

        # Stack dell'upsampling
        self.up_stack = nn.ModuleList([
            Upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
            Upsample(512, 4),  # (batch_size, 16, 16, 1024)
            Upsample(256, 4),  # (batch_size, 32, 32, 512)
            Upsample(128, 4),  # (batch_size, 64, 64, 256)
        ])

        # Ultimo layer di upsampling
        self.last = nn.ConvTranspose2d(in_channels=64, out_channels=output_channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        skips = []

        # Downsampling
        for down in self.down_stack:
            x = down(x)
            skips.append(x)

        # ResNet blocks
        x = self.resnet_block1(x)
        x = self.resnet_block2(x)

        # Transformer block
        x = self.transformer_block(x)

        # Upsampling with skip connections
        skips = skips[::-1]
        for up, skip in zip(self.up_stack, skips):
            x = up(x)
            x = torch.cat((x, skip), dim=1)

        # Final output layer
        x = self.last(x)
        x = torch.tanh(x)  # Output activation

        return x


class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        # Downsampling layers (adjusted to accept concatenated input and target channels)
        self.down1 = Downsample(input_channels * 2, 64, apply_batchnorm=False)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)  # Additional downsampling for extra capacity

        # Additional layers
        self.zero_pad1 = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(512, 512, kernel_size=4, stride=1, padding=0, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(512)
        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)

        self.zero_pad2 = nn.ZeroPad2d(1)
        self.last = nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()  # Final activation layer to constrain output

    def forward(self, inp, tar):
        # Concatenate input and target along channels
        x = torch.cat([inp, tar], dim=1)  # Final dimension: (batch_size, height, width, channels*2)

        # Downsampling
        x = self.down1(x)  # (batch_size, 128, 128, 64)
        x = self.down2(x)  # (batch_size, 64, 64, 128)
        x = self.down3(x)  # (batch_size, 32, 32, 256)
        x = self.down4(x)  # (batch_size, 16, 16, 512)

        # Convolutional layers with batch normalization and activation
        x = self.zero_pad1(x)  # (batch_size, 18, 18, 512)
        x = self.conv(x)       # (batch_size, 15, 15, 512)
        x = self.batchnorm1(x)
        x = self.leaky_relu(x)

        # Final layers
        x = self.zero_pad2(x)  # (batch_size, 17, 17, 512)
        x = self.last(x)       # (batch_size, 14, 14, 1)
        x = self.sigmoid(x)    # Apply sigmoid to constrain output to [0, 1]

        return x

In [6]:
class GANTrainer:
    def __init__(self, generator, discriminator, device, lr=2e-4):
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device)
        self.device = device

        # Optimizers for generator and discriminator
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(0.5, 0.999))
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

        # Loss function
        self.criterion = nn.BCELoss()


    def train_epoch(self, train_loader, epoch):
        self.generator.train()
        self.discriminator.train()

        total_loss_G = 0
        total_loss_D = 0

        for batch_idx, (original, filtered) in enumerate(train_loader):
            original, filtered = original.to(self.device), filtered.to(self.device)

            # Train Discriminator
            self.optimizer_D.zero_grad()
            # Real labels = 1, fake labels = 0
            real_labels = torch.ones(original.size(0), 1, 1, 1, device=self.device)
            fake_labels = torch.zeros(original.size(0), 1, 1, 1, device=self.device)

            # Discriminator loss on real images
            real_output = self.discriminator(original, filtered)
            loss_real = self.criterion(real_output, real_labels)

            # Generate fake images
            fake_images = self.generator(original)

            # Discriminator loss on fake images
            fake_output = self.discriminator(original, fake_images.detach())
            loss_fake = self.criterion(fake_output, fake_labels)

            # Combine losses and update discriminator
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            self.optimizer_D.step()
            total_loss_D += loss_D.item()

            # Train Generator
            self.optimizer_G.zero_grad()

            # Generator loss (fooling the discriminator)
            fake_output = self.discriminator(original, fake_images)
            loss_G = self.criterion(fake_output, real_labels)
            loss_G.backward()
            self.optimizer_G.step()
            total_loss_G += loss_G.item()

        avg_loss_G = total_loss_G / len(train_loader)
        avg_loss_D = total_loss_D / len(train_loader)

        print(f"Epoch [{epoch+1}], Loss D: {avg_loss_D:.4f}, Loss G: {avg_loss_G:.4f}")
        return avg_loss_D, avg_loss_G


### Training

In [7]:
# Function to train the GAN and save models
def train_and_save_gan(generator, discriminator, data_loader, epochs=100, save_path="gan_model"):
    for epoch in range(epochs):
        for real_imgs, _ in data_loader:
            batch_size = real_imgs.size(0)
            valid = torch.ones(batch_size, 1)
            fake = torch.zeros(batch_size, 1)

            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()

            # Generate noise and create fake images
            z = torch.randn(batch_size, latent_dim)
            generated_imgs = generator(z)
            g_loss = adversarial_loss(discriminator(generated_imgs), valid)

            # Backpropagation for generator
            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            # Loss for real images
            real_loss = adversarial_loss(discriminator(real_imgs), valid)

            # Loss for fake images
            fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            # Backpropagation for discriminator
            d_loss.backward()
            optimizer_D.step()

        # Print progress every epoch
        print(f"Epoch {epoch+1}/{epochs} | D Loss: {d_loss.item()} | G Loss: {g_loss.item()}")
        
        # Salva i checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': trainer.optimizer_G.state_dict(),
                'optimizer_D_state_dict': trainer.optimizer_D.state_dict(),
            }, f'checkpoint_epoch_{epoch+1}.pt')

    # Save the generator and discriminator models
    torch.save(generator.state_dict(), f"{save_path}_generator.pth")
    torch.save(discriminator.state_dict(), f"{save_path}_discriminator.pth")
    print(f"Models saved to {save_path}_generator.pth and {save_path}_discriminator.pth")

# Main

In [8]:
# Set the random seed for reproducibility
torch.manual_seed(42)


# Determine the device to run on (GPU if available, else CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize to [-1, 1]
])


# Dataset e DataLoader
originals_dir = '/kaggle/input/e2gan-images/original_images'
filtered_dir = '/kaggle/input/e2gan-images/modified_images'

In [None]:
# Create the dataset and split it into train, validation, and test sets
dataset = ImagePairsDataset(originals_dir, filtered_dir, transform=transform)

# Split dataset into train (80%), validation (10%), and test (10%)
train_set, val_set, test_set = random_split(dataset, [0.8, 0.1, 0.1])

# DataLoader for loading batches of data
batch_size = 32
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

# Initialize the generator and discriminator models
generator = Generator(output_channels=3).to(device)  # output_channels=3 for RGB images
discriminator = Discriminator(input_channels=3).to(device)  # input_channels=3 for RGB images

# Run the training and save the model
train_and_save_gan(generator, discriminator, train_loader, epochs=10, save_path="simple_gan")



In [None]:
# # Initialize the generator and discriminator models
# generator = Generator(output_channels=3).to(device)  # output_channels=3 for RGB images
# discriminator = Discriminator(input_channels=3).to(device)  # input_channels=3 for RGB images


# # Initialize the GANTrainer class
# trainer = GANTrainer(generator, discriminator, device)


# # Number of epochs to train
# epochs = 100
# for epoch in range(epochs):
#     # Train for one epoch
#     loss_D, loss_G = trainer.train_epoch(train_loader, epoch)

#     # Save model checkpoints every 10 epochs
#     if (epoch + 1) % 10 == 0:

#         checkpoint = {
#             'epoch': epoch,
#             'generator_state_dict': generator.state_dict(),
#             'discriminator_state_dict': discriminator.state_dict(),
#             'optimizer_G_state_dict': trainer.optimizer_G.state_dict(),
#             'optimizer_D_state_dict': trainer.optimizer_D.state_dict(),
#         }

#         torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pt')
#         print(f"Checkpoint saved for epoch {epoch+1}")

# print("Training completed!")