In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.utils as vutils
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

# Custom Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, transform=None, image_size=(128, 128)):
        self.image_dir = image_dir
        self.transform = transform
        self.image_size = image_size
        self.image_files = [f for f in os.listdir(image_dir) 
                           if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        try:
            img_path = os.path.join(self.image_dir, self.image_files[idx])
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
            return image
        except Exception as e:
            print(f"Error loading image {self.image_files[idx]}: {str(e)}")
            return torch.zeros((3, *self.image_size))


# Generator

In [None]:

class Generator(nn.Module):
    def __init__(self, latent_dim, channels=3):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        
        self.model = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, z):
        z = z.view(-1, self.latent_dim, 1, 1)
        return self.model(z)

# Discriminator


In [None]:


class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x).view(-1, 1)


# Save 

In [None]:

def save_generator(generator, path='generator.pth'):
    torch.save(generator.state_dict(), path)
    print(f"Generator saved to {path}")

def load_generator(latent_dim, channels=3, path='generator.pth'):
    generator = Generator(latent_dim, channels)
    generator.load_state_dict(torch.load(path))
    generator.eval()
    return generator

def generate_images(generator, num_images=16, latent_dim=100, output_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = generator.to(device)
    generator.eval()
    
    with torch.no_grad():
        z = torch.randn(num_images, latent_dim).to(device)
        fake_images = generator(z)
        fake_images = (fake_images + 1) / 2
        
        if output_path:
            vutils.save_image(fake_images, output_path, normalize=True, nrow=4)
        
    return fake_images



# train

In [None]:
def train_gan(generator, discriminator, dataloader, num_epochs=100, latent_dim=100, save_interval=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    
    criterion = nn.BCELoss()
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = f"gan_training_{timestamp}"
    os.makedirs(save_dir, exist_ok=True)
    
    try:
        for epoch in range(num_epochs):
            for i, real_images in enumerate(dataloader):
                batch_size = real_images.size(0)
                real_images = real_images.to(device)
                
                d_optimizer.zero_grad()
                label_real = torch.ones(batch_size, 1).to(device)
                label_fake = torch.zeros(batch_size, 1).to(device)
                
                output_real = discriminator(real_images)
                d_loss_real = criterion(output_real, label_real)
                
                z = torch.randn(batch_size, latent_dim).to(device)
                fake_images = generator(z)
                output_fake = discriminator(fake_images.detach())
                d_loss_fake = criterion(output_fake, label_fake)
                
                d_loss = d_loss_real + d_loss_fake
                d_loss.backward()
                d_optimizer.step()
                
                g_optimizer.zero_grad()
                output_fake = discriminator(fake_images)
                g_loss = criterion(output_fake, label_real)
                g_loss.backward()
                g_optimizer.step()
            
            if (epoch + 1) % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] "
                      f"D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")
                
                sample_path = os.path.join(save_dir, f"samples_epoch_{epoch+1}.png")
                generate_images(generator, num_images=1, latent_dim=latent_dim, 
                              output_path=sample_path)
        
        final_path = os.path.join(save_dir, "generator_final.pth")
        save_generator(generator, final_path)
        
    except Exception as e:
        print(f"Training interrupted: {str(e)}")
        error_path = os.path.join(save_dir, "generator_error_state.pth")
        save_generator(generator, error_path)
        
    return generator


In [None]:

def main():
    latent_dim = 100
    batch_size = 16
    num_epochs = 300
    image_size = 128
    channels = 3
    
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    dataset = CustomDataset(
        image_dir=r'D:\fake_defect\sample_data\Positive',
        transform=transform,
        image_size=(image_size, image_size)
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    generator = Generator(latent_dim, channels)
    discriminator = Discriminator(channels)
    
    trained_generator = train_gan(generator, discriminator, dataloader, 
                                num_epochs, latent_dim, save_interval=50)
    
    generate_images(trained_generator, num_images=4, 
                   output_path="final_samples.png")





In [None]:
main()


# if __name__ == "__main__":
#     torch.multiprocessing.freeze_support()
#     torch.backends.cudnn.benchmark = True
#     main()