In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models
import requests
import json
import random
from tqdm.notebook import tqdm # Use notebook version for Colab
from PIL import Image
import cv2 # For edge detection (Sobel)

# Clear CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# Set random seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

PyTorch Version: 2.6.0+cu124
CUDA Available: True
CUDA Device Name: NVIDIA A100-SXM4-40GB


In [None]:
# Cell 2: Configuration

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Data parameters
QUICKDRAW_CATEGORIES = ['cat', 'dog', 'house', 'tree', 'bicycle', 'car', 'face', 'flower']
DATA_DIR = './data/quickdraw'
MAX_SAMPLES_PER_CATEGORY = 5000

# Model parameters
IMAGE_SIZE = 128
IN_CHANNELS = 1
OUT_CHANNELS = 1
FEATURES_G = 48 # Generator initial features
FEATURES_D = 64 # Discriminator (Critic) initial features

# --- Training parameters ---

# WGAN-GP Settings
GENERATOR_LR = 1e-4
DISCRIMINATOR_LR = 1e-4
ADAM_BETA1 = 0.0
ADAM_BETA2 = 0.9

# Loss Weights (Rebalanced)
MASK_WEIGHT = 10.0
PERCEPTUAL_WEIGHT = 0.5

# WGAN-GP Specific Parameters
LAMBDA_GP = 10
N_CRITIC = 5

BATCH_SIZE = 16
NUM_EPOCHS = 150 # Will likely not train on all 150

# Masking parameters
MASK_MODE = 'line_aware'
MASK_SQUARE_SIZE = IMAGE_SIZE // 4

# Checkpoint loading - set to False since this is our first run at using this latest architecture
LOAD_GENERATOR_CHECKPOINT = False
GENERATOR_CHECKPOINT_PATH = None

# Output directories
CHECKPOINT_DIR = './sketch_completion_checkpoints_gan_v9_wgan_attn'
SAMPLE_DIR = './sketch_completion_samples_gan_v9_wgan_attn'
LOSS_PLOT_FILE = 'sketch_completion_loss_gan_v9_wgan_attn.png'

# Create directories
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)

Using device: cuda


In [None]:
def download_quickdraw_dataset(categories, save_dir):
    """Downloads QuickDraw ndjson files for specified categories."""
    os.makedirs(save_dir, exist_ok=True)
    base_url = "https://storage.googleapis.com/quickdraw_dataset/full/simplified/"

    print(f"Starting QuickDraw download to {save_dir}...")
    for category in tqdm(categories, desc="Downloading categories"):
        url = f"{base_url}{category.replace(' ', '%20')}.ndjson" # Handle spaces
        file_path = os.path.join(save_dir, f"{category}.ndjson")

        if os.path.exists(file_path):
            continue

        try:
            response = requests.get(url, stream=True)
            response.raise_for_status()

            with open(file_path, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)

        except requests.exceptions.RequestException as e:
            print(f"Failed to download {category}: {e}")
            if os.path.exists(file_path):
                os.remove(file_path)
        except Exception as e:
             print(f"An unexpected error occurred for {category}: {e}")
             if os.path.exists(file_path):
                os.remove(file_path)

    print("QuickDraw download process finished.")

