In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import os
# useful v1 functions
import import_ipynb 
from metamaterials_GAN_v1 import plot_shape, load_item, quarter, dataset, dataloader

if __name__ == "__main__":
    print("Torch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "No GPU detected")

Torch version: 2.6.0
CUDA available: False
CUDA version: None
Number of GPUs: 0
GPU name: No GPU detected


In [2]:
import torch
import torch.nn as nn
import torch.autograd as autograd

# Generator (unchanged)
class Generator(nn.Module):
    def __init__(self, cond_dim=8):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(cond_dim, 128), nn.ReLU(True),
            nn.Linear(128, 512),       nn.ReLU(True),
        )
        self.fc_img = nn.Linear(512, 64 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(16,  1, 4, 2, 1), nn.Sigmoid(),
        )
        self.fc_params = nn.Sequential(
            nn.Linear(512, 32), nn.ReLU(True),
            nn.Linear(32, 4),
        )
    
    def forward(self, cond):
        x = self.fc(cond)
        img_feats = self.fc_img(x).view(-1, 64, 4, 4)
        waveguide = self.deconv(img_feats)
        params    = self.fc_params(x)
        return waveguide, params


# Critic (WGAN-GP)
class Critic(nn.Module):
    def __init__(self, cond_dim=8):
        super(Critic, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(16,32, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(32,64, 4, 2, 1), nn.LeakyReLU(0.2, True),
        )
        self.fc_image = nn.Linear(64*4*4, 128)
        self.fc_params = nn.Sequential(
            nn.Linear(4, 16), nn.LeakyReLU(0.2, True)
        )
        self.fc_cond = nn.Sequential(
            nn.Linear(cond_dim, 16), nn.LeakyReLU(0.2, True)
        )
        # No Sigmoid: output an unconstrained scalar
        self.fc_final = nn.Sequential(
            nn.Linear(128 + 16 + 16, 64),
            nn.LeakyReLU(0.2, True),
            nn.Linear(64, 1),
        )
    
    def forward(self, waveguide, params, cond):
        bsz  = waveguide.size(0)
        x_img = self.cnn(waveguide).view(bsz, -1)
        x_img = self.fc_image(x_img)
        x_p   = self.fc_params(params)
        x_c   = self.fc_cond(cond)
        x     = torch.cat([x_img, x_p, x_c], dim=1)
        return self.fc_final(x)


# Gradient Penalty Helper
def compute_gradient_penalty(critic, real_imgs, fake_imgs, real_params, cond, device, λ_gp=10.0):
    bsz = real_imgs.size(0)
    # interpolation factor
    α = torch.rand(bsz, 1, 1, 1, device=device)
    interpolates = (α * real_imgs + (1 - α) * fake_imgs).requires_grad_(True)
    # critic output on interpolated samples
    interp_scores = critic(interpolates, real_params, cond)
    # gradients of scores w.r.t. interpolates
    grads = autograd.grad(
        outputs=interp_scores,
        inputs=interpolates,
        grad_outputs=torch.ones_like(interp_scores),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    grads = grads.view(bsz, -1)
    # gradient penalty: (||∇||₂ − 1)²
    gp = λ_gp * ((grads.norm(2, dim=1) - 1) ** 2).mean()
    return gp


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Models
generator = Generator(cond_dim=8).to(device)
critic    = Critic(cond_dim=8).to(device)

# Hyper‑parameters
lr         = 1e-4        # slightly higher than RMSprop WGAN
n_critic   = 5           # critic steps per generator step
lambda_gp  = 10.0        # gradient penalty weight

# Optimizers (WGAN‑GP uses Adam safely)
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_c = optim.Adam(critic.parameters(),    lr=lr, betas=(0.5, 0.9))

In [None]:
if __name__ == "__main__":
    import matplotlib
    matplotlib.use('TkAgg')  # Or 'QtAgg', 'WXAgg'
    import matplotlib.pyplot as plt

    # — Hyperparameters —
    epochs       = 100
    lambda_param = 10.0    # weight for parameter‐MSE term

    # — Models (assumes Generator, Critic, compute_gradient_penalty are defined above) —
    generator = Generator(cond_dim=8).to(device)
    critic    = Critic(cond_dim=8).to(device)

    # — Loss for the parameter regression branch —
    criterion_param = nn.MSELoss()

    # — Optimizers (WGAN-GP typically uses Adam) —
    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
    optimizer_c = optim.Adam(critic.parameters(),    lr=lr, betas=(0.5, 0.9))

    # — Tracking losses —
    train_c_losses = []
    train_g_losses = []

    # — Interactive plotting setup —
    plt.ion()
    fig, ax = plt.subplots(figsize=(8,5))

    for epoch in range(1, epochs+1):
        start = time.perf_counter()
        c_epoch = []
        g_epoch = []

        for em, wts, prm, real_wg in dataloader:
            # Move data to device
            em, wts, prm = em.to(device), wts.to(device), prm.to(device)
            real_wg      = real_wg.to(device).unsqueeze(1)  # (B,1,32,32)
            cond         = torch.cat([em, wts], dim=1)
            bsz          = em.size(0)

            # — Train Critic n_critic times —
            for _ in range(n_critic):
                critic.zero_grad()
                # Real + fake scores
                fake_wg, fake_prm = generator(cond)
                real_score = critic(real_wg,    prm, cond).mean()
                fake_score = critic(fake_wg.detach(), fake_prm.detach(), cond).mean()
                # Gradient penalty
                gp = compute_gradient_penalty(
                    critic, real_wg, fake_wg.detach(), prm, cond, device, λ_gp=lambda_gp
                )
                # Critic loss
                c_loss = fake_score - real_score + gp
                c_loss.backward()
                optimizer_c.step()

            c_epoch.append(c_loss.item())

            # — Train Generator —
            generator.zero_grad()
            fake_wg2, fake_prm2 = generator(cond)
            # Adversarial loss
            g_adv   = -critic(fake_wg2, fake_prm2, cond).mean()
            # Param regression loss
            g_param = criterion_param(fake_prm2, prm)
            g_loss  = g_adv + lambda_param * g_param
            g_loss.backward()
            optimizer_g.step()

            g_epoch.append(g_loss.item())

        # — Epoch logging —
        avg_c  = np.mean(c_epoch)
        avg_g  = np.mean(g_epoch)
        train_c_losses.append(avg_c)
        train_g_losses.append(avg_g)
        elapsed = time.perf_counter() - start
        print(f"Epoch {epoch:3d}/{epochs} — Critic: {avg_c:.4f}, Generator: {avg_g:.4f} — {elapsed:.1f}s")

        # — Update live plot —
        ax.clear()
        ax.plot(train_c_losses, label='Train Critic Loss')
        ax.plot(train_g_losses, label='Train Generator Loss')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('WGAN‑GP Losses (Live)')
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.pause(0.1)

    plt.ioff()
    plt.show()

    # — Save the generator —
    os.makedirs("./models", exist_ok=True)
    save_path = "./models/generator_wgan_gp.pth"
    torch.save(generator.state_dict(), save_path)
    print(f"Generator state_dict saved to {save_path}")


In [None]:
def generate_waveguide(generator, eigenmodes_weights):
    """
    Given a trained WGAN generator and a flat vector of 8 cond features
    (4 eigenmodes + 4 weights), returns the generated waveguide and params.
    """
    device = next(generator.parameters()).device
    x = (
        torch.tensor(eigenmodes_weights, dtype=torch.float32)
             .unsqueeze(0)
             .to(device)
    )

    generator.eval()
    with torch.no_grad():
        waveguide, params = generator(x)
    return waveguide, params

if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gen = Generator(cond_dim=8).to(device)
    gen.load_state_dict(torch.load(save_path, map_location=device))
    
    test_vals = load_item(dataset[1], train=False)
    cond = np.concatenate((test_vals["Eigenmodes"], test_vals["Weights"]))

    gen_waveguide, gen_params = generate_waveguide(gen, cond)
    gen_waveguide = (gen_waveguide >= 0.5).float()
    wg = gen_waveguide.squeeze().cpu().numpy()

    ## Eye test evaluate 
    plot_shape(wg)
    print(f"Conditions:       {cond}")
    print(f"Real Parameters:   {test_vals['Params']}")
    print(f"Generated Params:  {gen_params.squeeze().cpu().numpy()}")

    plt.show()