In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from tqdm import tqdm

In [2]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def load_dataset(folder, transform):
    """Load images from a folder and apply the given transform, returning a TensorDataset."""
    images = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(('.jpg', '.png'))]
    tensors = [transform(Image.open(img).convert('RGB')) for img in images]
    return TensorDataset(torch.stack(tensors))  # Changed to return TensorDataset instead of a list



In [4]:
# Define the transform for all datasets
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])


In [5]:

class StyleDataset(Dataset):
    def __init__(self, image_folder, style_vector, transform=None):
        """Dataset for loading images with associated style vectors."""
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]
        self.style_vector = torch.tensor(style_vector, dtype=torch.float)
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.style_vector

In [6]:

# Define the style vectors
vangogh_style_vector = [1, 0]
monet_style_vector = [0, 1]

# Create datasets
vangogh_dataset = StyleDataset("cleandata/augmented_vangogh", vangogh_style_vector, transform)
monet_dataset = StyleDataset("cleandata/augmented_monet", monet_style_vector, transform)

# Combine datasets
style_dataset = ConcatDataset([vangogh_dataset, monet_dataset])

# Load content images as a TensorDataset (fixed the list issue)
content_dataset = load_dataset("cleandata/augmented_content", transform)  # Now returns TensorDataset

# Create DataLoaders
content_loader = DataLoader(content_dataset, batch_size=4, shuffle=True,num_workers=4)  # Adjusted to use TensorDataset
vangogh_loader = DataLoader(vangogh_dataset, batch_size=4, shuffle=True,num_workers=4)
monet_loader = DataLoader(monet_dataset, batch_size=4, shuffle=True,num_workers=4)
style_loader = DataLoader(style_dataset, batch_size=4, shuffle=True,num_workers=4)

print("✅ Dataset, ConcatDataset, and DataLoaders are ready!")

✅ Dataset, ConcatDataset, and DataLoaders are ready!


In [7]:
# #recent generator
# def build_generator():
#     layers = []
#     # Downsampling (More layers, larger filters)
#     layers += [
#         nn.Conv2d(5, 64, 7, 1, 3, bias=False),
#         nn.BatchNorm2d(64),
#         nn.ReLU(True),
#         nn.Conv2d(64, 128, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(128),
#         nn.ReLU(True),
#         nn.Conv2d(128, 256, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(256),
#         nn.ReLU(True),
#         nn.Conv2d(256, 512, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(512),
#         nn.ReLU(True),
#         nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(1024),
#         nn.ReLU(True)
#     ]
#     # Residual blocks
#     for _ in range(6):
#         layers += [
#             nn.Conv2d(1024, 1024, 3, 1, 1, bias=False),
#             nn.BatchNorm2d(1024),
#             nn.ReLU(True),
#             nn.Conv2d(1024, 1024, 3, 1, 1, bias=False),
#             nn.BatchNorm2d(1024)
#         ]
#     # Upsampling
#     layers += [
#         nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(512),
#         nn.ReLU(True),
#         nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(256),
#         nn.ReLU(True),
#         nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(128),
#         nn.ReLU(True),
#         nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(64),
#         nn.ReLU(True),
#         nn.Conv2d(64, 3, 7, 1, 3, bias=False),
#         nn.Tanh()
#     ]
#     return nn.Sequential(*layers).to(device)

In [8]:
def build_generator():
    def res_block(channels):
        return nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=0, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=0, bias=False),
            nn.InstanceNorm2d(channels),
        )
    
    input_channels = 3
    output_channels = 3
    num_residual_blocks = 9
    
    layers = []
    layers += [
        nn.ReflectionPad2d(3),
        nn.Conv2d(input_channels, 64, kernel_size=7, stride=1, padding=0, bias=False),
        nn.InstanceNorm2d(64),
        nn.ReLU(True)
    ]
    
    layers += [
        nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
        nn.InstanceNorm2d(128),
        nn.ReLU(True),
        nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
        nn.InstanceNorm2d(256),
        nn.ReLU(True),
    ]
    
    for _ in range(num_residual_blocks):
        layers.append(res_block(256))
    
    layers += [
        nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
        nn.InstanceNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
        nn.InstanceNorm2d(64),
        nn.ReLU(True),
    ]
    
    layers += [
        nn.ReflectionPad2d(3),
        nn.Conv2d(64, output_channels, kernel_size=7, stride=1, padding=0, bias=False),
        nn.Tanh()
    ]
    
    return nn.Sequential(*layers)

In [9]:

# def build_discriminator():
#     layers = [
#         nn.Conv2d(3, 64, 4, 2, 1, bias=False),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(64, 128, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(128),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(128, 256, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(256),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(256, 1, 4, 1, 1, bias=False),
#         nn.Sigmoid()
#     ]
#     return nn.Sequential(*layers).to(device)

In [10]:
# import torch.nn as nn

# class ResidualBlock(nn.Module):
#     def __init__(self, in_channels):
#         super(ResidualBlock, self).__init__()
#         self.conv = nn.Sequential(
#             nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=False),
#             nn.BatchNorm2d(in_channels),
#             nn.LeakyReLU(0.2, True),
#             nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=False),
#             nn.BatchNorm2d(in_channels)
#         )

#     def forward(self, x):
#         return x + self.conv(x)

# def build_discriminator():
#     layers = [
#         nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # 3x256x256 -> 64x128x128
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # 64x128x128 -> 128x64x64
#         nn.BatchNorm2d(128),
#         nn.LeakyReLU(0.2, True),
#         ResidualBlock(128),  # Add residual block
#         nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # 128x64x64 -> 256x32x32
#         nn.BatchNorm2d(256),
#         nn.LeakyReLU(0.2, True),
#         ResidualBlock(256),  # Add residual block
#         nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # 256x32x32 -> 512x16x16
#         nn.BatchNorm2d(512),
#         nn.LeakyReLU(0.2, True),
#         ResidualBlock(512),  # Add residual block
#         nn.Conv2d(512, 1024, 4, 2, 1, bias=False),  # 512x16x16 -> 1024x8x8
#         nn.BatchNorm2d(1024),
#         nn.LeakyReLU(0.2, True),
#         ResidualBlock(1024),  # Add residual block
#         nn.Conv2d(1024, 1, 4, 1, 1, bias=False),  # 1024x8x8 -> 1x7x7
#         nn.Sigmoid()  # Output: Probability that the image is real (0-1)
#     ]
#     return nn.Sequential(*layers).to(device)


In [11]:

# def build_discriminator():
#     layers = [
#         nn.Conv2d(3, 64, 4, 2, 1, bias=False),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(64, 128, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(128),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(128, 128, 3, 1, 1, bias=False),
#         nn.BatchNorm2d(128),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(128, 256, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(256),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(256, 256, 3, 1, 1, bias=False),
#         nn.BatchNorm2d(256),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(256, 512, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(512),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(512, 512, 3, 1, 1, bias=False),
#         nn.BatchNorm2d(512),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
#         nn.BatchNorm2d(1024),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(1024, 1024, 3, 1, 1, bias=False),
#         nn.BatchNorm2d(1024),
#         nn.LeakyReLU(0.2, True),
#         nn.Conv2d(1024, 1, 4, 1, 1, bias=False),
#         nn.Sigmoid()
#     ]
#     return nn.Sequential(*layers).to(device)



def build_discriminator():
    layers = [
        nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # Downsampling
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(64, 128, 4, 2, 1, bias=False),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(128, 256, 4, 2, 1, bias=False),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(256, 512, 4, 2, 1, bias=False),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(512, 1, 4, 1, 0, bias=False),  # Patch-level decision (30x30 output for 256x256 input)
    ]
    return nn.Sequential(*layers)



# TRAIN STEP

### OverView 

Generator training : making the fake images realistic 

Discriminator traning :  Distinguishing real vs fake images

CycleConsistency :  Ensuring original images can be reconstructed

Indentity Loss :  Preserving input images when necessery

In [12]:
def train_step(generator, discriminator, content_imgs, style_imgs, style_vector, opt_gen, opt_disc, criterion, cycle_criterion):
    content_imgs = content_imgs.to(device)
    style_imgs = style_imgs.to(device)
    style_vector = style_vector.to(device)

    batch_size, _, h, w = content_imgs.size()
    style_vector = style_vector.view(batch_size, 2, 1, 1).expand(-1, -1, h, w)
    combined_input = torch.cat([content_imgs, style_vector], 1)

    fake_style = generator(combined_input)
    real_pred = discriminator(style_imgs)
    fake_pred = discriminator(fake_style.detach())

    loss_disc = criterion(real_pred, torch.ones_like(real_pred)) + criterion(fake_pred, torch.zeros_like(fake_pred))

    opt_disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()

    fake_pred_for_gen = discriminator(fake_style)
    loss_gen = criterion(fake_pred_for_gen, torch.ones_like(fake_pred_for_gen))
    
    cycle_loss = cycle_criterion(content_imgs, generator(torch.cat([fake_style, style_vector], 1)))
    total_loss = loss_gen + 10 * cycle_loss

    opt_gen.zero_grad()
    total_loss.backward()
    opt_gen.step()

    return loss_disc.item(), total_loss.item()

