In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# Improved hyperparameters
BATCH_SIZE = 128  # Increased batch size for better statistics
IMAGE_SIZE = 64
CHANNELS = 1  # Grayscale
LATENT_DIM = 100
EPOCHS = 50  # More epochs for better convergence
LR_G = 1e-4  # Separate learning rates
LR_D = 4e-4  # Discriminator learns faster
BETAS = (0.5, 0.999)  # Adam optimizer parameters
N_CRITIC = 5  # Train discriminator more frequently for stability
LABEL_SMOOTHING = 0.9  # Prevent discriminator overconfidence

In [None]:
# Set device and ensure reproducibility
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Enhanced data loading with augmentation
transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=CHANNELS),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Grayscale(num_output_channels=CHANNELS),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_data = datasets.ImageFolder(
    "D:/projects/machine learning/Expression-recognition/jonathanheix dataset/images/train",
    transform=transform
)

train_loader = DataLoader(
    train_data, 
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,  # Increased workers for faster loading
    pin_memory=True,
    drop_last=True  # Avoid issues with small last batch
)

val_data = datasets.ImageFolder(
    "D:/projects/machine learning/Expression-recognition/jonathanheix dataset/images/validation",
    transform=val_transform
)

val_loader = DataLoader(
    val_data,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [None]:
# Spectral Normalization for stability
class SpectralNorm:
    def __init__(self, module, name='weight', power_iterations=1):
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _made_params(self):
        try:
            return hasattr(self.module, f'{self.name}_u')
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)
        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]
        
        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = F.normalize(u.data, dim=0)
        v.data = F.normalize(v.data, dim=0)
        
        setattr(self.module, f'{self.name}_u', u)
        setattr(self.module, f'{self.name}_v', v)

    def _update_u_v(self):
        u = getattr(self.module, f'{self.name}_u')
        v = getattr(self.module, f'{self.name}_v')
        w = getattr(self.module, self.name)
        
        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = F.normalize(torch.mv(torch.t(w.view(height, -1).data), u.data), dim=0)
            u.data = F.normalize(torch.mv(w.view(height, -1).data, v.data), dim=0)
            
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def __call__(self, *args, **kwargs):
        self._update_u_v()
        return self.module(*args, **kwargs)

def spectral_norm(module):
    return SpectralNorm(module)

