In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import os
# useful v1 functions
import import_ipynb 
from metamaterials_GAN_v1 import plot_shape, load_item, quarter, dataset, dataloader

if __name__ == "__main__":
    print("Torch version:", torch.__version__)
    print("CUDA available:", torch.cuda.is_available())
    print("CUDA version:", torch.version.cuda)
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU name:", torch.cuda.get_device_name(0) if torch.cuda.device_count() > 0 else "No GPU detected")

Torch version: 2.6.0
CUDA available: False
CUDA version: None
Number of GPUs: 0
GPU name: No GPU detected


### WGAN implementation attempt 1

In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
# Generator, same as in v1
class Generator(nn.Module):
    def __init__(self, cond_dim=8):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(cond_dim, 128), nn.ReLU(True),
            nn.Linear(128, 512),       nn.ReLU(True),
        )
        self.fc_img = nn.Linear(512, 64 * 4 * 4)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(True),
            nn.ConvTranspose2d(16, 1,  4, 2, 1), nn.Sigmoid(),
        )
        self.fc_params = nn.Sequential(
            nn.Linear(512, 32), nn.ReLU(True),
            nn.Linear(32, 4),
        )
    
    def forward(self, cond):
        x = self.fc(cond)
        img_feats = self.fc_img(x).view(-1, 64, 4, 4)
        waveguide = self.deconv(img_feats)
        params    = self.fc_params(x)
        return waveguide, params

# Critic (formerly Discriminator)

