In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
from tqdm import tqdm


In [19]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
device

device(type='cuda')

# IMPORT DATASET

This the load datset function

In [21]:
def load_dataset(folder, transform):
    images = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(('.jpg', '.png'))]
    return [transform(Image.open(img).convert('RGB')) for img in images]

Tranform the dataset for reducing the training time by reducing the size of the image

In [22]:
# Prepare datasets
transform = transforms.Compose([
    transforms.Resize((200, 200)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3)
])

Loading the datset to vvariable

In [23]:
from torch.utils.data import Dataset

class StyleDataset(Dataset):
    def __init__(self, image_folder, style_vector, transform=None):
        self.image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith(('.jpg', '.png'))]
        self.style_vector = style_vector
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, self.style_vector

In [24]:
# Define the style vectors
vangogh_style_vector = torch.tensor([1, 0], dtype=torch.float)
monet_style_vector = torch.tensor([0, 1], dtype=torch.float)

# Create datasets
vangogh_dataset = StyleDataset("cleandata/augmented_vangogh", vangogh_style_vector, transform)
monet_dataset = StyleDataset("cleandata/augmented_monet", monet_style_vector, transform)

# Combine datasets
style_dataset = torch.utils.data.ConcatDataset([vangogh_dataset, monet_dataset])

In [25]:
content_images = load_dataset("cleandata/augmented_content", transform)
# Load datasets
vangogh_images = load_dataset("cleandata/augmented_vangogh", transform)
monet_images = load_dataset("cleandata/augmented_monet", transform)



# vangogh_images = [(img, torch.tensor([1, 0])) for img in load_dataset("cleandata/augmented_vangogh", transform)]
# monet_images = [(img, torch.tensor([0, 1])) for img in load_dataset("cleandata/augmented_monet", transform)]
# style_images = vangogh_images + monet_images

Converting the images into batches

In [26]:
content_loader = DataLoader(content_images, batch_size=32, shuffle=True)
# style_loader = DataLoader(style_dataset, batch_size=32, shuffle=True)
# Create DataLoaders
vangogh_loader = DataLoader(vangogh_images, batch_size=32, shuffle=True)
monet_loader = DataLoader(monet_images, batch_size=32, shuffle=True)

# BUILDING THE GENERATOR


