In [None]:
import torch
from model_utils import setup_generator, setup_discriminator

G = setup_generator(4, -1)
D = setup_discriminator(4)

G.load_state_dict(torch.load(r'E:\Data\biggan\g.pth', map_location=torch.device('cpu')))

In [9]:
from support_utils import get_latent_input

################################################################
# //////////////////////////////////////////////////////////// #
################################################################
def compute_gradient_penalty(discriminator, real_images, fake_images, labels, device):
    alpha = torch.rand(real_images.size(0), 1, 1, 1, device=device)
    
    interpolates = (alpha * real_images + (1 - alpha) * fake_images).requires_grad_(True)
    
    d_interpolates = discriminator(interpolates, labels)
    
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates, device=device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    
    return gradient_penalty

################################################################
# //////////////////////////////////////////////////////////// #
################################################################
def train_discriminator(generator,
                        discriminator,
                        real_images,
                        labels,
                        d_loss_fn,
                        optimizer_D,
                        lambda_gp,
                        device,
                        scaler):
    
    # Generate fake images
    noise, labels = get_latent_input(real_images.size(0), labels, device)
    
    with torch.amp.autocast(device_type=device, dtype=torch.float16):
        fake_images = generator(noise, labels)

        # Get discriminator outputs
        real_outputs = discriminator(real_images, labels)
        fake_outputs = discriminator(fake_images.detach(), labels)

        # Discriminator losses
        d_adv_loss = d_loss_fn(real_outputs, fake_outputs)
        
    # Compute gradient penalty in full precision
    with torch.amp.autocast(device_type='cuda', enabled=False):
        gradient_penalty = compute_gradient_penalty(discriminator, real_images, fake_images.detach(), labels, device)
        
    d_loss = d_adv_loss + lambda_gp * gradient_penalty

    # Backpropagation and optimization
    optimizer_D.zero_grad()
    scaler.scale(d_loss).backward()
    scaler.step(optimizer_D)
    scaler.update()

    # Compute accuracies
    real_acc = (real_outputs > 0).float().mean()
    fake_acc = (fake_outputs < 0).float().mean()

    return d_loss.item(), real_acc.item(), fake_acc.item()

In [2]:
import os
from torchvision import transforms
from PIL import Image
import torch

class PalmDataset(torch.utils.data.Dataset):
    """Dataset class for palm images with labels."""

    def __init__(self, root_dir):
        self.root_dir = root_dir
        
        self.transform = transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.RandomVerticalFlip(),
                         transforms.RandomRotation(15),
                         transforms.RandomResizedCrop(512, scale=(0.8, 1.0)),
                         transforms.ColorJitter(brightness=0.1, contrast=0.1),
                         transforms.ToTensor(),
                         transforms.Normalize([0.5]*3, [0.5]*3)
                         ])
        
        self.image_paths = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        name = self.image_paths[idx]
        img_path = os.path.join(self.root_dir, name)
        
        try:
            image = Image.open(img_path).convert("RGB")
            
        except Exception as e:
            image = Image.new("RGB", (512, 512))
            
        label = 0
        
        if name.startswith('lr'):
            label = 1
        if name.startswith('rw'):
            label = 2
        if name.startswith('rb'):
            label = 3
        
        image = self.transform(image)
            
        return image, label
    
loader = torch.utils.data.DataLoader(dataset=PalmDataset(r'E:\Data\Biometric\Set\Palms11kSplit\test'), batch_size=4)

In [6]:
from support_utils import save_sample_images_by_class

for batch, labels in loader:
    
    batch.cuda()
    
    data = {
        0: batch[0],
        1: batch[1],
        2: batch[2],
        3: batch[3]
    }
    
    save_sample_images_by_class(data, 3, r'', 4)
    
    break