In [None]:
# Improved Generator with residual connections and self-attention
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(True),
            nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(True),
            nn.Conv2d(in_channels, in_channels, 3, 1, 1, bias=False)
        )
        
    def forward(self, x):
        return x + self.block(x)

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value(x).view(batch_size, -1, width * height)
        
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        
        return self.gamma * out + x


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Initial projection
        self.initial = nn.Sequential(
            nn.ConvTranspose2d(LATENT_DIM, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        
        # Upsampling layers with residual blocks
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            ResidualBlock(256)
        )
        
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            ResidualBlock(128)
        )
        
        # Self-attention at 16x16 resolution
        self.attention = SelfAttention(128)
        
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            ResidualBlock(64)
        )
        
        # Output layer
        self.output = nn.Sequential(
            nn.ConvTranspose2d(64, CHANNELS, 4, 2, 1, bias=False),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.initial(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.attention(x)
        x = self.up3(x)
        return self.output(x)


In [None]:
# Improved Discriminator with spectral normalization and self-attention
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        # 64x64 -> 32x32
        self.layer1 = nn.Sequential(
            spectral_norm(nn.Conv2d(CHANNELS, 64, 4, 2, 1, bias=False)),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # 32x32 -> 16x16
        self.layer2 = nn.Sequential(
            spectral_norm(nn.Conv2d(64, 128, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Self-attention at 16x16 resolution
        self.attention = SelfAttention(128)
        
        # 16x16 -> 8x8
        self.layer3 = nn.Sequential(
            spectral_norm(nn.Conv2d(128, 256, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # 8x8 -> 4x4
        self.layer4 = nn.Sequential(
            spectral_norm(nn.Conv2d(256, 512, 4, 2, 1, bias=False)),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # 4x4 -> 1x1
        self.output = nn.Sequential(
            spectral_norm(nn.Conv2d(512, 1, 4, 1, 0, bias=False)),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.attention(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return self.output(x)


In [None]:
# Initialize models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Print model architectures and parameter counts
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Generator parameters: {count_parameters(generator)}")
print(f"Discriminator parameters: {count_parameters(discriminator)}")

# Setup optimizers with different learning rates
g_optimizer = torch.optim.Adam(generator.parameters(), lr=LR_G, betas=BETAS)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=LR_D, betas=BETAS)

# Learning rate schedulers for better convergence
g_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(g_optimizer, T_max=EPOCHS, eta_min=LR_G/10)
d_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(d_optimizer, T_max=EPOCHS, eta_min=LR_D/10)

# Loss functions
bce_loss = nn.BCELoss()

# Wasserstein loss for improved stability
def wasserstein_loss(y_pred, y_true):
    return torch.mean(y_true * y_pred)


In [None]:
# Enhanced training function with gradient penalty and WGAN features
def compute_gradient_penalty(discriminator, real_samples, fake_samples):
    """Compute gradient penalty for improved WGAN training"""
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    
    # Create interpolated images
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    
    # Calculate discriminator output for interpolated images
    d_interpolates = discriminator(interpolates)
    
    # Calculate gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates).to(device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # Calculate gradient penalty
    gradients = gradients.view(batch_size, -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

def train_epoch(generator, discriminator, dataloader, g_optimizer, d_optimizer, epoch):
    generator.train()
    discriminator.train()
    
    g_losses = []
    d_losses = []
    d_accs = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (real_images, _) in pbar:
        batch_size = real_images.size(0)
        real_images = real_images.to(device)
        
        # -------------------------
        # Train Discriminator
        # -------------------------
        for _ in range(N_CRITIC):
            discriminator.zero_grad()
            
            # Real images
            label_real = torch.full((batch_size, 1, 1, 1), LABEL_SMOOTHING, device=device)  # Label smoothing
            output_real = discriminator(real_images)
            d_loss_real = bce_loss(output_real, label_real)
            
            # Fake images
            z = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
            fake_images = generator(z).detach()
            label_fake = torch.zeros(batch_size, 1, 1, 1, device=device)
            output_fake = discriminator(fake_images)
            d_loss_fake = bce_loss(output_fake, label_fake)
            
            # Gradient penalty for stability
            gp = 0.1 * compute_gradient_penalty(discriminator, real_images, fake_images)
            
            # Combined loss with regularization
            d_loss = d_loss_real + d_loss_fake + gp
            
            d_loss.backward()
            d_optimizer.step()
            
            # Calculate accuracy for monitoring
            pred_real = (output_real > 0.5).float()
            pred_fake = (output_fake < 0.5).float()
            acc = 0.5 * (torch.mean(pred_real) + torch.mean(pred_fake))
            d_accs.append(acc.item())
        
        # -------------------------
        # Train Generator
        # -------------------------
        generator.zero_grad()
        
        # Generate new fake images
        z = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
        fake_images = generator(z)
        
        # Use feature matching loss: matching intermediate layer activations
        label_real = torch.ones(batch_size, 1, 1, 1, device=device)
        output_fake = discriminator(fake_images)
        g_loss = bce_loss(output_fake, label_real)
        
        g_loss.backward()
        g_optimizer.step()
        
        g_losses.append(g_loss.item())
        d_losses.append(d_loss.item())
        
        # Update progress bar
        pbar.set_description(f"[{epoch+1}/{EPOCHS}] D: {d_loss.item():.4f}, G: {g_loss.item():.4f}, Acc: {np.mean(d_accs):.4f}")
    
    return np.mean(g_losses), np.mean(d_losses), np.mean(d_accs)


In [None]:
# Enhanced evaluation function
def evaluate(generator, discriminator, dataloader, device):
    generator.eval()
    discriminator.eval()
    
    g_losses = []
    d_losses = []
    d_accs = []
    
    with torch.no_grad():
        for real_images, _ in dataloader:
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            
            # Real images
            label_real = torch.ones(batch_size, 1, 1, 1, device=device)
            output_real = discriminator(real_images)
            d_loss_real = bce_loss(output_real, label_real)
            
            # Fake images
            z = torch.randn(batch_size, LATENT_DIM, 1, 1, device=device)
            fake_images = generator(z)
            label_fake = torch.zeros(batch_size, 1, 1, 1, device=device)
            output_fake = discriminator(fake_images)
            
            d_loss_fake = bce_loss(output_fake, label_fake)
            d_loss = d_loss_real + d_loss_fake
            
            # Generator loss
            label_ones = torch.ones(batch_size, 1, 1, 1, device=device)
            g_loss = bce_loss(output_fake, label_ones)
            
            # Calculate accuracy
            pred_real = (output_real > 0.5).float()
            pred_fake = (output_fake < 0.5).float()
            acc = 0.5 * (torch.mean(pred_real) + torch.mean(pred_fake))
            
            g_losses.append(g_loss.item())
            d_losses.append(d_loss.item())
            d_accs.append(acc.item())
    
    return np.mean(g_losses), np.mean(d_losses), np.mean(d_accs)


In [None]:
# Fixed noise for consistent evaluation
fixed_noise = torch.randn(64, LATENT_DIM, 1, 1, device=device)

# Setup directories for saving models and images
os.makedirs("gan_images", exist_ok=True)
os.makedirs("gan_models", exist_ok=True)

# Initialize tracking metrics
train_g_losses, train_d_losses, train_accuracies = [], [], []
val_g_losses, val_d_losses, val_accuracies = [], [], []
best_val_acc = 0
best_fid = float('inf')

# Training loop with early stopping
patience = 7  # Early stopping patience
no_improve = 0

for epoch in range(EPOCHS):
    # Training
    g_loss, d_loss, train_acc = train_epoch(
        generator, discriminator, train_loader,
        g_optimizer, d_optimizer, epoch
    )
    train_g_losses.append(g_loss)
    train_d_losses.append(d_loss)
    train_accuracies.append(train_acc)
    
    # Validation
    val_g_loss, val_d_loss, val_acc = evaluate(
        generator, discriminator, val_loader, device
    )
    val_g_losses.append(val_g_loss)
    val_d_losses.append(val_d_loss)
    val_accuracies.append(val_acc)
    
    # Generate and save sample images
    with torch.no_grad():
        fake = generator(fixed_noise)
        vutils.save_image(fake, f"gan_images/fake_epoch_{epoch+1}.png", normalize=True)
    
    # Save best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        no_improve = 0
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'val_acc': val_acc,
        }, 'gan_models/best_model.pth')
        print(f"Saved best model with val_acc: {val_acc:.4f}")
    else:
        no_improve += 1
    
    # Step learning rate schedulers
    g_scheduler.step()
    d_scheduler.step()
    
    # Print metrics
    print(f"Epoch [{epoch+1}/{EPOCHS}]")
    print(f"Train - G_loss: {g_loss:.4f}, D_loss: {d_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val - G_loss: {val_g_loss:.4f}, D_loss: {val_d_loss:.4f}, Acc: {val_acc:.4f}")
    print(f"Learning rates - G: {g_optimizer.param_groups[0]['lr']:.6f}, D: {d_optimizer.param_groups[0]['lr']:.6f}")
    
    # Checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'val_acc': val_acc,
        }, f'gan_models/checkpoint_epoch_{epoch+1}.pth')
    
    # Early stopping
    if no_improve >= patience:
        print(f"Early stopping at epoch {epoch+1}")
        break


In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_g_losses, label='Generator Train')
plt.plot(val_g_losses, label='Generator Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Generator Loss')

plt.subplot(1, 3, 2)
plt.plot(train_d_losses, label='Discriminator Train')
plt.plot(val_d_losses, label='Discriminator Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Discriminator Loss')

plt.subplot(1, 3, 3)
plt.plot(train_accuracies, label='Train')
plt.plot(val_accuracies, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Discriminator Accuracy')

plt.tight_layout()
plt.savefig('gan_images/training_curves.png')
plt.close()

# Load the best model for inference
best_model = torch.load('gan_models/best_model.pth')
generator.load_state_dict(best_model['generator_state_dict'])
discriminator.load_state_dict(best_model['discriminator_state_dict'])

print(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")


In [None]:
# Function to generate samples with the best model
def generate_samples(n_samples=25, z_dim=LATENT_DIM):
    """Generate and save samples from the best model"""
    generator.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, z_dim, 1, 1, device=device)
        samples = generator(z)
        vutils.save_image(samples, f"gan_images/final_samples.png", nrow=5, normalize=True)
    return samples

# Generate final samples
generate_samples(25)