In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [None]:
# Use GPU if available
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print("Device: ", device)

N = 100
batch_size = 128
z_dim = 30  # noise dimension
num_epochs = 300
critic_num = 2 # Update critic _ before generator updates
lambd = 10 # Penalization constant

In [None]:
X_data = np.load("X_data.npy")
Y_data = np.load("Y_data.npy")

X_tensor = torch.tensor(X_data, dtype=torch.float32)
Y_tensor = torch.tensor(Y_data, dtype=torch.float32)


# Split train and test set
Y_train, Y_test, X_train, X_test = train_test_split(Y_tensor, X_tensor, test_size=0.2)

train_dataset = TensorDataset(Y_train, X_train)
test_dataset = TensorDataset(Y_test, X_test)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
class Sin(nn.Module):
    def forward(self, x):
        return torch.sin(x)

# Define Generator
class Generator(nn.Module):
    def __init__(self, y_dim, z_dim, x_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(y_dim + z_dim, 128),
            Sin(),
            nn.Linear(128, 128),
            Sin(),
            nn.Linear(128, x_dim)
        )
    def forward(self, y, z):
        input = torch.cat([y, z], dim=1)
        return self.net(input)

# Define Discriminator
class Discriminator(nn.Module):
    def __init__(self, y_dim, x_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(y_dim + x_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, y, x):
        input = torch.cat([y, x], dim=1)
        return self.net(input)

In [None]:
G = Generator(y_dim=N, z_dim=z_dim, x_dim=N).to(device)
D = Discriminator(y_dim=N, x_dim=N).to(device)

# Print total params
print(sum(p.numel() for p in G.parameters() if p.requires_grad))
print(sum(p.numel() for p in D.parameters() if p.requires_grad))

# Optimizer
lr = 5e-3
optimizer_G = optim.Adam(G.parameters(), lr=lr)
optimizer_D = optim.Adam(D.parameters(), lr=lr)

# Learning Rate Scheduler
scheduler_G = StepLR(optimizer_G, step_size=50, gamma=0.8)
scheduler_D = StepLR(optimizer_D, step_size=50, gamma=0.8)

# Loss
criterion = nn.BCELoss()

In [None]:

for epoch in range(num_epochs):
    G.train()
    D.train()
    for Y_batch, X_batch in train_loader:
        Y_batch = Y_batch.to(device)
        X_batch = X_batch.to(device)
        batch_size_cur = Y_batch.shape[0]

        # Train Discriminator
        #--------------------
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_size_cur, 1, device=device)
        fake_labels = torch.zeros(batch_size_cur, 1, device=device)

        d_real = D(Y_batch, X_batch)
        loss_d_real = criterion(d_real, real_labels)

        z = torch.randn(batch_size_cur, z_dim, device=device)
        X_fake = G(Y_batch, z)
        d_fake = D(Y_batch, X_fake.detach())
        loss_d_fake = criterion(d_fake, fake_labels)

        # Calculate loss
        loss_d = loss_d_real + loss_d_fake
        loss_d.backward()
        optimizer_D.step()

        # Train Generator
        #----------------
        if epoch % critic_num == 0:
          optimizer_G.zero_grad()
          d_fake = D(Y_batch, X_fake)
          # Calculate loss
          loss_g = criterion(d_fake, real_labels)
          loss_g += lambd * F.mse_loss(X_fake, X_batch) # penalize generator
          loss_g.backward()
          optimizer_G.step()
    # LR scheduler update
    scheduler_G.step()
    scheduler_D.step()

    if epoch %50 == 0 :
        print(f"Epoch {epoch}: D loss {loss_d.item():.4f}, G loss {loss_g.item():.4f}")


In [None]:
G.eval()
num_samples = 100
num_plots = 6
fig, axs = plt.subplots(2, 3, figsize=(15, 8))
axs = axs.flatten()
with torch.no_grad():
    for i, (Y_sample, X_sample) in enumerate(test_loader):
        if i >= num_plots:
            break
        Y_sample = Y_sample.to(device)
        X_sample = X_sample.to(device)
        generated_samples = []
        for _ in range(num_samples): # generate multiple x given y
            z = torch.randn(1, z_dim, device=device)
            X_gen = G(Y_sample, z).cpu().numpy().flatten()
            generated_samples.append(X_gen)

        generated_samples = np.array(generated_samples)

        mean_gen = generated_samples.mean(axis=0) # get mean of generated x
        std_gen = generated_samples.std(axis=0)# get SD of generated x

        t = np.linspace(0, 1, N)

        # Plot true vs generated
        ax = axs[i]
        ax.plot(t, X_sample.cpu().numpy().flatten(), label='True')
        ax.plot(t, mean_gen, label='Mean Generated', linestyle='--')
        ax.fill_between(s, mean_gen - 2*std_gen, mean_gen + 2*std_gen, color='grey',  label='95% CI')
        plt.suptitle("True vs Generated x", fontsize=18)
        ax.set_title(f"Sample {i+1}")
        ax.set_xlabel("t")
        ax.set_ylabel(" x")
        ax.legend()
plt.tight_layout(pad=1.0)
plt.show()