In [13]:
def train_model(generator, discriminator, content_loader, vangogh_loader, monet_loader, 
                train_step, opt_gen, opt_disc, criterion, cycle_criterion, device, epochs=20):
    best_gen_loss = float('inf')  # Track the best generator loss

    for epoch in range(epochs):
        epoch_gen_loss = 0
        epoch_disc_loss = 0
        batch_count = 0

        # Create iterators for style loaders
        vangogh_iter = iter(vangogh_loader)
        monet_iter = iter(monet_loader)

        pbar = tqdm(total=len(content_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch", dynamic_ncols=True)

        for content_batch in content_loader:
            content_batch = content_batch[0].to(device)  # Ensure content_batch is a tensor and on the device

            if torch.rand(1).item() < 0.5:
                style_batch = next(vangogh_iter, None)
                style_vector = torch.tensor([[1, 0]] * content_batch.size(0), dtype=torch.float, device=device)
            else:
                style_batch = next(monet_iter, None)
                style_vector = torch.tensor([[0, 1]] * content_batch.size(0), dtype=torch.float, device=device)

            if style_batch is None:
                vangogh_iter = iter(vangogh_loader)
                monet_iter = iter(monet_loader)
                continue

            style_batch = style_batch[0].to(device)

            # Perform a training step
            loss_disc, loss_gen = train_step(generator, discriminator, content_batch, style_batch,
                                             style_vector, opt_gen, opt_disc, criterion, cycle_criterion)

            epoch_gen_loss += loss_gen
            epoch_disc_loss += loss_disc
            batch_count += 1

            pbar.set_postfix({
                "Gen Loss": f"{epoch_gen_loss / batch_count:.4f}",
                "Disc Loss": f"{epoch_disc_loss / batch_count:.4f}",
                "Batches": batch_count
            })
            pbar.update(1)

        pbar.close()

        avg_gen_loss = epoch_gen_loss / batch_count
        avg_disc_loss = epoch_disc_loss / batch_count
        print(f"✅ Epoch {epoch+1}/{epochs} - Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}\\n")

        if avg_gen_loss < best_gen_loss:
            best_gen_loss = avg_gen_loss
            model_path = "models/best_generator_epoch_150.pth"
            # Save the entire model (architecture + weights)
            torch.save(generator, model_path)
            print(f"📁 Model improved! Saved as {model_path}")

    # Save the final model
    final_model_path = "models/final_generator.pth"
    # Save the entire model (architecture + weights)
    torch.save(generator, final_model_path)
    print(f"🎯 Training complete. Final model saved as {final_model_path}")


In [None]:
# Initialize the generator and discriminator models
generator = build_generator()
discriminator = build_discriminator()

# Initialize the optimizers
opt_gen = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Initialize the loss functions
criterion = nn.MSELoss()  # Mean Squareed Error Loss function
cycle_criterion = nn.L1Loss()  # L1 loss for cycle consistency

# Call the training function
train_model(generator, discriminator, content_loader, vangogh_loader, monet_loader,
            train_step, opt_gen, opt_disc, criterion, cycle_criterion, device, epochs=150)


In [None]:
def generate_styled_image(model_path, content_img, style_type):
    content_img = transform(content_img).unsqueeze(0).to(device)

    # Load the entire generator model
    generator = torch.load(model_path)
    generator.eval()

    # Set style vector based on style_type
    style_vector = torch.tensor([1, 0] if style_type == 'vangogh' else [0, 1], dtype=torch.float).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(device)

    with torch.no_grad():
        fake_image = generator(torch.cat([content_img, style_vector], 1))

    fake_image = fake_image.squeeze().cpu().numpy().transpose(1, 2, 0)
    fake_image = np.clip(fake_image * 255, 0, 255).astype(np.uint8)
    
    return Image.fromarray(fake_image)


In [None]:
from PIL import Image
import torch

# Specify the path to your saved model
model_path = "best_generator_epoch.pth"  # or final_generator.pth

# Load a sample content image (ensure it's a PIL image)
content_img_path = r"data\ContentImage\2015-04-29 16_19_50.jpg"
content_img = Image.open(content_img_path)

# Specify the style you want ('vangogh' or 'monet')
style_type = 'vangogh'  # or 'monet'

# Call the generate_styled_image function to get the styled image
styled_image = generate_styled_image(model_path, content_img, style_type)

# Show the generated image
styled_image.show()

# Optionally, save the result if needed
styled_image.save("generated_image.jpg")


FileNotFoundError: [Errno 2] No such file or directory: 'data\\ContentImage\x815-04-29 16_19_50.jpg'