In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image
import cv2
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
from pathlib import Path


class BorderDataset(Dataset):
    def __init__(self, cut_objects_path, original_images_path, new_backgrounds_path, transform=None, image_size=256):
        self.cut_objects_path = Path(cut_objects_path)
        self.original_images_path = Path(original_images_path)
        self.new_backgrounds_path = Path(new_backgrounds_path)
        self.transform = transform
        self.image_size = image_size

        # Get all image files
        self.cut_objects = list(self.cut_objects_path.glob(
            '*.jpg')) + list(self.cut_objects_path.glob('*.png'))
        self.original_images = list(self.original_images_path.glob(
            '*.jpg')) + list(self.original_images_path.glob('*.png'))
        self.new_backgrounds = list(self.new_backgrounds_path.glob(
            '*.jpg')) + list(self.new_backgrounds_path.glob('*.png'))

        # Match by filename
        self.matched_pairs = []
        for cut_obj in self.cut_objects:
            # Find corresponding original image
            original_match = None
            for orig in self.original_images:
                if cut_obj.stem == orig.stem:
                    original_match = orig
                    break

            if original_match:
                self.matched_pairs.append((cut_obj, original_match))

    def __len__(self):
        return len(self.matched_pairs) * len(self.new_backgrounds)

    def create_border_mask(self, cut_object, border_width=10):
        """Create a mask for the border region"""
        try:
            # Convert to grayscale and create alpha mask
            if len(cut_object.shape) == 3 and cut_object.shape[2] == 4:  # RGBA
                alpha = cut_object[:, :, 3].astype(np.float32) / 255.0
            else:  # RGB - create alpha from non-black pixels
                alpha = np.any(cut_object > 10, axis=2).astype(np.float32)

            # Ensure alpha is in correct range and type
            alpha = np.clip(alpha, 0, 1).astype(np.uint8)

            # Create border mask using morphological operations
            kernel = np.ones((border_width, border_width), np.uint8)
            dilated = cv2.dilate(alpha, kernel, iterations=1)
            eroded = cv2.erode(alpha, kernel, iterations=1)
            border_mask = dilated - eroded

            return border_mask.astype(np.float32)

        except Exception as e:
            print(f"Error creating border mask: {e}")
            print(f"Cut object shape: {cut_object.shape}")
            # Return a default mask in case of error
            return np.ones((cut_object.shape[0], cut_object.shape[1]), dtype=np.float32) * 0.1

    def __getitem__(self, idx):
        try:
            # Get pair index and background index
            pair_idx = idx // len(self.new_backgrounds)
            bg_idx = idx % len(self.new_backgrounds)

            cut_obj_path, original_path = self.matched_pairs[pair_idx]
            bg_path = self.new_backgrounds[bg_idx]

            # Load images with error handling
            cut_object = cv2.imread(str(cut_obj_path), cv2.IMREAD_UNCHANGED)
            original = cv2.imread(str(original_path))
            new_background = cv2.imread(str(bg_path))

            # Check if images loaded successfully
            if cut_object is None:
                raise ValueError(f"Failed to load cut object: {cut_obj_path}")
            if original is None:
                raise ValueError(
                    f"Failed to load original image: {original_path}")
            if new_background is None:
                raise ValueError(f"Failed to load background: {bg_path}")

            # Convert BGR to RGB with proper error handling
            if len(cut_object.shape) == 3:
                if cut_object.shape[2] == 3:
                    cut_object = cv2.cvtColor(cut_object, cv2.COLOR_BGR2RGB)
                elif cut_object.shape[2] == 4:
                    cut_object = cv2.cvtColor(cut_object, cv2.COLOR_BGRA2RGBA)
            else:
                raise ValueError(
                    f"Unexpected cut object shape: {cut_object.shape}")

            if len(original.shape) == 3 and original.shape[2] == 3:
                original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
            else:
                raise ValueError(
                    f"Unexpected original image shape: {original.shape}")

            if len(new_background.shape) == 3 and new_background.shape[2] == 3:
                new_background = cv2.cvtColor(
                    new_background, cv2.COLOR_BGR2RGB)
            else:
                raise ValueError(
                    f"Unexpected background shape: {new_background.shape}")

            # Resize images
            cut_object = cv2.resize(
                cut_object, (self.image_size, self.image_size))
            original = cv2.resize(original, (self.image_size, self.image_size))
            new_background = cv2.resize(
                new_background, (self.image_size, self.image_size))

            # Create border mask
            border_mask = self.create_border_mask(cut_object)

            # Create naive composite (input to generator)
            if len(cut_object.shape) == 3 and cut_object.shape[2] == 4:  # RGBA
                alpha = cut_object[:, :, 3:4] / 255.0
                cut_object_rgb = cut_object[:, :, :3]
            else:  # RGB
                alpha = np.any(cut_object > 10, axis=2,
                               keepdims=True).astype(np.float32)
                cut_object_rgb = cut_object

            # Ensure all arrays are float32 for consistency
            alpha = alpha.astype(np.float32)
            cut_object_rgb = cut_object_rgb.astype(np.float32)
            new_background = new_background.astype(np.float32)

            naive_composite = cut_object_rgb * \
                alpha + new_background * (1 - alpha)

            # Ground truth is the original image
            ground_truth = original.astype(np.float32)

            # Apply transforms
            if self.transform:
                # Ensure images are in uint8 range for PIL transforms
                naive_composite = np.clip(
                    naive_composite, 0, 255).astype(np.uint8)
                ground_truth = np.clip(ground_truth, 0, 255).astype(np.uint8)
                cut_object_rgb = np.clip(
                    cut_object_rgb, 0, 255).astype(np.uint8)
                new_background = np.clip(
                    new_background, 0, 255).astype(np.uint8)

                naive_composite = self.transform(naive_composite)
                ground_truth = self.transform(ground_truth)
                border_mask = torch.tensor(
                    border_mask, dtype=torch.float32).unsqueeze(0)
                cut_object_rgb = self.transform(cut_object_rgb)
                new_background = self.transform(new_background)
            else:
                # Convert to tensors if no transform
                naive_composite = torch.tensor(naive_composite.transpose(
                    2, 0, 1), dtype=torch.float32) / 255.0
                ground_truth = torch.tensor(ground_truth.transpose(
                    2, 0, 1), dtype=torch.float32) / 255.0
                border_mask = torch.tensor(
                    border_mask, dtype=torch.float32).unsqueeze(0)
                cut_object_rgb = torch.tensor(cut_object_rgb.transpose(
                    2, 0, 1), dtype=torch.float32) / 255.0
                new_background = torch.tensor(new_background.transpose(
                    2, 0, 1), dtype=torch.float32) / 255.0

            return {
                'naive_composite': naive_composite,
                'ground_truth': ground_truth,
                'border_mask': border_mask,
                'cut_object': cut_object_rgb,
                'new_background': new_background
            }

        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            print(f"Cut object path: {cut_obj_path}")
            print(f"Original path: {original_path}")
            print(f"Background path: {bg_path}")
            raise e