In [None]:
class SketchDataset(Dataset):
    def __init__(self, data_dir, categories, image_size=128, transform=None,
                 max_samples_per_category=1000, mask_mode='random', mask_square_size=32):
        self.data_dir = data_dir
        self.categories = categories
        self.image_size = image_size
        self.transform = transform
        self.max_samples = max_samples_per_category
        self.mask_mode = mask_mode
        self.mask_square_size = mask_square_size
        self.sketches = []
        self.labels = []
        print("Loading dataset...")
        self._load_data()
        print(f"Dataset loaded with {len(self.sketches)} samples.")

    def _load_data(self):
        for i, category in enumerate(self.categories):
            file_path = os.path.join(self.data_dir, f"{category}.ndjson")
            if not os.path.exists(file_path):
                print(f"Warning: File not found for category '{category}', skipping.")
                continue
            count = 0
            try:
                with open(file_path, 'r') as f:
                    for line in tqdm(f, desc=f"Loading '{category}'"):
                        if count >= self.max_samples:
                            break
                        try:
                            data = json.loads(line)
                            if data.get('recognized', False):
                                self.sketches.append(data['drawing'])
                                self.labels.append(i)
                                count += 1
                        except json.JSONDecodeError:
                            pass # skip bad lines
                        except KeyError:
                            pass # skip lines with missing keys
            except Exception as e:
                print(f"Error reading file {file_path}: {e}")

    def _sketch_to_image(self, drawing_strokes):
        image = np.zeros((self.image_size, self.image_size), dtype=np.uint8)
        for stroke in drawing_strokes:
            xs, ys = stroke[0], stroke[1]
            xs = np.array(xs) * (self.image_size - 1) / 255.0
            ys = np.array(ys) * (self.image_size - 1) / 255.0
            for i in range(len(xs) - 1):
                x1, y1 = int(round(xs[i])), int(round(ys[i]))
                x2, y2 = int(round(xs[i+1])), int(round(ys[i+1]))
                cv2.line(image, (x1, y1), (x2, y2), color=255, thickness=1)
        return image

    def _create_mask(self, complete_img_np):
        mask = np.zeros_like(complete_img_np, dtype=np.float32)
        h, w = complete_img_np.shape
        current_mask_mode = self.mask_mode
        y1, y2, x1, x2 = 0, h, 0, w

        if self.mask_mode == 'line_aware':
            sobelx = cv2.Sobel(complete_img_np, cv2.CV_64F, 1, 0, ksize=5)
            sobely = cv2.Sobel(complete_img_np, cv2.CV_64F, 0, 1, ksize=5)
            gradient_magnitude = np.sqrt(sobelx**2 + sobely**2)
            if gradient_magnitude.max() > 0:
                gradient_magnitude = (gradient_magnitude / gradient_magnitude.max()) * 255
            gradient_magnitude = gradient_magnitude.astype(np.uint8)
            threshold = 100
            high_grad_points = np.argwhere(gradient_magnitude > threshold)
            if len(high_grad_points) > 0:
                center_idx = random.choice(range(len(high_grad_points)))
                center_y, center_x = high_grad_points[center_idx]
                half_h = self.mask_square_size // 2
                half_w = self.mask_square_size // 2
                rem_h = self.mask_square_size % 2
                rem_w = self.mask_square_size % 2
                y1 = max(0, center_y - half_h)
                y2 = min(h, center_y + half_h + rem_h)
                x1 = max(0, center_x - half_w)
                x2 = min(w, center_x + half_w + rem_w)
            else:
                current_mask_mode = 'random'

        if current_mask_mode == 'random':
             max_y_start = h - self.mask_square_size
             max_x_start = w - self.mask_square_size
             if max_y_start >= 0 and max_x_start >= 0:
                 y1 = random.randint(0, max_y_start)
                 x1 = random.randint(0, max_x_start)
                 y2 = y1 + self.mask_square_size
                 x2 = x1 + self.mask_square_size
             # else: implement some other mask, if needed, for now we're sticking
             # to the line aware masking, with random masking as another option
             # though we're likely only doing line aware

        mask[y1:y2, x1:x2] = 1.0
        return mask

    def __len__(self):
        return len(self.sketches)

    def __getitem__(self, idx):
        drawing = self.sketches[idx]
        complete_img_np = self._sketch_to_image(drawing)
        mask_np = self._create_mask(complete_img_np)
        partial_img_np = complete_img_np * (1 - mask_np)
        complete_img_pil = Image.fromarray(complete_img_np)
        partial_img_pil = Image.fromarray(partial_img_np.astype(np.uint8))
        mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))

        if self.transform:
            complete_tensor = self.transform(complete_img_pil)
            partial_tensor = self.transform(partial_img_pil)
            mask_tensor = self.transform(mask_pil)
            mask_tensor = (mask_tensor > 0.5).float()
        else:
             to_tensor = transforms.ToTensor()
             complete_tensor = to_tensor(complete_img_pil)
             partial_tensor = to_tensor(partial_img_pil)
             mask_tensor = to_tensor(mask_pil)
             mask_tensor = (mask_tensor > 0.5).float()
        return {'complete': complete_tensor, 'partial': partial_tensor, 'mask': mask_tensor}

In [None]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        # Use a single convolution to compute attention map
        # Padding calculation to keep feature map size same: (kernel_size - 1) // 2
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Compute attention map across channels
        avg_out = torch.mean(x, dim=1, keepdim=True) # Average pooling across channels
        max_out, _ = torch.max(x, dim=1, keepdim=True) # Max pooling across channels
        channel_reduced = torch.cat([avg_out, max_out], dim=1)

        # Apply convolution and sigmoid to get spatial attention map
        attention_map = self.conv(channel_reduced)
        attention_map = self.sigmoid(attention_map)

        return x * attention_map

In [None]:
# Cell 6: Generator Definition

