In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# Create output directory
os.makedirs("synthetic_images", exist_ok=True)
# Hyperparameters
image_size = 64
batch_size = 64
latent_dim = 100
epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dummy dataset (replace with medical image dataset later)
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataset = torchvision.datasets.FakeData(image_size=(1, image_size, image_size), transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256 * 8 * 8),
            nn.BatchNorm1d(256 * 8 * 8),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 8, 8)),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),     # 64x64
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),    # 32x32
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),  # 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, 2, 1), # 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(256 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.net(x)

# Initialize models and optimizers
G = Generator().to(device)
D = Discriminator().to(device)
loss_fn = nn.BCELoss()
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training Loop
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)
        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)

        # Train Discriminator
        z = torch.randn(batch_size, latent_dim, device=device)
        fake_imgs = G(z)
        d_loss = loss_fn(D(real_imgs), real_labels) + loss_fn(D(fake_imgs.detach()), fake_labels)
        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # Train Generator
        g_loss = loss_fn(D(fake_imgs), real_labels)
        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()


    if epoch % 10 == 0:
        print(f"[Epoch {epoch}] D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
        save_image(fake_imgs[:16], f"synthetic_images/fake_epoch_{epoch}.png", normalize=True)
print("✅ Synthetic image generation complete.")

[Epoch 0] D Loss: 1.7627 | G Loss: 3.5350
[Epoch 10] D Loss: 0.5505 | G Loss: 9.3545
[Epoch 20] D Loss: 0.5841 | G Loss: 4.9787
[Epoch 30] D Loss: 0.7987 | G Loss: 5.3233
[Epoch 40] D Loss: 1.5441 | G Loss: 4.2200
[Epoch 50] D Loss: 1.1980 | G Loss: 2.5274
[Epoch 60] D Loss: 0.7230 | G Loss: 2.4803
[Epoch 70] D Loss: 0.9331 | G Loss: 2.8462
[Epoch 80] D Loss: 1.0909 | G Loss: 2.9899
[Epoch 90] D Loss: 0.8538 | G Loss: 2.5069
✅ Synthetic image generation complete.
