## Part 1: Imports and Setup

In [1]:
import torch
import torch.nn as nn
from torch.nn import BCELoss
import matplotlib.pyplot as plt
import numpy as np
import sys

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set seed for reproducibility
torch.manual_seed(42)
if device == "cuda":
    torch.cuda.manual_seed(42)


Using device: cuda


In [2]:
# Visualization helpers 
def plot_discriminator_analysis(real_preds, fake_preds, predictions, batch_size):
    """Visualize discriminator predictions and performance."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))

    # Plot 1: Histogram of predictions
    ax = axes[0]
    real_preds_np = real_preds.squeeze().cpu().numpy()
    fake_preds_np = fake_preds.squeeze().cpu().numpy()

    ax.hist(
        real_preds_np,
        bins=4,
        alpha=0.6,
        label="Real (target=1)",
        color="green",
        edgecolor="black",
    )
    ax.hist(
        fake_preds_np,
        bins=4,
        alpha=0.6,
        label="Fake (target=0)",
        color="red",
        edgecolor="black",
    )
    ax.axvline(
        0.5, color="black", linestyle="--", linewidth=2, label="Decision boundary"
    )
    ax.set_xlabel("Discriminator Output", fontsize=11)
    ax.set_ylabel("Frequency", fontsize=11)
    ax.set_title("Prediction Distribution (Untrained)", fontsize=12, fontweight="bold")
    ax.legend()
    ax.grid(alpha=0.3)

    # Plot 2: Individual predictions
    ax = axes[1]
    x_pos = np.arange(batch_size)
    predictions_np = predictions.squeeze().cpu().numpy()

    ax.scatter(
        x_pos[: batch_size // 2],
        predictions_np[: batch_size // 2],
        s=150,
        c="green",
        alpha=0.7,
        edgecolor="black",
        linewidth=2,
        label="Real",
    )
    ax.scatter(
        x_pos[batch_size // 2 :],
        predictions_np[batch_size // 2 :],
        s=150,
        c="red",
        alpha=0.7,
        edgecolor="black",
        linewidth=2,
        label="Fake",
    )
    ax.axhline(0.5, color="black", linestyle="--", linewidth=2, alpha=0.5)
    ax.set_xlabel("Sample Index", fontsize=11)
    ax.set_ylabel("Predicted Probability", fontsize=11)
    ax.set_title("Individual Predictions", fontsize=12, fontweight="bold")
    ax.set_ylim([-0.1, 1.1])
    ax.legend()
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.show()


print(" Visualization helpers loaded")


 Visualization helpers loaded


## Part 2: Load Models

Load the generator from Module 3 and create a discriminator.

In [None]:
# Import create_discriminator class
from models.basic_gan import create_generator, create_discriminator

# Create models
generator = create_generator(latent_dim=100, device=device)
generator.eval()  # Evaluation mode

discriminator = create_discriminator(device=device)
discriminator.eval()  # Evaluation mode

print("✓ Models loaded")
print(f"  Generator params: {sum(p.numel() for p in generator.parameters()):,}")
print(f"  Discriminator params: {sum(p.numel() for p in discriminator.parameters()):,}")


## Part 3: Implement create_mixed_batch()

Complete the TODOs to create a mixed batch of real and fake images.

**What you need to do:**
- Split batch_size in half
- Generate fake images using generator
- Create placeholder real images
- Concatenate real and fake
- Create labels (1=real, 0=fake)
- Return (images, targets)

In [None]:
def create_mixed_batch(generator, batch_size=16, latent_dim=100, device="cpu"):
    """
    Create a mixed batch of real and fake images.

    Args:
        generator: Generator model
        batch_size: Total batch size
        latent_dim: Dimension of noise vector
        device: 'cpu' or 'cuda'

    Returns:
        mixed_images: (batch_size, 1, 28, 28)
        mixed_targets: (batch_size, 1) with values 0 or 1
    """

    # TODO 1: Split batch_size in half
    # You should get: half_batch = batch_size // 2

    # TODO 2: Generate fake images
    # - Sample noise from torch.randn(half_batch, latent_dim, device)
    # - Pass through generator with torch.no_grad()
    # - Reshape from (half_batch, 784) to (half_batch, 1, 28, 28)

    # TODO 3: Create synthetic real images
    # - Use torch.rand(half_batch, 1, 28, 28, device) for random images
    # - In practice, these would be loaded from MNIST

    # TODO 4: Mix real and fake images
    # - Use torch.cat([real_images, fake_images], dim=0)
    # - Result shape should be (batch_size, 1, 28, 28)

    # TODO 5: Create labels
    # - Real labels: torch.ones(half_batch, 1, device)
    # - Fake labels: torch.zeros(half_batch, 1, device)
    # - Concatenate them

    # TODO 6: Return mixed images and targets
    # return mixed_images, mixed_targets
    pass


print(" create_mixed_batch() defined (TODOs to complete)")


## Part 4: Test Your Implementation

**Verification cells** - Run these to check your work.

In [None]:
# Call your function
batch_size = 16
mixed_images, mixed_targets = create_mixed_batch(generator, batch_size, device=device)

# Check shapes
print("Shape checks:")
print(f"  Images shape: {mixed_images.shape}")
assert mixed_images.shape == (
    batch_size,
    1,
    28,
    28,
), f"Images should be (16, 1, 28, 28), got {mixed_images.shape}"
print(f"  Images shape correct")

print(f"\n  Targets shape: {mixed_targets.shape}")
assert mixed_targets.shape == (
    batch_size,
    1,
), f"Targets should be (16, 1), got {mixed_targets.shape}"
print(f"  Targets shape correct")

# Check label values
print(f"\nLabel checks:")
print(
    f"  First {batch_size//2} targets (should be ~1.0): {mixed_targets[:batch_size//2].squeeze().tolist()}"
)
print(
    f"  Last {batch_size//2} targets (should be ~0.0): {mixed_targets[batch_size//2:].squeeze().tolist()}"
)

assert torch.allclose(
    mixed_targets[: batch_size // 2], torch.ones(batch_size // 2, 1, device=device)
), "First half should be 1"
print(f"  Real labels correct (1.0)")

assert torch.allclose(
    mixed_targets[batch_size // 2 :], torch.zeros(batch_size // 2, 1, device=device)
), "Second half should be 0"
print(f"   Fake labels correct (0.0)")

print("\n All shape and label checks passed!")


## Part 5: Discriminator Predictions

Run the discriminator on your mixed batch.

In [None]:
# Get discriminator predictions
with torch.no_grad():
    predictions = discriminator(mixed_images)

print(f"Predictions shape: {predictions.shape}")
print(f"Prediction values:")
print(predictions.squeeze().tolist())

# Separate predictions
real_preds = predictions[: batch_size // 2]
fake_preds = predictions[batch_size // 2 :]

print(
    f"\nReal image predictions (should be ~0.5 for untrained): {real_preds.squeeze().tolist()}"
)
print(
    f"Fake image predictions (should be ~0.5 for untrained): {fake_preds.squeeze().tolist()}"
)

print(f"\nMean real prediction: {real_preds.mean():.3f}")
print(f"Mean fake prediction: {fake_preds.mean():.3f}")
print(f"Overall mean: {predictions.mean():.3f}")
print("\n(For untrained discriminator, should all be ~0.5 = random guessing)")


## Part 6: Compute Binary Cross-Entropy Loss

In [None]:
# Compute BCE loss
criterion = BCELoss()
loss = criterion(predictions, mixed_targets)

print(f"Binary Cross-Entropy Loss: {loss.item():.4f}")
print(f"\nInterpretation:")
print(f"  - Random guessing (50%): loss ≈ 0.6931")
print(f"  - Our loss: {loss.item():.4f}")

import math

expected_random = -math.log(0.5)
print(f"  - Expected for random: {expected_random:.4f}")

if abs(loss.item() - expected_random) < 0.1:
    print(
        f"\n✓ Loss is close to random baseline (good! discriminator hasn't learned yet)"
    )
else:
    print(f"\n⚠ Loss differs from expected")


## Part 7: Analyze Classification Accuracy

In [None]:
# Classification accuracy at threshold 0.5
real_correct = (real_preds > 0.5).sum().item()
fake_correct = (fake_preds < 0.5).sum().item()
total_correct = real_correct + fake_correct

print(f"Discriminator Performance:")
print(
    f"  Real images predicted as real: {real_correct}/{batch_size//2} ({100*real_correct/(batch_size//2):.0f}%)"
)
print(
    f"  Fake images predicted as fake: {fake_correct}/{batch_size//2} ({100*fake_correct/(batch_size//2):.0f}%)"
)
print(f"\nOverall accuracy: {100*total_correct/batch_size:.1f}%")
print(f"\n(Expected ~50% for random guessing; 100% would be perfect)")


## Part 8: Visualization

In [None]:
plot_discriminator_analysis(real_preds, fake_preds, predictions, batch_size)

print(" Visualization complete")

## Summary and Key Takeaways

### What You Learned
1. **Mixed Batches**: Combining real and fake images with proper labels (1=real, 0=fake)
2. **Binary Classification**: Discriminator outputs probability of being real (0-1)
3. **BCE Loss**: ~0.693 loss indicates random guessing
4. **Untrained Performance**: ~50% accuracy before any training

### Key Concepts
- **Batch composition**: First half real, second half fake
- **Label convention**: 1.0 for real, 0.0 for fake
- **Loss baseline**: -log(0.5) ≈ 0.6931 for random classification
- **No gradients**: Use `torch.no_grad()` for inference