# U-Net Generator with Bottleneck Spatial Attention
class SimpleSketchUNetWithAttention(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=64):
        super(SimpleSketchUNetWithAttention, self).__init__()

        # Encoder Blocks
        self.enc1 = self._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = self._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc3 = self._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc4 = self._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck_conv = self._block(features * 8, features * 16, name="bottleneck")
        # Instantiate Attention Module for Bottleneck
        self.bottleneck_attn = SpatialAttention(kernel_size=7)

        # Decoder Blocks
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
        self.dec4 = self._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
        self.dec3 = self._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
        self.dec2 = self._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
        self.dec1 = self._block(features * 2, features, name="dec1")

        # Final Output Convolution
        self.final_conv = nn.Conv2d(features, out_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        print("Generator initialized with Simple Bottleneck Spatial Attention.")

    def forward(self, x):
        # --- Encoder ---
        e1 = self.enc1(x)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1)
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)
        p3 = self.pool3(e3)
        e4 = self.enc4(p3)
        p4 = self.pool4(e4)

        # --- Bottleneck with attention ---
        b_conv = self.bottleneck_conv(p4)
        b_attn = self.bottleneck_attn(b_conv)

        # --- Decoder (uses attended bottleneck features 'b_attn') ---
        d4 = self.upconv4(b_attn)
        cat4 = torch.cat((d4, e4), dim=1)
        dec4_out = self.dec4(cat4)
        d3 = self.upconv3(dec4_out)
        cat3 = torch.cat((d3, e3), dim=1)
        dec3_out = self.dec3(cat3)
        d2 = self.upconv2(dec3_out)
        cat2 = torch.cat((d2, e2), dim=1)
        dec2_out = self.dec2(cat2)
        d1 = self.upconv1(dec2_out)
        cat1 = torch.cat((d1, e1), dim=1)
        dec1_out = self.dec1(cat1)

        # --- Final Output ---
        output = self.final_conv(dec1_out)
        output = self.sigmoid(output)
        return output

    def _block(self, in_channels, features, name):
        return nn.Sequential(
            nn.Conv2d(in_channels, features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(features, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm2d(features, affine=True),
            nn.LeakyReLU(0.2, inplace=True),
        )

In [None]:
# Cell 7: Discriminator Definition (Now called a "Critic" since we're using WGAN-GP)

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, in_channels=1, features=64, n_layers=3, norm_layer=nn.InstanceNorm2d):
        """PatchGAN Critic for WGAN-GP."""
        super(Discriminator, self).__init__()

        use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4 # kernel size
        padw = 1 # padding
        sequence = [
            nn.Conv2d(in_channels, features, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # Gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(features * nf_mult_prev, features * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(features * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(features * nf_mult_prev, features * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(features * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        # Output 1 channel prediction map - no Sigmoid/Activation
        sequence += [nn.Conv2d(features * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
        self.model = nn.Sequential(*sequence)

        print(f"Discriminator (Critic PatchGAN for WGAN-GP) initialized ({n_layers} downsampling layers).")

    def forward(self, input):
        """Standard forward pass."""
        return self.model(input)

In [None]:
# Cell 8: Gradient Penalty Function

def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)

    d_interpolates = discriminator(interpolates)
    fake = torch.ones(d_interpolates.size(), requires_grad=False, device=device) # Use fake tensor as target for gradient computation

    # Get gradient w.r.t. interpolates
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True, # Create graph for second derivative
        retain_graph=True, # Retain graph for subsequent operations
        only_inputs=True,
    )[0] # Get the first output (gradients w.r.t inputs)

    gradients = gradients.view(gradients.size(0), -1) # Flatten gradients
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
class LSGANLoss(nn.Module):
    """Least Squares GAN loss."""
    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(LSGANLoss, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss = nn.MSELoss()

    def get_target_tensor(self, prediction, target_is_real):
        target_tensor = self.real_label if target_is_real else self.fake_label
        return target_tensor.expand_as(prediction).to(prediction.device)

    def __call__(self, prediction, target_is_real):
        target_tensor = self.get_target_tensor(prediction, target_is_real)
        loss = self.loss(prediction, target_tensor)
        return loss

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, feature_layers=None, use_input_norm=True, device='cuda'):
        super(VGGPerceptualLoss, self).__init__()
        self.vgg = models.vgg19(pretrained=True).features.to(device).eval() # Use weights=VGG19_Weights.DEFAULT in newer torchvision, forgot to update before the run so just using this older syntax for now
        if feature_layers is None:
            self.feature_layers = {'2', '7', '12', '21', '30'} # Default VGG19 layers
        else:
            self.feature_layers = set(str(l) for l in feature_layers)
        self.loss_fn = nn.L1Loss()
        self.use_input_norm = use_input_norm
        if use_input_norm:
            self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        for param in self.vgg.parameters():
            param.requires_grad = False
        print(f"VGGPerceptualLoss initialized. Using feature layers: {sorted(list(self.feature_layers))}")

    def _preprocess_input(self, x):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        if self.use_input_norm:
            x = self.normalize(x)
        return x

    def forward(self, input_img, target_img):
        input_proc = self._preprocess_input(input_img)
        target_proc = self._preprocess_input(target_img)
        loss = 0.0
        x_in, x_tgt = input_proc, target_proc
        for name, module in self.vgg._modules.items():
            x_in = module(x_in)
            x_tgt = module(x_tgt)
            if name in self.feature_layers:
                loss += self.loss_fn(x_in, x_tgt)
        return loss / len(self.feature_layers) if self.feature_layers else 0.0

In [None]:
# Cell 11: GAN Trainer Class

class GANTrainer:
    def __init__(self, generator, discriminator, dataloader, device,
                 lr_g, lr_d, adam_beta1, adam_beta2,
                 checkpoint_dir, sample_dir,
                 lambda_gp, n_critic,
                 mask_weight, perceptual_weight):
        self.generator = generator.to(device)
        self.discriminator = discriminator.to(device) # referred to as discriminator internally, but it's the Critic
        self.dataloader = dataloader
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        self.sample_dir = sample_dir
        self.lambda_gp = lambda_gp
        self.n_critic = n_critic
        self.mask_weight = mask_weight
        self.perceptual_weight = perceptual_weight

        # Optimizers (Using specific betas for WGAN)
        self.g_optimizer = optim.Adam(generator.parameters(), lr=lr_g, betas=(adam_beta1, adam_beta2))
        self.d_optimizer = optim.Adam(discriminator.parameters(), lr=lr_d, betas=(adam_beta1, adam_beta2))

        # Loss functions (Reconstruction losses only)
        self.l1_loss_fn = nn.L1Loss()
        # Keep VGG loss
        self.perceptual_loss_fn = VGGPerceptualLoss(use_input_norm=True, device=device)

        # Training history (Adjusted for WGAN)
        self.g_losses = [] # Total G loss
        self.d_losses = [] # Total D loss (Critic loss)
        self.g_adv_losses = [] # Adversarial component (-D(G(z)))
        self.g_l1_masked_losses = []
        self.g_perceptual_losses = []
        self.d_gp_losses = [] # Gradient Penalty loss
        self.wasserstein_distances = [] # D(real) - D(fake) approximation

        if len(dataloader) > 0:
            self.vis_batch = next(iter(dataloader))
        else:
            self.vis_batch = None
            print("WARNING: Dataloader empty.")
        self.current_epoch = 0
        self.g_steps = 0 # Track generator steps

    def train_epoch(self, epoch):
        self.generator.train()
        self.discriminator.train()
        running_g_loss = 0.0
        running_d_loss = 0.0
        running_g_adv = 0.0
        running_g_l1m = 0.0
        running_g_perc = 0.0
        running_d_gp = 0.0
        running_wasserstein_d = 0.0
        batches_processed = 0

        progress_bar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}", leave=False)

        for batch_idx, batch in enumerate(progress_bar):
            complete_sketches = batch['complete'].to(self.device) # Real images
            partial_sketches = batch['partial'].to(self.device) # Input to generator
            masks = batch['mask'].to(self.device)
            batch_size = complete_sketches.size(0) # Get current batch size

            # Accumulate critic loss over n_critic steps for averaging later
            d_loss_accum = 0.0
            gp_loss_accum = 0.0
            wd_accum = 0.0

            for _ in range(self.n_critic):
                self.d_optimizer.zero_grad()

                # Generate fake images
                fake_sketches = self.generator(partial_sketches)

                # Real images score
                real_output = self.discriminator(complete_sketches)
                # Fake images score
                fake_output = self.discriminator(fake_sketches.detach()) # detatch here for critic update

                # Gradient penalty
                gradient_penalty = compute_gradient_penalty(self.discriminator, complete_sketches.data, fake_sketches.data, self.device)

                # Adversarial loss (Wasserstein loss)
                # We want to maximize D(real) - D(fake), or minimize -(D(real) - D(fake)) = D(fake) - D(real)
                loss_d = fake_output.mean() - real_output.mean() + self.lambda_gp * gradient_penalty

                loss_d.backward()
                self.d_optimizer.step()

                # Track stats for averaging
                d_loss_accum += loss_d.item()
                gp_loss_accum += gradient_penalty.item()
                wd_accum += (real_output.mean() - fake_output.mean()).item() # Wasserstein Distance estimate


            # Average critic losses over n_critic steps
            avg_d_loss_batch = d_loss_accum / self.n_critic
            avg_gp_loss_batch = gp_loss_accum / self.n_critic
            avg_wd_batch = wd_accum / self.n_critic

            running_d_loss += avg_d_loss_batch * batch_size # Weight by batch size for overall epoch average
            running_d_gp += avg_gp_loss_batch * batch_size
            running_wasserstein_d += avg_wd_batch * batch_size

            # Generator update happens effectively once per main loop iteration
            self.g_optimizer.zero_grad()

            # Regenerate fake images required for grad calculation through G
            fake_sketches_for_g = self.generator(partial_sketches)
            # Calculate score for fake images (for generator loss)
            pred_fake_for_g = self.discriminator(fake_sketches_for_g)

            # Adversarial Loss (Generator wants high score for fakes -> minimize -D(G(z)))
            loss_g_adv = -pred_fake_for_g.mean()

            # Masked L1 Loss
            loss_masked = self.l1_loss_fn(fake_sketches_for_g * masks, complete_sketches * masks)

            # Perceptual Loss
            loss_g_perc = 0.0 # Default if weight is 0
            if self.perceptual_weight > 0:
                 loss_g_perc = self.perceptual_loss_fn(fake_sketches_for_g, complete_sketches)

            # Total Generator Loss
            loss_g = loss_g_adv + \
                       (self.mask_weight * loss_masked) + \
                       (self.perceptual_weight * loss_g_perc)

            loss_g.backward()
            self.g_optimizer.step()
            self.g_steps += 1 # Increment generator steps counter

            # --- Accumulate Generator Statistics ---
            running_g_loss += loss_g.item() * batch_size
            running_g_adv += loss_g_adv.item() * batch_size
            running_g_l1m += loss_masked.item() * batch_size
            running_g_perc += loss_g_perc.item() * batch_size # Will be 0 if weight is 0
            batches_processed += batch_size # Use total items processed for averaging


            # Update progress bar
            progress_bar.set_postfix({
                'G':f"{loss_g.item():.3f}",
                'D':f"{avg_d_loss_batch:.3f}", # Show per-batch avg D loss
                'GP':f"{avg_gp_loss_batch:.3f}",
                'W-Dist':f"{avg_wd_batch:.3f}",
                'G_Adv':f"{loss_g_adv.item():.3f}",
                'G_L1M':f"{loss_masked.item():.4f}",
                'G_Perc':f"{loss_g_perc.item():.3f}"
                })


        # Calculate average losses for the epoch based on total items processed
        if batches_processed == 0:
            return 0,0,0,0,0,0,0 # Handle empty dataloader case

        epoch_g_loss = running_g_loss / batches_processed
        epoch_d_loss = running_d_loss / batches_processed
        epoch_g_adv = running_g_adv / batches_processed
        epoch_g_l1m = running_g_l1m / batches_processed
        epoch_g_perc = running_g_perc / batches_processed
        epoch_d_gp = running_d_gp / batches_processed
        epoch_wd = running_wasserstein_d / batches_processed

        self.g_losses.append(epoch_g_loss)
        self.d_losses.append(epoch_d_loss)
        self.g_adv_losses.append(epoch_g_adv)
        self.g_l1_masked_losses.append(epoch_g_l1m)
        self.g_perceptual_losses.append(epoch_g_perc)
        self.d_gp_losses.append(epoch_d_gp)
        self.wasserstein_distances.append(epoch_wd)

        if self.vis_batch:
            self.save_samples(epoch)
        return epoch_g_loss, epoch_d_loss, epoch_g_adv, epoch_g_l1m, epoch_g_perc, epoch_d_gp, epoch_wd

    def save_samples(self, epoch, num_samples=4):
         self.generator.eval()
         with torch.no_grad():
              if not self.vis_batch:
                  return
              complete = self.vis_batch['complete'].to(self.device)
              partial = self.vis_batch['partial'].to(self.device)
              generated = self.generator(partial)
              num_samples = min(num_samples, complete.size(0))
              fig, axes = plt.subplots(num_samples, 3, figsize=(12, 3 * num_samples))
              fig.suptitle(f'Epoch {epoch+1} Samples (WGAN-GP)', fontsize=16)
              for i in range(num_samples):
                  p_img = partial[i].cpu().squeeze().numpy()
                  g_img = generated[i].cpu().squeeze().numpy()
                  c_img = complete[i].cpu().squeeze().numpy()
                  axes[i, 0].imshow(p_img, cmap='gray', vmin=0, vmax=1)
                  axes[i, 0].set_title('Partial Input')
                  axes[i, 0].axis('off')
                  axes[i, 1].imshow(g_img, cmap='gray', vmin=0, vmax=1)
                  axes[i, 1].set_title('Generated Output')
                  axes[i, 1].axis('off')
                  axes[i, 2].imshow(c_img, cmap='gray', vmin=0, vmax=1)
                  axes[i, 2].set_title('Ground Truth')
                  axes[i, 2].axis('off')
              plt.tight_layout(rect=[0, 0.03, 1, 0.95])
              sample_path = os.path.join(self.sample_dir, f'sample_epoch_{epoch+1:03d}.png')
              plt.savefig(sample_path)
              plt.close(fig)
         self.generator.train()

    def save_checkpoint(self, epoch):
        checkpoint = {'epoch': epoch,'generator_state_dict': self.generator.state_dict(),'discriminator_state_dict': self.discriminator.state_dict(),'g_optimizer_state_dict': self.g_optimizer.state_dict(),'d_optimizer_state_dict': self.d_optimizer.state_dict(),'g_losses': self.g_losses, 'd_losses': self.d_losses,'g_adv_losses': self.g_adv_losses, 'g_l1_masked_losses': self.g_l1_masked_losses,'g_perceptual_losses': self.g_perceptual_losses, 'd_gp_losses': self.d_gp_losses, 'wasserstein_distances': self.wasserstein_distances}
        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch+1:03d}.pth')
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved to {checkpoint_path}")

    def load_checkpoint(self, checkpoint_path):
        if not os.path.exists(checkpoint_path):
            print(f"Checkpoint not found at {checkpoint_path}. Starting from scratch.")
            return 0
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        try:
            self.generator.load_state_dict(checkpoint.get('generator_state_dict'))
            print(" G state loaded.")
        except Exception as e:
            print(f" ERROR loading G state: {e}. Weights not loaded.")
            return 0
        try:
            self.discriminator.load_state_dict(checkpoint.get('discriminator_state_dict'))
            print(" D state loaded.")
        except Exception as e:
            print(f" WARNING loading D state: {e}.")
        try:
            self.g_optimizer.load_state_dict(checkpoint.get('g_optimizer_state_dict'))
            print(" G Optim loaded.")
        except Exception as e:
            print(f" WARNING loading G Optim state: {e}.")
        try:
            self.d_optimizer.load_state_dict(checkpoint.get('d_optimizer_state_dict'))
            print(" D Optim loaded.")
        except Exception as e:
            print(f" WARNING loading D Optim state: {e}.")
        self.g_losses = checkpoint.get('g_losses', [])
        self.d_losses = checkpoint.get('d_losses', [])
        self.g_adv_losses = checkpoint.get('g_adv_losses', [])
        self.g_l1_masked_losses = checkpoint.get('g_l1_masked_losses', [])
        self.g_perceptual_losses = checkpoint.get('g_perceptual_losses', [])
        self.d_gp_losses = checkpoint.get('d_gp_losses', [])
        self.wasserstein_distances = checkpoint.get('wasserstein_distances', [])
        start_epoch = checkpoint.get('epoch', -1) + 1
        self.current_epoch = start_epoch
        print(f"Resuming training from epoch {start_epoch}")
        return start_epoch

    def plot_losses(self, save_path):
        if not self.g_losses or not self.d_losses:
            print("No loss history to plot.")
            return
        epochs = range(1, len(self.g_losses) + 1)
        plt.figure(figsize=(14, 10))

        # Plot G and D losses (Critic Loss for D)
        plt.subplot(2, 2, 1)
        plt.plot(epochs, self.g_losses, label='G Total', color='blue')
        plt.plot(epochs, self.d_losses, label='D (Critic) Total', color='red')
        plt.title('Overall WGAN Losses')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Plot G components
        plt.subplot(2, 2, 2)
        plt.plot(epochs, self.g_adv_losses, label=f'G Adv Term [-D(G(z))]', linestyle='-', color='green') # No adv weight needed here
        plt.plot(epochs, [l * self.mask_weight for l in self.g_l1_masked_losses], label=f'G L1 Masked (w={self.mask_weight})', linestyle='--', color='orange')
        if self.perceptual_weight > 0:
            plt.plot(epochs, [l * self.perceptual_weight for l in self.g_perceptual_losses], label=f'G Perceptual (w={self.perceptual_weight})', linestyle='-.', color='cyan')
        plt.title('Generator Loss Components (Weighted)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        # Plot D components (W-Dist and GP)
        plt.subplot(2, 2, 3)
        plt.plot(epochs, self.wasserstein_distances, label='Wasserstein Dist Est (D(x)-D(G(z)))', linestyle='-', color='purple')
        plt.plot(epochs, [l * self.lambda_gp for l in self.d_gp_losses], label=f'Gradient Penalty (Term * {self.lambda_gp})', linestyle='--', color='brown')
        plt.title('Critic Loss Components')
        plt.xlabel('Epoch')
        plt.ylabel('Value / Loss Term')
        plt.legend()
        plt.grid(True)

        # Plot Unweighted G components
        plt.subplot(2, 2, 4)
        plt.plot(epochs, self.g_adv_losses, label=f'G Adv Term', linestyle='-', color='green')
        plt.plot(epochs, self.g_l1_masked_losses, label=f'G L1 Masked', linestyle='--', color='orange')
        if self.perceptual_weight > 0 and self.g_perceptual_losses:
            plt.plot(epochs, self.g_perceptual_losses, label=f'G Perceptual', linestyle='-.', color='cyan')
        plt.title('G Loss Components (Unweighted)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.savefig(save_path)
        print(f"Loss plot saved to {save_path}")
        plt.show()

    def train(self, num_epochs, start_epoch=0):
        print(f"Starting WGAN-GP Training (Epochs {start_epoch+1} to {start_epoch + num_epochs}) [L1MW={self.mask_weight}, PercepW={self.perceptual_weight}, GP={self.lambda_gp}, N_Critic={self.n_critic}]")
        for epoch in range(start_epoch, start_epoch + num_epochs):
            self.current_epoch = epoch
            g_loss, d_loss, g_adv, g_l1m, g_perc, d_gp, wd = self.train_epoch(epoch)
            print(f"Epoch {epoch+1}/{start_epoch + num_epochs} | G Loss: {g_loss:.4f} | D Loss: {d_loss:.4f} | GP: {d_gp:.4f} | W-Dist: {wd:.4f} | G Adv: {g_adv:.4f} | G L1M: {g_l1m:.4f} | G Perc: {g_perc:.4f}")
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch)
        if num_epochs > 0:
            self.save_checkpoint(start_epoch + num_epochs - 1)
        if self.g_losses:
            self.plot_losses(LOSS_PLOT_FILE)
        print("Training finished.")

In [None]:
# Cell 12: Main Execution Block

def main():
    print("--- Initializing WGAN-GP + Attention Training Stage ---")
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
    ])

    print("Creating dataset...")
    dataset = SketchDataset(
        data_dir=DATA_DIR, categories=QUICKDRAW_CATEGORIES, image_size=IMAGE_SIZE, transform=transform,
        max_samples_per_category=MAX_SAMPLES_PER_CATEGORY, mask_mode=MASK_MODE, mask_square_size=MASK_SQUARE_SIZE
    )
    if len(dataset) == 0:
        print("ERROR: Dataset empty!")
        return

    print("Creating dataloader...")
    dataloader = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,
        pin_memory=True if DEVICE == 'cuda' else False, prefetch_factor=2 if DEVICE=='cuda' else None
    )

    # --- Initialize Models ---
    print("Creating Generator (with Attention)...")
    generator = SimpleSketchUNetWithAttention(
        in_channels=IN_CHANNELS, out_channels=OUT_CHANNELS, features=FEATURES_G
    ).to(DEVICE)

    print("Creating Discriminator (Critic for WGAN-GP)...")
    discriminator = Discriminator(
        in_channels=IN_CHANNELS, features=FEATURES_D
    ).to(DEVICE)

    # --- Load Generator Checkpoint ---
    generator_loaded_successfully = False
    if LOAD_GENERATOR_CHECKPOINT:
         if GENERATOR_CHECKPOINT_PATH and os.path.exists(GENERATOR_CHECKPOINT_PATH):
             print(f"Loading GENERATOR weights ONLY from: {GENERATOR_CHECKPOINT_PATH}")
             loaded_checkpoint = torch.load(GENERATOR_CHECKPOINT_PATH, map_location=DEVICE)
             if 'generator_state_dict' in loaded_checkpoint:
                 try:
                    generator.load_state_dict(loaded_checkpoint['generator_state_dict'])
                    print(" G weights loaded.")
                    generator_loaded_successfully = True
                 except Exception as e:
                    print(f" WARN: Error loading G weights: {e}. Training from scratch.")
             else:
                print(f" WARN: 'generator_state_dict' not found in {GENERATOR_CHECKPOINT_PATH}. Training from scratch.")
         else:
            print(f" WARN: Checkpoint specified but not found at {GENERATOR_CHECKPOINT_PATH}. Training from scratch.")
    if not generator_loaded_successfully:
        print("Generator will be trained from scratch.")


    g_params = sum(p.numel() for p in generator.parameters() if p.requires_grad)
    d_params = sum(p.numel() for p in discriminator.parameters() if p.requires_grad)
    print(f"Generator (Attn) parameter count: {g_params:,}")
    print(f"Discriminator (Critic) parameter count: {d_params:,}")

    print("Creating WGAN-GP trainer...")
    trainer = GANTrainer(
        generator=generator, discriminator=discriminator, dataloader=dataloader, device=DEVICE,
        lr_g=GENERATOR_LR, lr_d=DISCRIMINATOR_LR,
        adam_beta1=ADAM_BETA1, adam_beta2=ADAM_BETA2,
        checkpoint_dir=CHECKPOINT_DIR, sample_dir=SAMPLE_DIR,
        lambda_gp=LAMBDA_GP, n_critic=N_CRITIC,
        mask_weight=MASK_WEIGHT, perceptual_weight=PERCEPTUAL_WEIGHT
        # adversarial_weight is handled internally by WGAN loss structure
    )

    latest_checkpoint_path = None
    if os.path.exists(CHECKPOINT_DIR):
        checkpoints = sorted([f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')])
        if checkpoints:
            latest_checkpoint_path = os.path.join(CHECKPOINT_DIR, checkpoints[-1])

    start_epoch = 0
    if latest_checkpoint_path:
        print(f"Attempting to resume training from latest V9 checkpoint: {latest_checkpoint_path}")
        start_epoch = trainer.load_checkpoint(latest_checkpoint_path)
        if start_epoch == 0:
            print("Resuming failed, will start from epoch 0.")

    # --- Start Training ---
    remaining_epochs = NUM_EPOCHS - start_epoch
    if remaining_epochs > 0:
        trainer.train(num_epochs=remaining_epochs, start_epoch=start_epoch)
    else:
        print("Training already completed according to epochs and found checkpoint.")
        if trainer.g_losses:
            trainer.plot_losses(LOSS_PLOT_FILE)

# --- Run Main ---
if __name__ == '__main__':
    if not os.path.exists(DATA_DIR) or not any(fname.endswith('.ndjson') for fname in os.listdir(DATA_DIR)):
         print("Data directory not found or empty, downloading...")
         download_quickdraw_dataset(QUICKDRAW_CATEGORIES, DATA_DIR)
    else:
        print("Data directory found.")
    main()

Data directory not found or empty, downloading...
Starting QuickDraw download to ./data/quickdraw...


Downloading categories:   0%|          | 0/8 [00:00<?, ?it/s]

QuickDraw download process finished.
--- Initializing WGAN-GP + Attention Training Stage ---
Creating dataset...
Loading dataset...


Loading 'cat': 0it [00:00, ?it/s]

Loading 'dog': 0it [00:00, ?it/s]

Loading 'house': 0it [00:00, ?it/s]

Loading 'tree': 0it [00:00, ?it/s]

Loading 'bicycle': 0it [00:00, ?it/s]

Loading 'car': 0it [00:00, ?it/s]

Loading 'face': 0it [00:00, ?it/s]

Loading 'flower': 0it [00:00, ?it/s]

Dataset loaded with 40000 samples.
Creating dataloader...
Creating Generator (with Attention)...
Generator initialized WITH Simple Bottleneck Spatial Attention.
Creating Discriminator (Critic for WGAN-GP)...
Discriminator (Critic PatchGAN for WGAN-GP) initialized (3 downsampling layers).
Generator will be trained from scratch.
Generator (Attn) parameter count: 17,460,627
Discriminator (Critic) parameter count: 2,762,689
Creating WGAN-GP trainer...


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:02<00:00, 221MB/s]


VGGPerceptualLoss initialized. Using feature layers: ['12', '2', '21', '30', '7']
Starting WGAN-GP Training (Epochs 1 to 150) [L1MW=10.0, PercepW=0.5, GP=10, N_Critic=5]


Epoch 1:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 1/150 | G Loss: 1.1476 | D Loss: 3.1395 | GP: 0.3211 | W-Dist: 0.0718 | G Adv: -0.0147 | G L1M: 0.0102 | G Perc: 2.1198


Epoch 2:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 2/150 | G Loss: 0.2382 | D Loss: -0.0009 | GP: 0.0005 | W-Dist: 0.0054 | G Adv: -0.1295 | G L1M: 0.0041 | G Perc: 0.6526


Epoch 3:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 3/150 | G Loss: 0.0904 | D Loss: -0.0012 | GP: 0.0003 | W-Dist: 0.0038 | G Adv: -0.2521 | G L1M: 0.0040 | G Perc: 0.6058


Epoch 4:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 4/150 | G Loss: 0.0585 | D Loss: -0.0016 | GP: 0.0002 | W-Dist: 0.0040 | G Adv: -0.2756 | G L1M: 0.0039 | G Perc: 0.5893


Epoch 5:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 5/150 | G Loss: 0.5258 | D Loss: -0.0017 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: 0.1975 | G L1M: 0.0039 | G Perc: 0.5783
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_005.pth


Epoch 6:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 6/150 | G Loss: 0.5761 | D Loss: -0.0019 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: 0.2520 | G L1M: 0.0039 | G Perc: 0.5699


Epoch 7:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 7/150 | G Loss: 0.2055 | D Loss: -0.0015 | GP: 0.0002 | W-Dist: 0.0040 | G Adv: -0.1162 | G L1M: 0.0039 | G Perc: 0.5656


Epoch 8:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 8/150 | G Loss: 0.1571 | D Loss: -0.0017 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -0.1620 | G L1M: 0.0039 | G Perc: 0.5605


Epoch 9:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 9/150 | G Loss: -0.4449 | D Loss: -0.0016 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -0.7610 | G L1M: 0.0039 | G Perc: 0.5546


Epoch 10:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 10/150 | G Loss: -1.1932 | D Loss: -0.0019 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.5065 | G L1M: 0.0039 | G Perc: 0.5496
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_010.pth


Epoch 11:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 11/150 | G Loss: -1.3025 | D Loss: -0.0022 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.6147 | G L1M: 0.0039 | G Perc: 0.5472


Epoch 12:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 12/150 | G Loss: -1.1173 | D Loss: -0.0019 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.4275 | G L1M: 0.0038 | G Perc: 0.5435


Epoch 13:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 13/150 | G Loss: -1.4097 | D Loss: -0.0021 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.7198 | G L1M: 0.0039 | G Perc: 0.5430


Epoch 14:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 14/150 | G Loss: -1.5502 | D Loss: -0.0020 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.8578 | G L1M: 0.0039 | G Perc: 0.5380


Epoch 15:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 15/150 | G Loss: -1.1339 | D Loss: -0.0017 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.4401 | G L1M: 0.0038 | G Perc: 0.5356
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_015.pth


Epoch 16:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 16/150 | G Loss: -1.1196 | D Loss: -0.0023 | GP: 0.0002 | W-Dist: 0.0043 | G Adv: -1.4257 | G L1M: 0.0038 | G Perc: 0.5354


Epoch 17:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 17/150 | G Loss: -1.8396 | D Loss: -0.0020 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -2.1440 | G L1M: 0.0038 | G Perc: 0.5321


Epoch 18:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 18/150 | G Loss: -2.0137 | D Loss: -0.0020 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -2.3156 | G L1M: 0.0038 | G Perc: 0.5273


Epoch 19:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 19/150 | G Loss: -1.9238 | D Loss: -0.0020 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -2.2265 | G L1M: 0.0038 | G Perc: 0.5286


Epoch 20:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 20/150 | G Loss: -1.2738 | D Loss: -0.0023 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.5765 | G L1M: 0.0038 | G Perc: 0.5283
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_020.pth


Epoch 21:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 21/150 | G Loss: -1.6275 | D Loss: -0.0021 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.9303 | G L1M: 0.0038 | G Perc: 0.5287


Epoch 22:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 22/150 | G Loss: -1.6264 | D Loss: -0.0021 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.9262 | G L1M: 0.0038 | G Perc: 0.5231


Epoch 23:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 23/150 | G Loss: -1.3934 | D Loss: -0.0021 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.6927 | G L1M: 0.0038 | G Perc: 0.5221


Epoch 24:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 24/150 | G Loss: -1.2043 | D Loss: -0.0024 | GP: 0.0002 | W-Dist: 0.0042 | G Adv: -1.5038 | G L1M: 0.0038 | G Perc: 0.5222


Epoch 25:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 25/150 | G Loss: -1.0646 | D Loss: -0.0022 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.3619 | G L1M: 0.0038 | G Perc: 0.5184
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_025.pth


Epoch 26:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 26/150 | G Loss: -1.0964 | D Loss: -0.0021 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.3943 | G L1M: 0.0038 | G Perc: 0.5195


Epoch 27:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 27/150 | G Loss: -1.6495 | D Loss: -0.0017 | GP: 0.0002 | W-Dist: 0.0040 | G Adv: -1.9469 | G L1M: 0.0038 | G Perc: 0.5186


Epoch 28:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 28/150 | G Loss: -1.2028 | D Loss: -0.0022 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.4994 | G L1M: 0.0038 | G Perc: 0.5167


Epoch 29:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 29/150 | G Loss: -1.0388 | D Loss: -0.0025 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -1.3351 | G L1M: 0.0038 | G Perc: 0.5161


Epoch 30:   0%|          | 0/2500 [00:00<?, ?it/s]

Epoch 30/150 | G Loss: -0.4664 | D Loss: -0.0025 | GP: 0.0002 | W-Dist: 0.0041 | G Adv: -0.7617 | G L1M: 0.0038 | G Perc: 0.5142
Checkpoint saved to ./sketch_completion_checkpoints_gan_v9_wgan_attn/checkpoint_epoch_030.pth


Epoch 31:   0%|          | 0/2500 [00:00<?, ?it/s]

KeyboardInterrupt: 