class Critic(nn.Module):
    def __init__(self, cond_dim=8):
        super(Critic, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(16,32, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(32,64, 4, 2, 1), nn.LeakyReLU(0.2, True),
        )
        self.fc_image = nn.Linear(64*4*4, 128)
        self.fc_params = nn.Sequential(
            nn.Linear(4, 16), nn.LeakyReLU(0.2, True)
        )
        self.fc_cond = nn.Sequential(
            nn.Linear(cond_dim, 16), nn.LeakyReLU(0.2, True)
        )
        # No final Sigmoid—raw score output
        self.fc_final = nn.Sequential(
            nn.Linear(128+16+16, 64),
            nn.LeakyReLU(0.2, True),
            nn.Linear(64, 1),
        )
    
    def forward(self, waveguide, params, cond):
        bsz = waveguide.size(0)
        x_img = self.cnn(waveguide).view(bsz, -1)
        x_img = self.fc_image(x_img)
        x_p   = self.fc_params(params)
        x_c   = self.fc_cond(cond)
        x     = torch.cat([x_img, x_p, x_c], dim=1)
        return self.fc_final(x)  # “critic score”

# WGAN Training Loop

def train_wgan(
    netG, netC, dataloader, 
    epochs=100, 
    n_critic=5, 
    clip_value=0.01, 
    lr=5e-5,
    device=torch.device("cpu")
):
    # Optimizers
    optC = optim.RMSprop(netC.parameters(), lr=lr)
    optG = optim.RMSprop(netG.parameters(), lr=lr)
    
    netG.to(device)
    netC.to(device)
    
    for epoch in range(epochs):
        for i, (real_wave, real_params, cond) in enumerate(dataloader):
            real_wave   = real_wave.to(device)
            real_params = real_params.to(device)
            cond        = cond.to(device)
            
            # Train Critic
            for _ in range(n_critic):
                netC.zero_grad()
                # Real score
                real_score = netC(real_wave, real_params, cond).mean()
                # Fake score
                fake_wave, fake_params = netG(cond)
                fake_score = netC(fake_wave.detach(), fake_params.detach(), cond).mean()
                
                lossC = fake_score - real_score
                lossC.backward()
                optC.step()
                
                # Weight clipping
                for p in netC.parameters():
                    p.data.clamp_(-clip_value, clip_value)
            
            # Train Generator
            netG.zero_grad()
            gen_wave, gen_params = netG(cond)
            gen_score = netC(gen_wave, gen_params, cond).mean()
            lossG = -gen_score
            lossG.backward()
            optG.step()
        
        print(f"Epoch [{epoch+1}/{epochs}]  Loss_C: {lossC.item():.4f}  Loss_G: {lossG.item():.4f}")

### Initialzing Models and Optimizers

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = Generator().to(device)
critic    = Critic().to(device)

lr = 5e-5
n_critic = 5 # how many critic steps per generator step
clip_value = 0.01 # weight‑clipping range

optimizer_g = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_c = optim.RMSprop(critic.parameters(),    lr=lr)

### Training the WGAN

In [1]:
if __name__ == "__main__":
    import matplotlib
    matplotlib.use('TkAgg')  # Or 'QtAgg', 'WXAgg'
    import matplotlib.pyplot as plt

    # — Device and hyperparameters —

    epochs       = 100
    lambda_param = 10.0

    # — Models —
    generator = Generator(cond_dim=8).to(device)
    critic    = Critic(cond_dim=8).to(device)

    # — Loss for parameter regression —
    criterion_param = nn.MSELoss()

    # — Optimizers (original WGAN uses RMSprop) —
    optimizer_g = optim.RMSprop(generator.parameters(), lr=lr)
    optimizer_c = optim.RMSprop(critic.parameters(),    lr=lr)

    # — Logging lists —
    train_c_losses = []
    train_g_losses = []

    # — Interactive plot setup —
    plt.ion()
    fig, ax = plt.subplots(figsize=(8,5))

    for epoch in range(1, epochs+1):
        start = time.perf_counter()
        c_epoch = []
        g_epoch = []

        for em, wts, prm, real_wg in dataloader:
            # Move data to device
            em, wts, prm = em.to(device), wts.to(device), prm.to(device)
            real_wg      = real_wg.to(device).unsqueeze(1)   # (B,1,32,32)
            cond         = torch.cat([em, wts], dim=1)

            # — Train Critic n_critic times —
            for _ in range(n_critic):
                critic.zero_grad()
                # Real score
                real_score = critic(real_wg, prm, cond).mean()
                # Fake score (detach so generator isn't updated here)
                fake_wg, fake_prm = generator(cond)
                fake_score = critic(fake_wg.detach(), fake_prm.detach(), cond).mean()
                # Wasserstein critic loss
                c_loss = fake_score - real_score
                c_loss.backward()
                optimizer_c.step()

                # Weight clipping
                for p in critic.parameters():
                    p.data.clamp_(-clip_value, clip_value)

            # — Train Generator —
            generator.zero_grad()
            fake_wg2, fake_prm2 = generator(cond)
            # Adversarial term (want critic to rate fake samples high)
            g_adv = -critic(fake_wg2, fake_prm2, cond).mean()
            # Regression term on the parameters
            g_param = criterion_param(fake_prm2, prm)
            # Combined generator loss
            g_loss = g_adv + lambda_param * g_param
            g_loss.backward()
            optimizer_g.step()

            c_epoch.append(c_loss.item())
            g_epoch.append(g_loss.item())

        # — Epoch summaries —
        avg_c = np.mean(c_epoch)
        avg_g = np.mean(g_epoch)
        train_c_losses.append(avg_c)
        train_g_losses.append(avg_g)
        elapsed = time.perf_counter() - start

        print(f"Epoch {epoch:3d}/{epochs} — Critic: {avg_c:.4f}, Generator: {avg_g:.4f} — {elapsed:.1f}s")

        # — Update live plot —
        ax.clear()
        ax.plot(train_c_losses, label='Train Critic Loss')
        ax.plot(train_g_losses, label='Train Generator Loss')
        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss')
        ax.set_title('WGAN Losses (Live)')
        ax.legend(loc='upper right')
        plt.tight_layout()
        plt.pause(0.1)

    plt.ioff()
    plt.show()

    # — Save the generator weights —
    save_dir  = "./models"
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, "generator_wgan.pth")
    torch.save(generator.state_dict(), save_path)
    print(f"Generator state_dict saved to {save_path}")


KeyboardInterrupt



### Visualizing the final Generated Image

In [None]:
def generate_waveguide(generator, eigenmodes_weights):
    """
    Given a trained WGAN generator and a flat vector of 8 cond features
    (4 eigenmodes + 4 weights), returns the generated waveguide and params.
    """
    device = next(generator.parameters()).device
    x = (
        torch.tensor(eigenmodes_weights, dtype=torch.float32)
             .unsqueeze(0)
             .to(device)
    )

    generator.eval()
    with torch.no_grad():
        waveguide, params = generator(x)
    return waveguide, params


if __name__ == "__main__":

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    gen = Generator(cond_dim=8).to(device)
    gen.load_state_dict(torch.load(save_path, map_location=device))
    
    test_vals = load_item(dataset[1], train=False)
    cond = np.concatenate((test_vals["Eigenmodes"], test_vals["Weights"]))

    gen_waveguide, gen_params = generate_waveguide(gen, cond)
    gen_waveguide = (gen_waveguide >= 0.5).float()
    wg = gen_waveguide.squeeze().cpu().numpy()

    ## Eye test evaluate 
    plot_shape(wg)
    print(f"Conditions:       {cond}")
    print(f"Real Parameters:   {test_vals['Params']}")
    print(f"Generated Params:  {gen_params.squeeze().cpu().numpy()}")

    plt.show()