class Generator(nn.Module):
    def __init__(self, input_channels=3, output_channels=3):
        super(Generator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            # Input: 3 x 256 x 256
            nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 64 x 128 x 128
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 128 x 64 x 64
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # 256 x 32 x 32
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # 512 x 16 x 16
            nn.Conv2d(512, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )

        # Decoder
        self.decoder = nn.Sequential(
            # 512 x 8 x 8
            nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            # 512 x 16 x 16
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            # 256 x 32 x 32
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            # 128 x 64 x 64
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 64 x 128 x 128
            nn.ConvTranspose2d(64, output_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # 3 x 256 x 256
        )

        # Skip connections layers
        self.skip_conv1 = nn.Conv2d(64, 64, 1)
        self.skip_conv2 = nn.Conv2d(128, 128, 1)
        self.skip_conv3 = nn.Conv2d(256, 256, 1)
        self.skip_conv4 = nn.Conv2d(512, 512, 1)

    def forward(self, x):
        # Encode with skip connections
        enc1 = self.encoder[:2](x)  # 64 x 128 x 128
        enc2 = self.encoder[2:5](enc1)  # 128 x 64 x 64
        enc3 = self.encoder[5:8](enc2)  # 256 x 32 x 32
        enc4 = self.encoder[8:11](enc3)  # 512 x 16 x 16
        enc5 = self.encoder[11:](enc4)  # 512 x 8 x 8

        # Decode with skip connections
        dec1 = self.decoder[:3](enc5)  # 512 x 16 x 16
        dec1 = dec1 + self.skip_conv4(enc4)

        dec2 = self.decoder[3:6](dec1)  # 256 x 32 x 32
        dec2 = dec2 + self.skip_conv3(enc3)

        dec3 = self.decoder[6:9](dec2)  # 128 x 64 x 64
        dec3 = dec3 + self.skip_conv2(enc2)

        dec4 = self.decoder[9:12](dec3)  # 64 x 128 x 128
        dec4 = dec4 + self.skip_conv1(enc1)

        output = self.decoder[12:](dec4)  # 3 x 256 x 256

        return output


class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        self.features = nn.Sequential(
            # Input: 3 x 256 x 256
            nn.Conv2d(input_channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            # 64 x 128 x 128
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            # 128 x 64 x 64
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            # 256 x 32 x 32
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            # 512 x 16 x 16
            nn.Conv2d(512, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 x 8 x 8
        )

        # Global average pooling to reduce spatial dimensions
        self.global_pool = nn.AdaptiveAvgPool2d(1)

        # Final classification layer
        self.classifier = nn.Sequential(
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.features(x)
        pooled = self.global_pool(features)
        flattened = pooled.view(pooled.size(0), -1)  # [batch_size, 512]
        output = self.classifier(flattened)
        return output.squeeze(1)  # [batch_size]


class BorderGAN:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device

        # Initialize networks
        self.generator = Generator().to(device)
        self.discriminator = Discriminator().to(device)

        # Loss functions
        self.adversarial_loss = nn.BCELoss()
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()

        # Optimizers
        self.g_optimizer = optim.Adam(
            self.generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
        self.d_optimizer = optim.Adam(
            self.discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

        # Initialize weights
        self.generator.apply(self.weights_init)
        self.discriminator.apply(self.weights_init)

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def train_step(self, batch):
        naive_composite = batch['naive_composite'].to(self.device)
        ground_truth = batch['ground_truth'].to(self.device)
        border_mask = batch['border_mask'].to(self.device)

        batch_size = naive_composite.size(0)

        # Labels for adversarial loss
        real_label = torch.ones(batch_size, device=self.device)
        fake_label = torch.zeros(batch_size, device=self.device)

        # Train Discriminator
        self.d_optimizer.zero_grad()

        # Real images
        real_output = self.discriminator(ground_truth)
        d_real_loss = self.adversarial_loss(real_output, real_label)

        # Fake images
        fake_images = self.generator(naive_composite)
        fake_output = self.discriminator(fake_images.detach())
        d_fake_loss = self.adversarial_loss(fake_output, fake_label)

        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        self.d_optimizer.step()

        # Train Generator
        self.g_optimizer.zero_grad()

        # Adversarial loss
        fake_output = self.discriminator(fake_images)
        g_adv_loss = self.adversarial_loss(fake_output, real_label)

        # L1 loss (focus on border regions)
        l1_loss = self.l1_loss(fake_images * border_mask,
                               ground_truth * border_mask)

        # Perceptual loss (MSE on the entire image)
        perceptual_loss = self.mse_loss(fake_images, ground_truth)

        # Combined generator loss
        g_loss = g_adv_loss + 100 * l1_loss + 10 * perceptual_loss
        g_loss.backward()
        self.g_optimizer.step()

        return {
            'd_loss': d_loss.item(),
            'g_loss': g_loss.item(),
            'g_adv_loss': g_adv_loss.item(),
            'l1_loss': l1_loss.item(),
            'perceptual_loss': perceptual_loss.item()
        }

    def train(self, dataloader, num_epochs=100, save_interval=10):
        self.generator.train()
        self.discriminator.train()

        for epoch in range(num_epochs):
            epoch_d_loss = 0
            epoch_g_loss = 0

            for i, batch in enumerate(dataloader):
                losses = self.train_step(batch)
                epoch_d_loss += losses['d_loss']
                epoch_g_loss += losses['g_loss']

                if i % 100 == 0:
                    print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], '
                          f'D_loss: {losses["d_loss"]:.4f}, G_loss: {losses["g_loss"]:.4f}')

            avg_d_loss = epoch_d_loss / len(dataloader)
            avg_g_loss = epoch_g_loss / len(dataloader)

            print(
                f'Epoch [{epoch+1}/{num_epochs}], Avg D_loss: {avg_d_loss:.4f}, Avg G_loss: {avg_g_loss:.4f}')

            if (epoch + 1) % save_interval == 0:
                self.save_models(f'border_gan_epoch_{epoch+1}.pth')
                self.generate_samples(dataloader, epoch+1)

    def save_models(self, path):
        torch.save({
            'generator_state_dict': self.generator.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'g_optimizer_state_dict': self.g_optimizer.state_dict(),
            'd_optimizer_state_dict': self.d_optimizer.state_dict(),
        }, path)

    def load_models(self, path):
        checkpoint = torch.load(path)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        self.d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])

    def generate_samples(self, dataloader, epoch):
        self.generator.eval()
        with torch.no_grad():
            batch = next(iter(dataloader))
            naive_composite = batch['naive_composite'][:4].to(self.device)
            ground_truth = batch['ground_truth'][:4].to(self.device)

            fake_images = self.generator(naive_composite)

            # Save comparison images
            comparison = torch.cat(
                [naive_composite, fake_images, ground_truth], dim=3)
            save_image(
                comparison, f'samples_epoch_{epoch}.png', nrow=1, normalize=True)

        self.generator.train()

    def blend_image(self, cut_object_path, new_background_path, output_path):
        """Blend a cut object with a new background using the trained generator"""
        self.generator.eval()

        transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

        # Load and preprocess images
        cut_object = cv2.imread(cut_object_path, cv2.IMREAD_UNCHANGED)
        new_background = cv2.imread(new_background_path)

        cut_object = cv2.cvtColor(cut_object, cv2.COLOR_BGR2RGB) if cut_object.shape[2] == 3 else cv2.cvtColor(
            cut_object, cv2.COLOR_BGRA2RGBA)
        new_background = cv2.cvtColor(new_background, cv2.COLOR_BGR2RGB)

        # Resize
        cut_object = cv2.resize(cut_object, (256, 256))
        new_background = cv2.resize(new_background, (256, 256))

        # Create naive composite
        if cut_object.shape[2] == 4:
            alpha = cut_object[:, :, 3:4] / 255.0
            cut_object_rgb = cut_object[:, :, :3]
        else:
            alpha = np.any(cut_object > 10, axis=2,
                           keepdims=True).astype(np.float32)
            cut_object_rgb = cut_object

        naive_composite = cut_object_rgb * alpha + new_background * (1 - alpha)

        # Transform and generate
        naive_composite_tensor = transform(
            naive_composite.astype(np.uint8)).unsqueeze(0).to(self.device)

        with torch.no_grad():
            blended_image = self.generator(naive_composite_tensor)

        # Save result
        save_image(blended_image, output_path, normalize=True)
        print(f"Blended image saved to {output_path}")


def main():
    # Configuration
    cut_objects_path = "cut_objects"
    original_images_path = "original_images"
    new_backgrounds_path = "new_backgrounds"

    # Data transforms
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create dataset and dataloader
    dataset = BorderDataset(
        cut_objects_path, original_images_path, new_backgrounds_path, transform)

    # Check if dataset is valid
    if len(dataset) == 0:
        print("Error: Dataset is empty. Check your file paths and ensure images exist.")
        return

    print(f"Dataset size: {len(dataset)}")
    print(f"Number of matched pairs: {len(dataset.matched_pairs)}")
    print(f"Number of backgrounds: {len(dataset.new_backgrounds)}")

    # Test loading a single item first
    try:
        sample = dataset[0]
        print("Successfully loaded first sample")
        print(f"Sample keys: {sample.keys()}")
        for key, value in sample.items():
            if isinstance(value, torch.Tensor):
                print(f"{key} shape: {value.shape}")
    except Exception as e:
        print(f"Error loading first sample: {e}")
        return

    # Use num_workers=0 to avoid multiprocessing issues during debugging
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

    # Test the dataloader
    try:
        batch = next(iter(dataloader))
        print("Successfully loaded first batch")
        print(f"Batch size: {batch['naive_composite'].shape[0]}")
    except Exception as e:
        print(f"Error loading batch: {e}")
        return

    # Initialize and train model
    border_gan = BorderGAN()

    # Train the model
    border_gan.train(dataloader, num_epochs=1, save_interval=1)

    # Example usage for blending
    # border_gan.load_models('border_gan_epoch_100.pth')
    # border_gan.blend_image('cut_objects/example.png', 'new_backgrounds/bg1.jpg', 'result.png')


if __name__ == "__main__":
    main()

Dataset size: 51
Number of matched pairs: 51
Number of backgrounds: 1
Successfully loaded first sample
Sample keys: dict_keys(['naive_composite', 'ground_truth', 'border_mask', 'cut_object', 'new_background'])
naive_composite shape: torch.Size([3, 256, 256])
ground_truth shape: torch.Size([3, 256, 256])
border_mask shape: torch.Size([1, 256, 256])
cut_object shape: torch.Size([3, 256, 256])
new_background shape: torch.Size([3, 256, 256])
Successfully loaded first batch
Batch size: 4
Epoch [1/1], Step [1/13], D_loss: 0.6988, G_loss: 5.5829
Epoch [1/1], Avg D_loss: 0.5976, Avg G_loss: 3.9409


In [None]:
border_gan = BorderGAN()
border_gan.load_models('border_gan_epoch_1.pth')
border_gan.blend_image('cut_objects/Tardigrade_01_0002.png',
                       'new_backgrounds/Bandwurm_01_0001.png', 'result.png')

Blended image saved to result.png
