In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os

# Navigate to your project directory in Google Drive
%cd /content/drive/MyDrive/CAF-GAN/

# Unzip the image files (this might take a few minutes)
# The -q makes the output quiet, -n prevents unzipping if already done
!unzip -q -n mimic-cxr-jpg-2.0.0.zip

print("✅ Workspace ready and images unzipped.")

/content/drive/MyDrive/CAF-GAN
✅ Workspace ready and images unzipped.


In [2]:
!pip install pyyaml pandas scikit-learn albumentations segmentation-models-pytorch -q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/154.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    """
    CAF-GAN Generator Network (DCGAN architecture).
    Takes a latent vector z and outputs a 256x256 grayscale image.
    Output is normalized between -1 and 1 using Tanh.
    """
    def __init__(self, latent_dim, channels=1):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim

        self.main = nn.Sequential(
            # Input: latent_dim x 1 x 1
            nn.ConvTranspose2d(latent_dim, 1024, 4, 1, 0, bias=False),
            nn.BatchNorm2d(1024),
            nn.ReLU(True),
            # State: 1024 x 4 x 4
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # State: 512 x 8 x 8
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # State: 256 x 16 x 16
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # State: 128 x 32 x 32
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # State: 64 x 64 x 64
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # State: 32 x 128 x 128
            nn.ConvTranspose2d(32, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: channels x 256 x 256
        )

    def forward(self, input):
        return self.main(input)

In [4]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    """
    CAF-GAN Discriminator Network (Critic for WGAN-GP).
    Takes a 256x256 grayscale image and outputs a single scalar score.
    Uses InstanceNorm instead of BatchNorm for stability with WGAN-GP.
    No final activation function.
    """
    def __init__(self, channels=1):
        super(Discriminator, self).__init__()

        self.main = nn.Sequential(
            # Input: channels x 256 x 256
            nn.Conv2d(channels, 32, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 32 x 128 x 128
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(64, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 64 x 64 x 64
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 128 x 32 x 32
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 256 x 16 x 16
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 512 x 8 x 8
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(1024, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            # State: 1024 x 4 x 4
            nn.Conv2d(1024, 1, 4, 1, 0, bias=False)
            # Output: 1 x 1 x 1
        )

    def forward(self, input):
        return self.main(input)

In [5]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image

# This single dataset file will serve both critic training scripts.

class MIMICCXRClassifierDataset(Dataset):
    """
    Dataset for the Cdiag (classification) task.
    - Loads a JPG image.
    - Converts it to RGB (as required by ResNet).
    - Returns the image and its corresponding Pneumonia label.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        # Construct the path to the JPG image
        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load image and convert to a numpy array in RGB format
        image = Image.open(image_path).convert("RGB")
        image = np.array(image)

        # Apply augmentations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        # Get the label
        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)

        return image, label.unsqueeze(0)


class MIMICXRSegmentationDataset(Dataset):
    """
    Dataset for the Cseg (segmentation) task.
    - Loads a JPG image (as 3-channel RGB). <--- UPDATED
    - Loads its corresponding pre-generated PNG mask.
    - Returns both the image and the mask.
    """
    def __init__(self, df, image_dir, mask_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )
        mask_path = os.path.join(self.mask_dir, f"{dicom_id}.png")

        # --- KEY CHANGE ---
        # Load image and convert to RGB to match the model's expected input channels.
        image = np.array(Image.open(image_path).convert("RGB"), dtype=np.float32)

        # Load mask as grayscale
        mask = np.array(Image.open(mask_path).convert("L"), dtype=np.float32)

        # RESIZE THE IMAGE to match the mask size (256×256) BEFORE augmentation
        image = np.array(Image.fromarray(image.astype(np.uint8)).resize((256, 256), Image.BILINEAR))

        # Normalize mask values from [0, 255] to [0.0, 1.0]
        mask[mask == 255.0] = 1.0

        # Apply augmentations (Albumentations will now see matching input sizes from the Resize transform)
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        # Add a channel dimension for the mask for consistency
        return image, mask.unsqueeze(0)

class MIMICCXR_GANDataset(Dataset):
    """
    Dataset for the main GAN training.
    - Loads a JPG image (as grayscale).
    - Returns the image, its Pneumonia label, and the one-hot encoded race group.
    """
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

        # Pre-process sensitive attributes
        self.df['race_group'] = self.df['race_group'].astype('category')
        self.race_categories = self.df['race_group'].cat.categories
        self.one_hot_races = pd.get_dummies(self.df['race_group'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        subject_id = str(row['subject_id'])
        study_id = str(row['study_id'])
        dicom_id = row['dicom_id']

        image_path = os.path.join(
            self.image_dir, f'p{subject_id[:2]}', f'p{subject_id}', f's{study_id}', f'{dicom_id}.jpg'
        )

        # Load the image and convert to "RGB"
        image = Image.open(image_path).convert("RGB") # <-- CHANGE "L" to "RGB"
        image = np.array(image, dtype=np.float32)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        label = torch.tensor(row['Pneumonia'], dtype=torch.float32)
        race = torch.tensor(self.one_hot_races.iloc[idx].values, dtype=torch.float32)

        return image, label, race

In [6]:
CONFIG = {
    # --- Data Paths ---
    "IMAGE_DIR": "/content/drive/MyDrive/CAF-GAN/mimic-cxr-jpg-2.0.0/files/",
    "TRAIN_CSV_PATH": "/content/drive/MyDrive/CAF-GAN/data/splits/train.csv",
    "CDIAG_PATH": "/content/drive/MyDrive/CAF-GAN/outputs/cdiag/best_cdiag_colab.pth",
    "CSEG_PATH": "/content/drive/MyDrive/CAF-GAN/outputs/cseg/best_cseg_colab.pth",

    # --- Output Paths ---
    "OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/gan/",
    "IMAGE_OUTPUT_DIR": "/content/drive/MyDrive/CAF-GAN/outputs/gan/generated_images/",

    # --- Model & Training Hyperparameters ---
    "IMG_SIZE": 256,
    "CHANNELS": 3,
    "LATENT_DIM": 128,
    "BATCH_SIZE": 16,  # GANs are memory intensive
    "EPOCHS": 200,     # GANs need many epochs to converge
   "G_LR": 0.0001,       # ⬅️ Optimized for your critic strengths
    "D_LR": 0.0001,       # ⬅️ Balanced with generator
    "B1": 0.5,
    "B2": 0.999,
    "N_CRITIC": 5,        # ⬅️ Standard WGAN-GP value

    # --- Loss Weights (CALIBRATED for your critic performance) ---
    "LAMBDA_GP": 10,
    "LAMBDA_FAIR": 0.15,   # ⬅️ Increased slightly (Cdiag: 66% accuracy)
    "LAMBDA_CLINIC": 0.08, # ⬅️ Increased slightly (Cseg: 57.6% Dice)

    # --- Adaptive Training Parameters ---
    "USE_GRADIENT_CLIPPING": True,
    "MAX_GRAD_NORM": 1.0,
    "LR_DECAY_FACTOR": 0.8,
    "LR_DECAY_PATIENCE": 10,

    # --- System & Logging ---
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "NUM_WORKERS": 2,
    "LOG_FREQ": 50,
    "SAVE_IMG_FREQ": 1,
    "SAVE_MODEL_FREQ": 10
}


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
import torchvision.utils as vutils
import pandas as pd
from tqdm import tqdm
import os
import yaml
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import hashlib
import shutil

# --- Utility Functions ---
def weights_init(m):
    """Custom weights initialization called on netG and netD."""
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def gradient_penalty(critic, real, fake, device):
    """Calculates the gradient penalty for WGAN-GP."""
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)

    # Create interpolated images WITH gradient tracking
    interpolated_images = real * alpha + fake * (1 - alpha)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Compute gradients
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]

    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gp = torch.mean((gradient_norm - 1) ** 2)
    return gp

def is_valid_checkpoint(file_path):
    """Check if a checkpoint file is valid."""
    if not os.path.exists(file_path):
        return False

    try:
        checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
        required_keys = ['epoch', 'netG_state_dict', 'netD_state_dict',
                        'optimizerG_state_dict', 'optimizerD_state_dict',
                        'G_losses', 'D_losses']

        if not all(key in checkpoint for key in required_keys):
            return False

        return True
    except Exception as e:
        return False

def save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, path):
    """Save training checkpoint with integrity verification."""
    try:
        temp_path = path + ".temp"

        torch.save({
            'epoch': epoch,
            'netG_state_dict': netG.state_dict(),
            'netD_state_dict': netD.state_dict(),
            'optimizerG_state_dict': optimizerG.state_dict(),
            'optimizerD_state_dict': optimizerD.state_dict(),
            'G_losses': G_losses,
            'D_losses': D_losses
        }, temp_path)

        if is_valid_checkpoint(temp_path):
            if os.path.exists(path):
                os.remove(path)
            shutil.move(temp_path, path)
            print(f"✅ Checkpoint saved successfully for epoch {epoch}")
            return True
        else:
            if os.path.exists(temp_path):
                os.remove(temp_path)
            return False
    except Exception as e:
        if os.path.exists(temp_path):
            os.remove(temp_path)
        return False

def load_checkpoint(path, netG, netD, optimizerG, optimizerD, device):
    """Load training checkpoint with error handling."""
    if not os.path.exists(path):
        return 0, [], []

    if not is_valid_checkpoint(path):
        corrupted_path = path + ".corrupted"
        if os.path.exists(path):
            shutil.move(path, corrupted_path)
        return 0, [], []

    try:
        checkpoint = torch.load(path, map_location=device, weights_only=False)
        netG.load_state_dict(checkpoint['netG_state_dict'])
        netD.load_state_dict(checkpoint['netD_state_dict'])
        optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
        optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
        print(f"✅ Checkpoint loaded successfully from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'], checkpoint['G_losses'], checkpoint['D_losses']

    except Exception as e:
        corrupted_path = path + ".corrupted"
        if os.path.exists(path):
            shutil.move(path, corrupted_path)
        return 0, [], []

# --- 🚀 Main Training Execution ---
def run_gan_training(CONFIG):
    DEVICE = CONFIG['DEVICE']
    os.makedirs(CONFIG['IMAGE_OUTPUT_DIR'], exist_ok=True)
    os.makedirs(CONFIG['OUTPUT_DIR'], exist_ok=True)
    print(f"Using device: {DEVICE}")

    # --- Data Loading ---
    transform = A.Compose([
        A.Resize(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ])
    train_df = pd.read_csv(CONFIG['TRAIN_CSV_PATH'])
    train_dataset = MIMICCXR_GANDataset(train_df, CONFIG['IMAGE_DIR'], transform=transform)
    dataloader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True,
                           num_workers=CONFIG['NUM_WORKERS'])

    # --- Initialize Models ---
    netG = Generator(CONFIG['LATENT_DIM'], CONFIG['CHANNELS']).to(DEVICE)
    netD = Discriminator(CONFIG['CHANNELS']).to(DEVICE)

    # --- Load Frozen Critics ---
    Cdiag = models.resnet50()
    Cdiag.fc = nn.Linear(Cdiag.fc.in_features, 1)
    Cdiag.load_state_dict(torch.load(CONFIG['CDIAG_PATH'], map_location=DEVICE))
    Cdiag.eval()
    for param in Cdiag.parameters():
        param.requires_grad = False
    Cdiag.to(DEVICE)

    Cseg = smp.Unet(encoder_name="resnet34", in_channels=3, classes=1)
    Cseg.load_state_dict(torch.load(CONFIG['CSEG_PATH'], map_location=DEVICE))
    Cseg.eval()
    for param in Cseg.parameters():
        param.requires_grad = False
    Cseg.to(DEVICE)

    # --- Optimizers ---
    optimizerD = optim.Adam(netD.parameters(), lr=CONFIG['D_LR'], betas=(CONFIG['B1'], CONFIG['B2']))
    optimizerG = optim.Adam(netG.parameters(), lr=CONFIG['G_LR'], betas=(CONFIG['B1'], CONFIG['B2']))

    # --- Load from Epoch 60 ---
    checkpoint_path = f"{CONFIG['OUTPUT_DIR']}/training_checkpoint.pth"
    netG_weights_path = f"{CONFIG['OUTPUT_DIR']}/netG_epoch_60.pth"
    netD_weights_path = f"{CONFIG['OUTPUT_DIR']}/netD_epoch_60.pth"

    if os.path.exists(checkpoint_path) and is_valid_checkpoint(checkpoint_path):
        start_epoch, G_losses, D_losses = load_checkpoint(checkpoint_path, netG, netD, optimizerG, optimizerD, DEVICE)
        print(f"Resuming from epoch {start_epoch}")
    elif os.path.exists(netG_weights_path) and os.path.exists(netD_weights_path):
        print("📦 Loading weights from epoch 60")
        netG.load_state_dict(torch.load(netG_weights_path, map_location=DEVICE))
        netD.load_state_dict(torch.load(netD_weights_path, map_location=DEVICE))
        start_epoch = 60
        G_losses, D_losses = [], []
    else:
        print("🚀 Starting from scratch")
        netG.apply(weights_init)
        netD.apply(weights_init)
        start_epoch = 0
        G_losses, D_losses = [], []

    print(f"✅ Starting from epoch {start_epoch}")

    fixed_noise = torch.randn(64, CONFIG['LATENT_DIM'], 1, 1, device=DEVICE)

    # --- Training Variables ---
    best_wasserstein = float('inf')
    lr_decay_counter = 0

    # --- 🏟️ The Grand Training Loop ---
    print("Starting Training Loop...")

    for epoch in range(start_epoch, CONFIG['EPOCHS']):
        try:
            epoch_g_losses = []
            epoch_d_losses = []
            epoch_gp_values = []
            epoch_wasserstein_values = []

            for i, (real_imgs, labels, races) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{CONFIG['EPOCHS']}")):
                real_imgs, labels, races = real_imgs.to(DEVICE), labels.to(DEVICE), races.to(DEVICE)

                # ---------------------------
                # Train Discriminator (Critic)
                # ---------------------------
                netD.zero_grad()

                # Generate fake images
                noise = torch.randn(real_imgs.size(0), CONFIG['LATENT_DIM'], 1, 1, device=DEVICE)
                with torch.no_grad():
                    fake_imgs = netG(noise)

                # Real images forward pass
                critic_real = netD(real_imgs).reshape(-1)
                # Fake images forward pass
                critic_fake = netD(fake_imgs).reshape(-1)

                # Calculate Wasserstein distance
                wasserstein_dist = torch.mean(critic_real) - torch.mean(critic_fake)
                epoch_wasserstein_values.append(wasserstein_dist.item())

                # Gradient penalty
                gp = gradient_penalty(netD, real_imgs, fake_imgs, DEVICE)

                # Discriminator loss
                loss_critic = -wasserstein_dist + CONFIG['LAMBDA_GP'] * gp
                loss_critic.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(netD.parameters(), max_norm=1.0)
                optimizerD.step()

                # Update generator every N_CRITIC steps
                if i % CONFIG['N_CRITIC'] == 0:
                    # ---------------------
                    # Train Generator
                    # ---------------------
                    netG.zero_grad()

                    # Generate fresh fake images
                    noise = torch.randn(real_imgs.size(0), CONFIG['LATENT_DIM'], 1, 1, device=DEVICE)
                    fake_imgs = netG(noise)

                    # 1. Adversarial Loss
                    critic_output = netD(fake_imgs)
                    adversarial_loss = -torch.mean(critic_output)

                    # 2. Fairness Loss
                    with torch.no_grad():
                        diag_output = Cdiag(fake_imgs)
                        diag_preds = torch.sigmoid(diag_output).squeeze()

                    fairness_loss = 0.0
                    tpr_per_group = []
                    for j in range(races.shape[1]):
                        group_mask = races[:, j] == 1
                        if group_mask.sum() > 0:
                            tpr_num = torch.sum(diag_preds[group_mask] * labels[group_mask])
                            tpr_den = torch.sum(labels[group_mask]) + 1e-6
                            tpr_per_group.append(tpr_num / tpr_den)

                    if len(tpr_per_group) > 1:
                        tpr_tensor = torch.stack(tpr_per_group)
                        if not torch.isnan(tpr_tensor).any():
                            fairness_loss = torch.std(tpr_tensor)  # Use standard deviation for better fairness

                    # 3. Clinical Loss
                    with torch.no_grad():
                        seg_masks = torch.sigmoid(Cseg(fake_imgs))
                        clinical_loss = torch.mean((1 - seg_masks) ** 2)

                    # Total Generator Loss
                    loss_gen = (adversarial_loss +
                               CONFIG['LAMBDA_FAIR'] * fairness_loss +
                               CONFIG['LAMBDA_CLINIC'] * clinical_loss)

                    loss_gen.backward()

                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(netG.parameters(), max_norm=1.0)
                    optimizerG.step()

                    # Store losses for logging
                    G_losses.append(loss_gen.item())
                    D_losses.append(loss_critic.item())
                    epoch_g_losses.append(loss_gen.item())
                    epoch_d_losses.append(loss_critic.item())
                    epoch_gp_values.append(gp.item())

                # --- Logging ---
                if i % CONFIG['LOG_FREQ'] == 0:
                    avg_g_loss = np.mean(epoch_g_losses[-10:]) if epoch_g_losses else 0
                    avg_d_loss = np.mean(epoch_d_losses[-10:]) if epoch_d_losses else 0
                    avg_gp = np.mean(epoch_gp_values[-10:]) if epoch_gp_values else 0
                    avg_wasserstein = np.mean(epoch_wasserstein_values[-10:]) if epoch_wasserstein_values else 0

                    print(f"[{epoch+1}/{CONFIG['EPOCHS']}][{i}/{len(dataloader)}] "
                          f"Loss_D: {avg_d_loss:.4f} "
                          f"Loss_G: {avg_g_loss:.4f} "
                          f"GP: {avg_gp:.4f} "
                          f"W_dist: {avg_wasserstein:.4f}")

                # --- Save checkpoint every 20 batches ---
                if i % 20 == 0:
                    save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, checkpoint_path)

            # --- Validation & Monitoring ---
            # Calculate epoch statistics
            epoch_avg_g_loss = np.mean(epoch_g_losses) if epoch_g_losses else 0
            epoch_avg_d_loss = np.mean(epoch_d_losses) if epoch_d_losses else 0
            epoch_avg_wasserstein = np.mean(epoch_wasserstein_values) if epoch_wasserstein_values else 0

            print(f"\n📊 Epoch {epoch+1} Summary:")
            print(f"   Avg G Loss: {epoch_avg_g_loss:.4f}")
            print(f"   Avg D Loss: {epoch_avg_d_loss:.4f}")
            print(f"   Avg Wasserstein: {epoch_avg_wasserstein:.4f}")

            # Learning rate scheduling
            if epoch_avg_wasserstein < best_wasserstein:
                best_wasserstein = epoch_avg_wasserstein
                lr_decay_counter = 0
            else:
                lr_decay_counter += 1
                if lr_decay_counter >= 5:  # Patience of 5 epochs
                    for param_group in optimizerG.param_groups:
                        param_group['lr'] *= 0.8
                    for param_group in optimizerD.param_groups:
                        param_group['lr'] *= 0.8
                    lr_decay_counter = 0
                    print(f"📉 Learning rate reduced to: {optimizerG.param_groups[0]['lr']}")

            # --- Save Images ---
            if (epoch + 1) % CONFIG['SAVE_IMG_FREQ'] == 0:
                with torch.no_grad():
                    fake_grid = netG(fixed_noise).detach().cpu()
                vutils.save_image(fake_grid, f"{CONFIG['IMAGE_OUTPUT_DIR']}/epoch_{epoch+1}.png",
                                 normalize=True, nrow=8)
                print(f"🖼️  Saved generated images for epoch {epoch+1}")

            # --- Save model checkpoints ---
            if (epoch + 1) % CONFIG.get('SAVE_MODEL_FREQ', 10) == 0:
                torch.save(netG.state_dict(), f"{CONFIG['OUTPUT_DIR']}/netG_epoch_{epoch+1}.pth")
                torch.save(netD.state_dict(), f"{CONFIG['OUTPUT_DIR']}/netD_epoch_{epoch+1}.pth")
                print(f"💾 Saved model weights for epoch {epoch+1}")

            # --- Save checkpoint at end of epoch ---
            save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, checkpoint_path)
            print(f"💾 Epoch {epoch+1} completed. Checkpoint saved.")

        except Exception as e:
            print(f"❌ Error occurred during epoch {epoch+1}: {e}")
            emergency_path = f"{CONFIG['OUTPUT_DIR']}/emergency_checkpoint_epoch_{epoch}.pth"
            save_checkpoint(epoch, netG, netD, optimizerG, optimizerD, G_losses, D_losses, emergency_path)
            print(f"💾 Emergency checkpoint saved: {emergency_path}")
            continue

    # Final cleanup and model saving
    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)

    torch.save(netG.state_dict(), f"{CONFIG['OUTPUT_DIR']}/netG_final.pth")
    torch.save(netD.state_dict(), f"{CONFIG['OUTPUT_DIR']}/netD_final.pth")
    print("💾 Saved final models")
    print("✅ GAN Training Complete!")

    return netG, netD, G_losses, D_losses

# Run training
generator, discriminator, G_losses, D_losses = run_gan_training(CONFIG)

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

✅ Checkpoint loaded successfully from epoch 179
Resuming from epoch 179
✅ Starting from epoch 179
Starting Training Loop...


Epoch 180/200:   0%|          | 0/88 [00:00<?, ?it/s]

[180/200][0/88] Loss_D: -171.4308 Loss_G: 474.8717 GP: 7.1852 W_dist: 243.2830


Epoch 180/200:   1%|          | 1/88 [00:22<32:59, 22.75s/it]

✅ Checkpoint saved successfully for epoch 179


Epoch 180/200:  24%|██▍       | 21/88 [03:00<09:41,  8.68s/it]

✅ Checkpoint saved successfully for epoch 179


Epoch 180/200:  47%|████▋     | 41/88 [05:37<07:02,  8.99s/it]

✅ Checkpoint saved successfully for epoch 179


Epoch 180/200:  58%|█████▊    | 51/88 [07:02<05:36,  9.11s/it]

[180/200][50/88] Loss_D: -142.6861 Loss_G: 474.2354 GP: 7.9765 W_dist: 235.1087


Epoch 180/200:  69%|██████▉   | 61/88 [08:19<04:13,  9.40s/it]

✅ Checkpoint saved successfully for epoch 179


Epoch 180/200:  92%|█████████▏| 81/88 [10:55<01:05,  9.35s/it]

✅ Checkpoint saved successfully for epoch 179


Epoch 180/200: 100%|██████████| 88/88 [11:43<00:00,  8.00s/it]



📊 Epoch 180 Summary:
   Avg G Loss: 478.9924
   Avg D Loss: -147.2567
   Avg Wasserstein: 225.9839
🖼️  Saved generated images for epoch 180
💾 Saved model weights for epoch 180
✅ Checkpoint saved successfully for epoch 179
💾 Epoch 180 completed. Checkpoint saved.


Epoch 181/200:   0%|          | 0/88 [00:00<?, ?it/s]

[181/200][0/88] Loss_D: -125.1906 Loss_G: 491.2031 GP: 7.3790 W_dist: 198.9803


Epoch 181/200:   1%|          | 1/88 [00:06<09:40,  6.67s/it]

✅ Checkpoint saved successfully for epoch 180


Epoch 181/200:  24%|██▍       | 21/88 [00:44<02:57,  2.65s/it]

✅ Checkpoint saved successfully for epoch 180


Epoch 181/200:  47%|████▋     | 41/88 [01:20<01:57,  2.50s/it]

✅ Checkpoint saved successfully for epoch 180


Epoch 181/200:  58%|█████▊    | 51/88 [01:41<01:55,  3.12s/it]

[181/200][50/88] Loss_D: -148.8904 Loss_G: 477.2690 GP: 8.3979 W_dist: 210.0389


Epoch 181/200:  69%|██████▉   | 61/88 [01:57<01:02,  2.32s/it]

✅ Checkpoint saved successfully for epoch 180


Epoch 181/200:  92%|█████████▏| 81/88 [02:31<00:17,  2.53s/it]

✅ Checkpoint saved successfully for epoch 180


Epoch 181/200: 100%|██████████| 88/88 [02:38<00:00,  1.81s/it]



📊 Epoch 181 Summary:
   Avg G Loss: 474.1788
   Avg D Loss: -157.0204
   Avg Wasserstein: 227.4945
🖼️  Saved generated images for epoch 181
✅ Checkpoint saved successfully for epoch 180
💾 Epoch 181 completed. Checkpoint saved.


Epoch 182/200:   0%|          | 0/88 [00:00<?, ?it/s]

[182/200][0/88] Loss_D: -222.3769 Loss_G: 477.2451 GP: 3.8328 W_dist: 260.7048


Epoch 182/200:   1%|          | 1/88 [00:05<07:50,  5.41s/it]

✅ Checkpoint saved successfully for epoch 181


Epoch 182/200:  24%|██▍       | 21/88 [00:42<02:56,  2.64s/it]

✅ Checkpoint saved successfully for epoch 181


Epoch 182/200:  47%|████▋     | 41/88 [01:17<01:45,  2.25s/it]

✅ Checkpoint saved successfully for epoch 181


Epoch 182/200:  58%|█████▊    | 51/88 [01:36<01:29,  2.41s/it]

[182/200][50/88] Loss_D: -153.5938 Loss_G: 482.6190 GP: 7.5230 W_dist: 215.3286


Epoch 182/200:  69%|██████▉   | 61/88 [01:54<01:00,  2.24s/it]

✅ Checkpoint saved successfully for epoch 181


Epoch 182/200:  92%|█████████▏| 81/88 [02:30<00:16,  2.30s/it]

✅ Checkpoint saved successfully for epoch 181


Epoch 182/200: 100%|██████████| 88/88 [02:37<00:00,  1.79s/it]



📊 Epoch 182 Summary:
   Avg G Loss: 479.7925
   Avg D Loss: -162.7237
   Avg Wasserstein: 227.2374
🖼️  Saved generated images for epoch 182
✅ Checkpoint saved successfully for epoch 181
💾 Epoch 182 completed. Checkpoint saved.


Epoch 183/200:   0%|          | 0/88 [00:00<?, ?it/s]

[183/200][0/88] Loss_D: -118.7720 Loss_G: 484.4948 GP: 13.3427 W_dist: 252.1994


Epoch 183/200:   1%|          | 1/88 [00:06<09:53,  6.83s/it]

✅ Checkpoint saved successfully for epoch 182


Epoch 183/200:  24%|██▍       | 21/88 [00:43<03:15,  2.92s/it]

✅ Checkpoint saved successfully for epoch 182


Epoch 183/200:  47%|████▋     | 41/88 [01:20<02:07,  2.72s/it]

✅ Checkpoint saved successfully for epoch 182


Epoch 183/200:  58%|█████▊    | 51/88 [01:38<02:07,  3.44s/it]

[183/200][50/88] Loss_D: -151.8575 Loss_G: 473.8981 GP: 6.7849 W_dist: 224.7384


Epoch 183/200:  69%|██████▉   | 61/88 [01:56<01:01,  2.28s/it]

✅ Checkpoint saved successfully for epoch 182


Epoch 183/200:  92%|█████████▏| 81/88 [02:33<00:18,  2.68s/it]

✅ Checkpoint saved successfully for epoch 182


Epoch 183/200: 100%|██████████| 88/88 [02:38<00:00,  1.80s/it]



📊 Epoch 183 Summary:
   Avg G Loss: 471.7072
   Avg D Loss: -155.0500
   Avg Wasserstein: 230.9317
🖼️  Saved generated images for epoch 183
✅ Checkpoint saved successfully for epoch 182
💾 Epoch 183 completed. Checkpoint saved.


Epoch 184/200:   0%|          | 0/88 [00:00<?, ?it/s]

[184/200][0/88] Loss_D: -126.2955 Loss_G: 478.5457 GP: 10.5046 W_dist: 231.3416


Epoch 184/200:   1%|          | 1/88 [00:12<18:16, 12.60s/it]

✅ Checkpoint saved successfully for epoch 183


Epoch 184/200:  24%|██▍       | 21/88 [00:48<03:00,  2.70s/it]

✅ Checkpoint saved successfully for epoch 183


Epoch 184/200:  47%|████▋     | 41/88 [01:23<01:46,  2.26s/it]

✅ Checkpoint saved successfully for epoch 183


Epoch 184/200:  58%|█████▊    | 51/88 [01:41<02:11,  3.57s/it]

[184/200][50/88] Loss_D: -156.7090 Loss_G: 489.9765 GP: 8.1070 W_dist: 223.4107


Epoch 184/200:  69%|██████▉   | 61/88 [01:59<01:10,  2.60s/it]

✅ Checkpoint saved successfully for epoch 183


Epoch 184/200:  92%|█████████▏| 81/88 [02:33<00:14,  2.04s/it]

✅ Checkpoint saved successfully for epoch 183


Epoch 184/200: 100%|██████████| 88/88 [02:41<00:00,  1.84s/it]



📊 Epoch 184 Summary:
   Avg G Loss: 487.1384
   Avg D Loss: -151.9399
   Avg Wasserstein: 226.7221
🖼️  Saved generated images for epoch 184
✅ Checkpoint saved successfully for epoch 183
💾 Epoch 184 completed. Checkpoint saved.


Epoch 185/200:   0%|          | 0/88 [00:00<?, ?it/s]

[185/200][0/88] Loss_D: -147.7766 Loss_G: 484.1916 GP: 9.2723 W_dist: 240.4998


Epoch 185/200:   1%|          | 1/88 [00:07<10:11,  7.03s/it]

✅ Checkpoint saved successfully for epoch 184


Epoch 185/200:  24%|██▍       | 21/88 [00:38<02:50,  2.54s/it]

✅ Checkpoint saved successfully for epoch 184


Epoch 185/200:  47%|████▋     | 41/88 [01:09<01:44,  2.22s/it]

✅ Checkpoint saved successfully for epoch 184


Epoch 185/200:  58%|█████▊    | 51/88 [01:19<00:44,  1.20s/it]

[185/200][50/88] Loss_D: -141.4761 Loss_G: 486.9664 GP: 7.8230 W_dist: 247.7185


Epoch 185/200:  69%|██████▉   | 61/88 [01:41<01:08,  2.54s/it]

✅ Checkpoint saved successfully for epoch 184


Epoch 185/200:  92%|█████████▏| 81/88 [02:13<00:16,  2.36s/it]

✅ Checkpoint saved successfully for epoch 184


Epoch 185/200: 100%|██████████| 88/88 [02:17<00:00,  1.57s/it]



📊 Epoch 185 Summary:
   Avg G Loss: 488.4705
   Avg D Loss: -147.0326
   Avg Wasserstein: 230.7912
📉 Learning rate reduced to: 1.677721600000001e-05
🖼️  Saved generated images for epoch 185
✅ Checkpoint saved successfully for epoch 184
💾 Epoch 185 completed. Checkpoint saved.


Epoch 186/200:   0%|          | 0/88 [00:00<?, ?it/s]

[186/200][0/88] Loss_D: -181.7454 Loss_G: 492.0872 GP: 4.9094 W_dist: 230.8396


Epoch 186/200:   1%|          | 1/88 [00:05<08:22,  5.77s/it]

✅ Checkpoint saved successfully for epoch 185


Epoch 186/200:  24%|██▍       | 21/88 [00:35<02:25,  2.17s/it]

✅ Checkpoint saved successfully for epoch 185


Epoch 186/200:  47%|████▋     | 41/88 [01:04<01:36,  2.05s/it]

✅ Checkpoint saved successfully for epoch 185


Epoch 186/200:  58%|█████▊    | 51/88 [01:14<00:46,  1.25s/it]

[186/200][50/88] Loss_D: -162.4968 Loss_G: 487.9405 GP: 7.1554 W_dist: 244.1122


Epoch 186/200:  69%|██████▉   | 61/88 [01:34<00:59,  2.19s/it]

✅ Checkpoint saved successfully for epoch 185


Epoch 186/200:  92%|█████████▏| 81/88 [02:02<00:12,  1.83s/it]

✅ Checkpoint saved successfully for epoch 185


Epoch 186/200: 100%|██████████| 88/88 [02:08<00:00,  1.46s/it]



📊 Epoch 186 Summary:
   Avg G Loss: 485.5737
   Avg D Loss: -165.7050
   Avg Wasserstein: 231.5528
🖼️  Saved generated images for epoch 186
✅ Checkpoint saved successfully for epoch 185
💾 Epoch 186 completed. Checkpoint saved.


Epoch 187/200:   0%|          | 0/88 [00:00<?, ?it/s]

[187/200][0/88] Loss_D: -158.3387 Loss_G: 489.9925 GP: 8.4621 W_dist: 242.9592


Epoch 187/200:   1%|          | 1/88 [00:12<17:59, 12.41s/it]

✅ Checkpoint saved successfully for epoch 186


Epoch 187/200:  24%|██▍       | 21/88 [00:48<02:23,  2.14s/it]

✅ Checkpoint saved successfully for epoch 186


Epoch 187/200:  47%|████▋     | 41/88 [01:24<02:09,  2.75s/it]

✅ Checkpoint saved successfully for epoch 186


Epoch 187/200:  58%|█████▊    | 51/88 [01:41<02:06,  3.42s/it]

[187/200][50/88] Loss_D: -141.9164 Loss_G: 483.4769 GP: 8.4378 W_dist: 230.4830


Epoch 187/200:  69%|██████▉   | 61/88 [01:58<01:01,  2.29s/it]

✅ Checkpoint saved successfully for epoch 186


Epoch 187/200:  92%|█████████▏| 81/88 [02:33<00:18,  2.65s/it]

✅ Checkpoint saved successfully for epoch 186


Epoch 187/200: 100%|██████████| 88/88 [02:38<00:00,  1.81s/it]



📊 Epoch 187 Summary:
   Avg G Loss: 483.0196
   Avg D Loss: -155.3135
   Avg Wasserstein: 232.9001
🖼️  Saved generated images for epoch 187
✅ Checkpoint saved successfully for epoch 186
💾 Epoch 187 completed. Checkpoint saved.


Epoch 188/200:   0%|          | 0/88 [00:00<?, ?it/s]

[188/200][0/88] Loss_D: -145.7444 Loss_G: 463.8416 GP: 8.3025 W_dist: 228.7691


Epoch 188/200:   1%|          | 1/88 [00:11<16:36, 11.46s/it]

✅ Checkpoint saved successfully for epoch 187


Epoch 188/200:  24%|██▍       | 21/88 [00:44<02:45,  2.47s/it]

✅ Checkpoint saved successfully for epoch 187


Epoch 188/200:  47%|████▋     | 41/88 [01:13<01:52,  2.40s/it]

✅ Checkpoint saved successfully for epoch 187


Epoch 188/200:  58%|█████▊    | 51/88 [01:21<00:43,  1.18s/it]

[188/200][50/88] Loss_D: -164.1176 Loss_G: 484.6324 GP: 8.6970 W_dist: 241.2760


Epoch 188/200:  69%|██████▉   | 61/88 [01:43<01:07,  2.50s/it]

✅ Checkpoint saved successfully for epoch 187


Epoch 188/200:  92%|█████████▏| 81/88 [02:12<00:14,  2.14s/it]

✅ Checkpoint saved successfully for epoch 187


Epoch 188/200: 100%|██████████| 88/88 [02:16<00:00,  1.55s/it]



📊 Epoch 188 Summary:
   Avg G Loss: 481.9117
   Avg D Loss: -164.7181
   Avg Wasserstein: 232.5160
🖼️  Saved generated images for epoch 188
✅ Checkpoint saved successfully for epoch 187
💾 Epoch 188 completed. Checkpoint saved.


Epoch 189/200:   0%|          | 0/88 [00:00<?, ?it/s]

[189/200][0/88] Loss_D: -220.6870 Loss_G: 441.7076 GP: 5.1191 W_dist: 271.8784


Epoch 189/200:   1%|          | 1/88 [00:05<08:08,  5.61s/it]

✅ Checkpoint saved successfully for epoch 188


Epoch 189/200:  24%|██▍       | 21/88 [00:44<03:03,  2.74s/it]

✅ Checkpoint saved successfully for epoch 188


Epoch 189/200:  47%|████▋     | 41/88 [01:15<01:45,  2.24s/it]

✅ Checkpoint saved successfully for epoch 188


Epoch 189/200:  58%|█████▊    | 51/88 [01:26<00:54,  1.47s/it]

[189/200][50/88] Loss_D: -150.0056 Loss_G: 482.9015 GP: 7.7582 W_dist: 230.1839


Epoch 189/200:  69%|██████▉   | 61/88 [01:46<00:54,  2.01s/it]

✅ Checkpoint saved successfully for epoch 188


Epoch 189/200:  92%|█████████▏| 81/88 [02:19<00:17,  2.56s/it]

✅ Checkpoint saved successfully for epoch 188


Epoch 189/200: 100%|██████████| 88/88 [02:24<00:00,  1.64s/it]



📊 Epoch 189 Summary:
   Avg G Loss: 480.9448
   Avg D Loss: -149.0154
   Avg Wasserstein: 230.0408
🖼️  Saved generated images for epoch 189
✅ Checkpoint saved successfully for epoch 188
💾 Epoch 189 completed. Checkpoint saved.


Epoch 190/200:   0%|          | 0/88 [00:00<?, ?it/s]

[190/200][0/88] Loss_D: -184.3097 Loss_G: 480.5105 GP: 6.3680 W_dist: 247.9893


Epoch 190/200:   1%|          | 1/88 [00:12<18:09, 12.52s/it]

✅ Checkpoint saved successfully for epoch 189


Epoch 190/200:  24%|██▍       | 21/88 [00:43<02:10,  1.95s/it]

✅ Checkpoint saved successfully for epoch 189


Epoch 190/200:  47%|████▋     | 41/88 [01:14<01:33,  1.99s/it]

✅ Checkpoint saved successfully for epoch 189


Epoch 190/200:  58%|█████▊    | 51/88 [01:31<01:30,  2.43s/it]

[190/200][50/88] Loss_D: -172.2256 Loss_G: 481.0980 GP: 6.8985 W_dist: 228.0785


Epoch 190/200:  69%|██████▉   | 61/88 [01:46<00:56,  2.10s/it]

✅ Checkpoint saved successfully for epoch 189


Epoch 190/200:  92%|█████████▏| 81/88 [02:16<00:13,  1.96s/it]

✅ Checkpoint saved successfully for epoch 189


Epoch 190/200: 100%|██████████| 88/88 [02:20<00:00,  1.60s/it]



📊 Epoch 190 Summary:
   Avg G Loss: 485.9466
   Avg D Loss: -166.8772
   Avg Wasserstein: 231.4606
📉 Learning rate reduced to: 1.3421772800000009e-05
🖼️  Saved generated images for epoch 190
💾 Saved model weights for epoch 190
✅ Checkpoint saved successfully for epoch 189
💾 Epoch 190 completed. Checkpoint saved.


Epoch 191/200:   0%|          | 0/88 [00:00<?, ?it/s]

[191/200][0/88] Loss_D: -147.3006 Loss_G: 484.7753 GP: 5.7602 W_dist: 204.9022


Epoch 191/200:   1%|          | 1/88 [00:05<08:00,  5.52s/it]

✅ Checkpoint saved successfully for epoch 190


Epoch 191/200:  24%|██▍       | 21/88 [00:36<02:13,  2.00s/it]

✅ Checkpoint saved successfully for epoch 190


Epoch 191/200:  47%|████▋     | 41/88 [01:05<01:37,  2.07s/it]

✅ Checkpoint saved successfully for epoch 190


Epoch 191/200:  58%|█████▊    | 51/88 [01:15<00:45,  1.24s/it]

[191/200][50/88] Loss_D: -147.7907 Loss_G: 488.1047 GP: 7.3662 W_dist: 220.0658


Epoch 191/200:  69%|██████▉   | 61/88 [01:34<00:53,  1.98s/it]

✅ Checkpoint saved successfully for epoch 190


Epoch 191/200:  92%|█████████▏| 81/88 [02:03<00:13,  1.97s/it]

✅ Checkpoint saved successfully for epoch 190


Epoch 191/200: 100%|██████████| 88/88 [02:08<00:00,  1.46s/it]



📊 Epoch 191 Summary:
   Avg G Loss: 488.2689
   Avg D Loss: -159.8587
   Avg Wasserstein: 235.4449
🖼️  Saved generated images for epoch 191
✅ Checkpoint saved successfully for epoch 190
💾 Epoch 191 completed. Checkpoint saved.


Epoch 192/200:   0%|          | 0/88 [00:00<?, ?it/s]

[192/200][0/88] Loss_D: -193.5824 Loss_G: 477.3322 GP: 7.3640 W_dist: 267.2220


Epoch 192/200:   1%|          | 1/88 [00:12<18:36, 12.83s/it]

✅ Checkpoint saved successfully for epoch 191


Epoch 192/200:  24%|██▍       | 21/88 [00:41<02:28,  2.22s/it]

✅ Checkpoint saved successfully for epoch 191


Epoch 192/200:  47%|████▋     | 41/88 [01:11<01:41,  2.16s/it]

✅ Checkpoint saved successfully for epoch 191


Epoch 192/200:  58%|█████▊    | 51/88 [01:21<00:42,  1.15s/it]

[192/200][50/88] Loss_D: -153.5020 Loss_G: 487.3551 GP: 9.1457 W_dist: 237.0662


Epoch 192/200:  69%|██████▉   | 61/88 [01:42<01:01,  2.29s/it]

✅ Checkpoint saved successfully for epoch 191


Epoch 192/200:  92%|█████████▏| 81/88 [02:11<00:15,  2.18s/it]

✅ Checkpoint saved successfully for epoch 191


Epoch 192/200: 100%|██████████| 88/88 [02:16<00:00,  1.55s/it]



📊 Epoch 192 Summary:
   Avg G Loss: 486.0107
   Avg D Loss: -165.8953
   Avg Wasserstein: 233.0714
🖼️  Saved generated images for epoch 192
✅ Checkpoint saved successfully for epoch 191
💾 Epoch 192 completed. Checkpoint saved.


Epoch 193/200:   0%|          | 0/88 [00:00<?, ?it/s]

[193/200][0/88] Loss_D: -219.9183 Loss_G: 491.7072 GP: 3.2458 W_dist: 252.3765


Epoch 193/200:   1%|          | 1/88 [00:05<07:24,  5.11s/it]

✅ Checkpoint saved successfully for epoch 192


Epoch 193/200:  24%|██▍       | 21/88 [00:33<02:14,  2.01s/it]

✅ Checkpoint saved successfully for epoch 192


Epoch 193/200:  47%|████▋     | 41/88 [01:02<01:23,  1.77s/it]

✅ Checkpoint saved successfully for epoch 192


Epoch 193/200:  58%|█████▊    | 51/88 [01:18<01:23,  2.27s/it]

[193/200][50/88] Loss_D: -156.0754 Loss_G: 487.1824 GP: 7.7925 W_dist: 235.2051


Epoch 193/200:  69%|██████▉   | 61/88 [01:32<00:52,  1.96s/it]

✅ Checkpoint saved successfully for epoch 192


Epoch 193/200:  92%|█████████▏| 81/88 [02:05<00:15,  2.17s/it]

✅ Checkpoint saved successfully for epoch 192


Epoch 193/200: 100%|██████████| 88/88 [02:09<00:00,  1.47s/it]



📊 Epoch 193 Summary:
   Avg G Loss: 489.8462
   Avg D Loss: -156.0787
   Avg Wasserstein: 236.5941
🖼️  Saved generated images for epoch 193
✅ Checkpoint saved successfully for epoch 192
💾 Epoch 193 completed. Checkpoint saved.


Epoch 194/200:   0%|          | 0/88 [00:00<?, ?it/s]

[194/200][0/88] Loss_D: -50.0643 Loss_G: 501.8801 GP: 11.9358 W_dist: 169.4219


Epoch 194/200:   1%|          | 1/88 [00:06<09:18,  6.42s/it]

✅ Checkpoint saved successfully for epoch 193


Epoch 194/200:  24%|██▍       | 21/88 [00:41<02:24,  2.16s/it]

✅ Checkpoint saved successfully for epoch 193


Epoch 194/200:  47%|████▋     | 41/88 [01:11<01:27,  1.86s/it]

✅ Checkpoint saved successfully for epoch 193


Epoch 194/200:  58%|█████▊    | 51/88 [01:20<00:36,  1.01it/s]

[194/200][50/88] Loss_D: -157.9696 Loss_G: 496.4160 GP: 7.6409 W_dist: 229.1746


Epoch 194/200:  69%|██████▉   | 61/88 [01:40<00:59,  2.22s/it]

✅ Checkpoint saved successfully for epoch 193


Epoch 194/200:  92%|█████████▏| 81/88 [02:10<00:16,  2.42s/it]

✅ Checkpoint saved successfully for epoch 193


Epoch 194/200: 100%|██████████| 88/88 [02:14<00:00,  1.53s/it]



📊 Epoch 194 Summary:
   Avg G Loss: 492.6748
   Avg D Loss: -147.1527
   Avg Wasserstein: 232.1793
🖼️  Saved generated images for epoch 194
✅ Checkpoint saved successfully for epoch 193
💾 Epoch 194 completed. Checkpoint saved.


Epoch 195/200:   0%|          | 0/88 [00:00<?, ?it/s]

[195/200][0/88] Loss_D: -141.0205 Loss_G: 501.8983 GP: 6.9269 W_dist: 210.2899


Epoch 195/200:   1%|          | 1/88 [00:06<09:10,  6.32s/it]

✅ Checkpoint saved successfully for epoch 194


Epoch 195/200:  24%|██▍       | 21/88 [00:43<02:46,  2.48s/it]

✅ Checkpoint saved successfully for epoch 194


Epoch 195/200:  47%|████▋     | 41/88 [01:12<01:49,  2.32s/it]

✅ Checkpoint saved successfully for epoch 194


Epoch 195/200:  58%|█████▊    | 51/88 [01:21<00:39,  1.07s/it]

[195/200][50/88] Loss_D: -150.1340 Loss_G: 485.8425 GP: 7.6897 W_dist: 223.9058


Epoch 195/200:  69%|██████▉   | 61/88 [01:41<00:53,  1.99s/it]

✅ Checkpoint saved successfully for epoch 194


Epoch 195/200:  92%|█████████▏| 81/88 [02:10<00:14,  2.14s/it]

✅ Checkpoint saved successfully for epoch 194


Epoch 195/200: 100%|██████████| 88/88 [02:15<00:00,  1.54s/it]



📊 Epoch 195 Summary:
   Avg G Loss: 484.2170
   Avg D Loss: -149.4112
   Avg Wasserstein: 235.4900
📉 Learning rate reduced to: 1.0737418240000008e-05
🖼️  Saved generated images for epoch 195
✅ Checkpoint saved successfully for epoch 194
💾 Epoch 195 completed. Checkpoint saved.


Epoch 196/200:   0%|          | 0/88 [00:00<?, ?it/s]

[196/200][0/88] Loss_D: -228.4435 Loss_G: 485.6363 GP: 2.8812 W_dist: 257.2557


Epoch 196/200:   1%|          | 1/88 [00:06<08:56,  6.16s/it]

✅ Checkpoint saved successfully for epoch 195


Epoch 196/200:  24%|██▍       | 21/88 [00:34<02:31,  2.26s/it]

✅ Checkpoint saved successfully for epoch 195


Epoch 196/200:  47%|████▋     | 41/88 [01:06<01:32,  1.98s/it]

✅ Checkpoint saved successfully for epoch 195


Epoch 196/200:  58%|█████▊    | 51/88 [01:15<00:41,  1.12s/it]

[196/200][50/88] Loss_D: -158.9699 Loss_G: 493.9249 GP: 7.0180 W_dist: 236.0588


Epoch 196/200:  69%|██████▉   | 61/88 [01:35<00:52,  1.95s/it]

✅ Checkpoint saved successfully for epoch 195


Epoch 196/200:  92%|█████████▏| 81/88 [02:03<00:12,  1.76s/it]

✅ Checkpoint saved successfully for epoch 195


Epoch 196/200: 100%|██████████| 88/88 [02:07<00:00,  1.45s/it]



📊 Epoch 196 Summary:
   Avg G Loss: 491.3757
   Avg D Loss: -164.2313
   Avg Wasserstein: 234.8584
🖼️  Saved generated images for epoch 196
✅ Checkpoint saved successfully for epoch 195
💾 Epoch 196 completed. Checkpoint saved.


Epoch 197/200:   0%|          | 0/88 [00:00<?, ?it/s]

[197/200][0/88] Loss_D: -117.5037 Loss_G: 492.2493 GP: 9.0760 W_dist: 208.2639


Epoch 197/200:   1%|          | 1/88 [00:04<06:52,  4.74s/it]

✅ Checkpoint saved successfully for epoch 196


Epoch 197/200:  24%|██▍       | 21/88 [00:33<02:07,  1.91s/it]

✅ Checkpoint saved successfully for epoch 196


Epoch 197/200:  47%|████▋     | 41/88 [01:02<01:35,  2.02s/it]

✅ Checkpoint saved successfully for epoch 196


Epoch 197/200:  58%|█████▊    | 51/88 [01:12<00:46,  1.24s/it]

[197/200][50/88] Loss_D: -150.7611 Loss_G: 491.6723 GP: 8.4748 W_dist: 232.3374


Epoch 197/200:  69%|██████▉   | 61/88 [01:32<00:57,  2.11s/it]

✅ Checkpoint saved successfully for epoch 196


Epoch 197/200:  92%|█████████▏| 81/88 [02:04<00:16,  2.31s/it]

✅ Checkpoint saved successfully for epoch 196


Epoch 197/200: 100%|██████████| 88/88 [02:08<00:00,  1.46s/it]



📊 Epoch 197 Summary:
   Avg G Loss: 490.1603
   Avg D Loss: -155.2775
   Avg Wasserstein: 236.3759
🖼️  Saved generated images for epoch 197
✅ Checkpoint saved successfully for epoch 196
💾 Epoch 197 completed. Checkpoint saved.


Epoch 198/200:   0%|          | 0/88 [00:00<?, ?it/s]

[198/200][0/88] Loss_D: -142.3983 Loss_G: 510.5681 GP: 6.9995 W_dist: 212.3931


Epoch 198/200:   1%|          | 1/88 [00:12<18:24, 12.69s/it]

✅ Checkpoint saved successfully for epoch 197


Epoch 198/200:  24%|██▍       | 21/88 [00:41<02:20,  2.10s/it]

✅ Checkpoint saved successfully for epoch 197


Epoch 198/200:  47%|████▋     | 41/88 [01:21<01:38,  2.09s/it]

✅ Checkpoint saved successfully for epoch 197


Epoch 198/200:  58%|█████▊    | 51/88 [01:30<00:45,  1.22s/it]

[198/200][50/88] Loss_D: -167.1311 Loss_G: 498.1936 GP: 8.5751 W_dist: 232.3080


Epoch 198/200:  69%|██████▉   | 61/88 [01:50<00:52,  1.95s/it]

✅ Checkpoint saved successfully for epoch 197


Epoch 198/200:  92%|█████████▏| 81/88 [02:22<00:12,  1.85s/it]

✅ Checkpoint saved successfully for epoch 197


Epoch 198/200: 100%|██████████| 88/88 [02:28<00:00,  1.68s/it]



📊 Epoch 198 Summary:
   Avg G Loss: 495.8750
   Avg D Loss: -163.4584
   Avg Wasserstein: 235.1242
🖼️  Saved generated images for epoch 198
✅ Checkpoint saved successfully for epoch 197
💾 Epoch 198 completed. Checkpoint saved.


Epoch 199/200:   0%|          | 0/88 [00:00<?, ?it/s]

[199/200][0/88] Loss_D: -31.6336 Loss_G: 499.9905 GP: 24.8507 W_dist: 280.1410


Epoch 199/200:   1%|          | 1/88 [00:05<08:37,  5.95s/it]

✅ Checkpoint saved successfully for epoch 198


Epoch 199/200:  24%|██▍       | 21/88 [00:34<02:08,  1.91s/it]

✅ Checkpoint saved successfully for epoch 198


Epoch 199/200:  47%|████▋     | 41/88 [01:03<01:25,  1.83s/it]

✅ Checkpoint saved successfully for epoch 198


Epoch 199/200:  58%|█████▊    | 51/88 [01:19<01:41,  2.74s/it]

[199/200][50/88] Loss_D: -171.4927 Loss_G: 481.6232 GP: 7.6678 W_dist: 229.7490


Epoch 199/200:  69%|██████▉   | 61/88 [01:32<00:48,  1.79s/it]

✅ Checkpoint saved successfully for epoch 198


Epoch 199/200:  92%|█████████▏| 81/88 [02:03<00:14,  2.11s/it]

✅ Checkpoint saved successfully for epoch 198


Epoch 199/200: 100%|██████████| 88/88 [02:08<00:00,  1.46s/it]



📊 Epoch 199 Summary:
   Avg G Loss: 488.5639
   Avg D Loss: -160.4400
   Avg Wasserstein: 236.8000
🖼️  Saved generated images for epoch 199
✅ Checkpoint saved successfully for epoch 198
💾 Epoch 199 completed. Checkpoint saved.


Epoch 200/200:   0%|          | 0/88 [00:00<?, ?it/s]

[200/200][0/88] Loss_D: -137.2825 Loss_G: 488.0953 GP: 6.5719 W_dist: 203.0019


Epoch 200/200:   1%|          | 1/88 [00:04<06:40,  4.60s/it]

✅ Checkpoint saved successfully for epoch 199


Epoch 200/200:  24%|██▍       | 21/88 [00:34<02:21,  2.11s/it]

✅ Checkpoint saved successfully for epoch 199


Epoch 200/200:  47%|████▋     | 41/88 [01:06<01:51,  2.38s/it]

✅ Checkpoint saved successfully for epoch 199


Epoch 200/200:  58%|█████▊    | 51/88 [01:14<00:40,  1.10s/it]

[200/200][50/88] Loss_D: -139.1160 Loss_G: 484.8988 GP: 8.7766 W_dist: 224.4685


Epoch 200/200:  69%|██████▉   | 61/88 [01:35<01:00,  2.25s/it]

✅ Checkpoint saved successfully for epoch 199


Epoch 200/200:  92%|█████████▏| 81/88 [02:05<00:15,  2.17s/it]

✅ Checkpoint saved successfully for epoch 199


Epoch 200/200: 100%|██████████| 88/88 [02:09<00:00,  1.47s/it]



📊 Epoch 200 Summary:
   Avg G Loss: 489.1017
   Avg D Loss: -139.5694
   Avg Wasserstein: 231.5773
📉 Learning rate reduced to: 8.589934592000007e-06
🖼️  Saved generated images for epoch 200
💾 Saved model weights for epoch 200
✅ Checkpoint saved successfully for epoch 199
💾 Epoch 200 completed. Checkpoint saved.
💾 Saved final models
✅ GAN Training Complete!
