In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import glob
import numpy as np
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [2]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Data Loading
class ImageDataset(Dataset):
    def __init__(self, photo_dir, monet_dir, transform=None):
        self.photo_images = sorted(glob.glob(os.path.join(photo_dir, '*.jpg')))
        self.monet_images = sorted(glob.glob(os.path.join(monet_dir, '*.jpg')))
        self.transform = transform

    def __len__(self):
        return len(self.photo_images)

    def __getitem__(self, idx):
        photo_img = Image.open(self.photo_images[idx]).convert('RGB')
        monet_img = Image.open(self.monet_images[idx % len(self.monet_images)]).convert('RGB')

        if self.transform:
            Photo_img = self.transform(photo_img)
            monet_img = self.transform(monet_img)

        return photo_img, monet_img

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
# Create dataset
photo_ds_path = r'cleandata\augmented_content'
monet_ds_path = r'cleandata\augmented_monet'
dataset = ImageDataset(photo_ds_path, monet_ds_path, transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True,num_workers=4)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=1, padding=3),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 3, 7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(256, 1, 4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.main(x)

In [7]:
# Initialize Models
monet_generator = Generator().to(device)
photo_generator = Generator().to(device)
monet_discriminator = Discriminator().to(device)
photo_discriminator = Discriminator().to(device)

In [8]:
# Loss Functions and Optimizers
criterion = nn.MSELoss()
optimizer_G = torch.optim.Adam(list(monet_generator.parameters()) + list(photo_generator.parameters()), lr=2e-4, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(list(monet_discriminator.parameters()) + list(photo_discriminator.parameters()), lr=2e-4, betas=(0.5, 0.999))

In [9]:
import torch
from tqdm import tqdm
import os

# Training Loop
def train(dataloader, epochs):
    best_loss_G = float('inf')  # Track best generator loss
    save_dir = "models"
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        print(f"\nEpoch [{epoch+1}/{epochs}]")
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=True)

        for real_photo, real_monet in progress_bar:
            real_photo, real_monet = real_photo.to(device), real_monet.to(device)

            # Train generators
            fake_monet = monet_generator(real_photo)
            fake_photo = photo_generator(real_monet)

            loss_G = criterion(monet_discriminator(fake_monet), torch.ones_like(monet_discriminator(fake_monet))) + \
                     criterion(photo_discriminator(fake_photo), torch.ones_like(photo_discriminator(fake_photo)))

            optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

            # Train discriminators
            loss_D_monet = criterion(monet_discriminator(real_monet), torch.ones_like(monet_discriminator(real_monet))) + \
                           criterion(monet_discriminator(fake_monet.detach()), torch.zeros_like(monet_discriminator(fake_monet)))

            loss_D_photo = criterion(photo_discriminator(real_photo), torch.ones_like(photo_discriminator(real_photo))) + \
                           criterion(photo_discriminator(fake_photo.detach()), torch.zeros_like(photo_discriminator(fake_photo)))

            optimizer_D.zero_grad()
            (loss_D_monet + loss_D_photo).backward()
            optimizer_D.step()

            # **Ensure loss values are computed before updating tqdm**
            loss_G_val = loss_G.item()
            loss_D_val = (loss_D_monet + loss_D_photo).item()

            # Update tqdm progress bar dynamically
            progress_bar.set_postfix({
                'Loss_G': f"{loss_G_val:.4f}",
                'Loss_D': f"{loss_D_val:.4f}"
            })
            progress_bar.refresh()  # Force update

        # Save best-performing generator
        if loss_G_val < best_loss_G:
            best_loss_G = loss_G_val
            torch.save(monet_generator, os.path.join(save_dir, "best_monet_generator_medium.pth"))
            torch.save(photo_generator, os.path.join(save_dir, "best_photo_generator_medium.pth"))
            print(f"Best model saved at Epoch {epoch+1} with Loss_G: {best_loss_G:.4f}")

        print(f"Epoch [{epoch+1}/{epochs}] completed: Loss_G={loss_G_val:.4f}, Loss_D={loss_D_val:.4f}")

    # Save final models after training
    torch.save(monet_generator, os.path.join(save_dir, "final_monet_generator_medium.pth"))
    torch.save(photo_generator, os.path.join(save_dir, "final_photo_generator_medium.pth"))
    print("\nTraining complete! Final models saved.")


In [None]:
# Train model
train(dataloader, epochs=5)


Epoch [1/5]


Epoch 1:   0%|          | 0/1634 [00:00<?, ?it/s]