In [27]:
def build_generator():
    layers = []
    # Downsampling
    layers += [
        nn.Conv2d(5, 64, 7, 1, 3, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        nn.Conv2d(64, 128, 3, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.Conv2d(128, 256, 3, 2, 1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True)
    ]
    # Residual blocks
    for _ in range(4):
        layers += [
            nn.Conv2d(256, 256, 3, 1, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3, 1, 1, bias=False),
            nn.BatchNorm2d(256)
        ]
    # Upsampling
    layers += [
        nn.ConvTranspose2d(256, 128, 3, 2, 1, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),
        nn.ConvTranspose2d(128, 64, 3, 2, 1, 1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),
        nn.Conv2d(64, 3, 7, 1, 3, bias=False),
        nn.Tanh()
    ]
    return nn.Sequential(*layers).to(device)

# BUILDING DESCRIMINATOR

In [28]:
def build_discriminator():
    layers = [
        nn.Conv2d(3, 64, 4, 2, 1, bias=False),
        nn.LeakyReLU(0.2, True),
        nn.Conv2d(64, 128, 4, 2, 1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2, True),
        nn.Conv2d(128, 256, 4, 2, 1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2, True),
        nn.Conv2d(256, 1, 4, 1, 1, bias=False),
        nn.Sigmoid()
    ]
    return nn.Sequential(*layers).to(device)


# TRAINING FUNCTION

In [29]:
def train_step(generator, discriminator, content_imgs, style_imgs, style_vector, opt_gen, opt_disc, criterion, cycle_criterion):
    
    
    content_imgs = content_imgs.to(device)
    style_imgs = style_imgs.to(device)
    style_vector = style_vector.to(device)

    batch_size, _, h, w = content_imgs.size()
    style_vector = style_vector.view(batch_size, 2, 1, 1).expand(-1, -1, h, w)
    combined_input = torch.cat([content_imgs, style_vector], 1)

    fake_style = generator(combined_input)
    real_pred = discriminator(style_imgs)
    fake_pred = discriminator(fake_style.detach())

    loss_disc = criterion(real_pred, torch.ones_like(real_pred)) + criterion(fake_pred, torch.zeros_like(fake_pred))

    opt_disc.zero_grad()
    loss_disc.backward()
    opt_disc.step()

    fake_pred_for_gen = discriminator(fake_style)
    loss_gen = criterion(fake_pred_for_gen, torch.ones_like(fake_pred_for_gen))

    cycle_loss = cycle_criterion(content_imgs, generator(torch.cat([fake_style, style_vector], 1)))
    total_loss = loss_gen + 10 * cycle_loss

    opt_gen.zero_grad()
    total_loss.backward()
    opt_gen.step()

    return loss_disc.item(), total_loss.item()

In [30]:
# Initialize models and optimizers
generator = build_generator()
discriminator = build_discriminator()
opt_gen = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_disc = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
criterion = nn.BCELoss()
cycle_criterion = nn.L1Loss()

# TRAINING

In [31]:
# import torch
# from tqdm import tqdm

# # Training loop
# best_loss = float('inf')
# stagnant_epochs = 0

# for epoch in range(20):
#     epoch_gen_loss = 0
#     epoch_disc_loss = 0
#     batch_count = 0

#     with tqdm(total=len(content_loader), desc=f"Epoch {epoch+1}", unit="batch") as pbar:
#         for content_batch, style_batch in zip(content_loader, style_loader):
#             style_vector = torch.tensor([[1, 0]] * content_batch.size(0), dtype=torch.float, device=device)
#             loss_disc, loss_gen = train_step(generator, discriminator, content_batch, style_batch, style_vector, opt_gen, opt_disc, criterion, cycle_criterion)

#             epoch_gen_loss += loss_gen
#             epoch_disc_loss += loss_disc
#             batch_count += 1

#             pbar.set_postfix({
#                 "Gen Loss": f"{loss_gen:.4f}",
#                 "Disc Loss": f"{loss_disc:.4f}",
#                 "Batch": batch_count
#             })
#             pbar.update(1)

#     avg_gen_loss = epoch_gen_loss / batch_count
#     avg_disc_loss = epoch_disc_loss / batch_count

#     print(f"Epoch {epoch+1} completed. Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}")

#     if avg_gen_loss < best_loss:
#         best_loss = avg_gen_loss
#         stagnant_epochs = 0
#         torch.save(generator.state_dict(), "best_model.h5")
#         print("Saved new best model.")
#     else:
#         stagnant_epochs += 1

#     if stagnant_epochs >= 8:
#         print("Early stopping can be triggered due to no improvement in generator loss for 5 consecutive epochs.")
        

# torch.save(generator.state_dict(), "final_model.h5")
# print("Training complete. Final model saved as final_model.h5")


In [32]:
# import torch
# from tqdm import tqdm
# epochs=20
# best_gen_loss = float('inf')  # Track the best generator loss

# for epoch in range(epochs):
#     epoch_gen_loss = 0
#     epoch_disc_loss = 0
#     batch_count = 0

#     # Create iterators for style loaders
#     vangogh_iter = iter(vangogh_loader)
#     monet_iter = iter(monet_loader)

#     # Initialize a single progress bar for the entire epoch
#     pbar = tqdm(total=len(content_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch", dynamic_ncols=True)

#     for content_batch in content_loader:
#         # Randomly select a style
#         if torch.rand(1).item() < 0.5:
#             style_batch = next(vangogh_iter, None)
#             style_vector = torch.tensor([[1, 0]] * content_batch.size(0), dtype=torch.float, device=device)
#         else:
#             style_batch = next(monet_iter, None)
#             style_vector = torch.tensor([[0, 1]] * content_batch.size(0), dtype=torch.float, device=device)

#         # Reinitialize iterators if a style batch is None
#         if style_batch is None:
#             vangogh_iter = iter(vangogh_loader)
#             monet_iter = iter(monet_loader)
#             continue

#         # Perform a training step
#         loss_disc, loss_gen = train_step(generator, discriminator, content_batch, style_batch, style_vector, opt_gen, opt_disc, criterion, cycle_criterion)

#         epoch_gen_loss += loss_gen
#         epoch_disc_loss += loss_disc
#         batch_count += 1

#         # Update the progress bar with average losses
#         pbar.set_postfix({
#             "Gen Loss": f"{epoch_gen_loss / batch_count:.4f}",
#             "Disc Loss": f"{epoch_disc_loss / batch_count:.4f}",
#             "Batches": batch_count
#         })
#         pbar.update(1)

#     pbar.close()

#     # Calculate average losses for the epoch
#     avg_gen_loss = epoch_gen_loss / batch_count
#     avg_disc_loss = epoch_disc_loss / batch_count
#     print(f"✅ Epoch {epoch+1}/{epochs} - Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}\n")

#     # Save the model if the generator's performance improves
#     if avg_gen_loss < best_gen_loss:
#         best_gen_loss = avg_gen_loss
#         model_path = f"best_generator_epoch.h5"
#         torch.save(generator.state_dict(), model_path)
#         print(f"📁 Model improved! Saved as {model_path}")


In [33]:
import torch
from tqdm import tqdm

def train_model(generator, discriminator, content_loader, vangogh_loader, monet_loader, 
                train_step, opt_gen, opt_disc, criterion, cycle_criterion, device, epochs=20):
    best_gen_loss = float('inf')  # Track the best generator loss

    for epoch in range(epochs):
        epoch_gen_loss = 0
        epoch_disc_loss = 0
        batch_count = 0

        # Create iterators for style loaders
        vangogh_iter = iter(vangogh_loader)
        monet_iter = iter(monet_loader)

        # Initialize a single progress bar for the epoch
        pbar = tqdm(total=len(content_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch", dynamic_ncols=True)

        for content_batch in content_loader:
            # Ensure images have 4 dimensions (batch_size, channels, height, width)
            if content_batch.dim() == 3:
                content_batch = content_batch.unsqueeze(1).to(device)  # Add channel dimension if missing
            else:
                content_batch = content_batch.to(device)

            # Randomly select a style
            if torch.rand(1).item() < 0.5:
                style_batch = next(vangogh_iter, None)
                style_vector = torch.tensor([[1, 0]] * content_batch.size(0), dtype=torch.float, device=device)
            else:
                style_batch = next(monet_iter, None)
                style_vector = torch.tensor([[0, 1]] * content_batch.size(0), dtype=torch.float, device=device)

            # Reinitialize iterators if a style batch is None
            if style_batch is None:
                vangogh_iter = iter(vangogh_loader)
                monet_iter = iter(monet_loader)
                continue

            # Ensure style images also have the correct dimensions
            if style_batch.dim() == 3:
                style_batch = style_batch.unsqueeze(1).to(device)
            else:
                style_batch = style_batch.to(device)

            # Perform a training step
            loss_disc, loss_gen = train_step(generator, discriminator, content_batch, style_batch,
                                             style_vector, opt_gen, opt_disc, criterion, cycle_criterion)

            epoch_gen_loss += loss_gen
            epoch_disc_loss += loss_disc
            batch_count += 1

            # Update the progress bar with average losses
            pbar.set_postfix({
                "Gen Loss": f"{epoch_gen_loss / batch_count:.4f}",
                "Disc Loss": f"{epoch_disc_loss / batch_count:.4f}",
                "Batches": batch_count
            })
            pbar.update(1)

        pbar.close()

        # Calculate average losses for the epoch
        avg_gen_loss = epoch_gen_loss / batch_count
        avg_disc_loss = epoch_disc_loss / batch_count
        print(f"✅ Epoch {epoch+1}/{epochs} - Generator Loss: {avg_gen_loss:.4f}, Discriminator Loss: {avg_disc_loss:.4f}\n")

        # Save the model if the generator's performance improves
        if avg_gen_loss < best_gen_loss:
            best_gen_loss = avg_gen_loss
            model_path = "best_generator_epoch.h5"
            torch.save(generator.state_dict(), model_path)
            print(f"📁 Model improved! Saved as {model_path}")

    # Save the final model after all epochs
    final_model_path = "final_generator.h5"
    #torch.save(generator.state_dict(), final_model_path)
    """Use this below to save the whole model
    """
    torch.save(generator, "best_generator_epoch.pth")
    print(f"🎯 Training complete. Final model saved as {final_model_path}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_model(
    generator=generator, 
    discriminator=discriminator, 
    content_loader=content_loader, 
    vangogh_loader=vangogh_loader, 
    monet_loader=monet_loader, 
    train_step=train_step, 
    opt_gen=opt_gen, 
    opt_disc=opt_disc, 
    criterion=criterion, 
    cycle_criterion=cycle_criterion, 
    device=device,
    epochs=50
)




In [3]:
import torch
from torchvision import transforms
from PIL import Image
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_styled_image(model_path, image_path, style_type):
    # Load the entire generator model
    generator = torch.load(model_path, map_location=device)
    generator.eval()

    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # Load and preprocess the content image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    # Prepare the style vector (ensure dimension match)
    style_vector = torch.tensor([[1, 0]] if style_type == "vangogh" else [[0, 1]], dtype=torch.float, device=device)
    style_vector = style_vector.unsqueeze(-1).unsqueeze(-1)  # Add spatial dims
    style_vector = style_vector.expand(-1, -1, 128, 128)  # Match image dimensions

    # Generate the styled image
    with torch.no_grad():
        combined_input = torch.cat([image, style_vector], dim=1)
        styled_image = generator(combined_input)

    # Convert the output image to a displayable format
    styled_image = styled_image.squeeze(0).cpu().detach()
    styled_image = (styled_image + 1) / 2  # Denormalize to [0, 1]

    # Save and display the styled image
    os.makedirs("output", exist_ok=True)
    save_path = f"output/styled_{style_type}.png"
    transforms.ToPILImage()(styled_image).save(save_path)

    print(f"✅ Styled image saved to: {save_path}")
    return Image.open(save_path)


# Call the function
model_path = "best_generator_epoch.pth"
image_path = r"data\ContentImage\2014-08-02 15_56_41.jpg"
style_type = "monat"

styled_image = generate_styled_image(model_path, image_path, style_type)
styled_image.show()


  generator = torch.load(model_path, map_location=device)


✅ Styled image saved to: output/styled_monat.png


# TESTING