In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
np.random.seed(0)

In [2]:
def generate_3D_ellipse(n, r1 = 1, r2 = 0.5):
    t = torch.linspace(0, 2 * torch.pi, n+1)
    x = (r1 * torch.cos(t)).view(-1, 1)[:-1]
    y = (r2 * torch.sin(t)).view(-1, 1)[:-1]
    z = (0 * torch.sin(t)).view(-1, 1)[:-1]
    return torch.cat((x, y, z), dim=1).unsqueeze(0)


def generate_ellipse(n, r1 = 1, r2 = 0.5):
    t = torch.linspace(0, 2*torch.pi, n+1)
    x = (r1 * torch.cos(t)).view(-1, 1)[:-1]
    y = (r2 * torch.sin(t)).view(-1, 1)[:-1]
    return torch.cat((x, y), dim=1).unsqueeze(0)

def generate_data(n, sign = "+"):

    el = generate_ellipse(n, r1 = 0.5, r2 = 1)
    e2 = generate_ellipse(n, r1 = 1, r2 = 0.5)
    return torch.concat((el, e2), dim=0)


In [3]:
class H_theta(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(H_theta, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 300),
            nn.ReLU(),
            nn.Linear(300, 200),
            nn.ReLU(),
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, output_dim)
        )

    def forward(self, x):
        return self.fc(x)
    

def generate_NN_latent_functions(num_samples, xdim=1, zdim=2, bias=0):
    class NN(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(NN, self).__init__()
            self.fc1 = nn.Linear(input_dim, 100)
            self.fc2 = nn.Linear(100, 50)
            self.fc3 = nn.Linear(50, 50)
            self.fc4 = nn.Linear(50, output_dim)
        
        def forward(self, x):
            x = torch.relu(self.fc1(x))
            x = torch.relu(self.fc2(x))
            x = torch.relu(self.fc3(x))
            x = self.fc4(x)*5
            return x

    #  weight initialization function
    def weights_init_normal(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight, gain = 0.5)
            if m.bias is not None:
                nn.init.constant_(m.bias, bias)

    #  neural networks
    networks = []
    for _ in range(num_samples):
        net = NN(xdim, zdim)
        net.apply(weights_init_normal)
        networks.append(net)
    return networks

def find_nns(Y, G):
    #Y: [1, 1024, 3]
    #G: [20, 1024, 3]

    distances = torch.sum(((Y - G) ** 2), dim = 2).mean(dim = 1)
    _, min_idx = torch.min(distances, dim=0)
    return min_idx.item()

def f_loss(Y, G):
    weighted_diffs = (G - Y)**2
    diffs = torch.sum(weighted_diffs, dim=2)
    total_loss = diffs.mean(dim=1).mean(dim=0)
    return total_loss

In [4]:
import itertools
import os

def train(concat=True, multiply_output=False, zdim=5, xdim=1, output_dim=2, bias=0, include_zero=False, epochs=100, num_points=1600, lr=0.001, staleness=15, num_Z_samples=30):

    folder_name = f"case_bias_{bias}_zero_{include_zero}_concat_{concat}_multiply_{multiply_output}_{output_dim}D"
    output_dir = os.path.join("cases_results", folder_name)
    os.makedirs(output_dir, exist_ok=True)

    print(f'Training: {folder_name}')

    if concat:
        input_dim = zdim + xdim
    else:
        input_dim = zdim

    H_t = H_theta(input_dim=input_dim, output_dim=output_dim).to(device)
    optimizer = optim.Adam(H_t.parameters(), lr=lr)
    losses = []

    if xdim == 1:
        if include_zero:
            x = torch.linspace(-1.0, 2.0, num_points).to(device).unsqueeze(1)
        else:
            x = torch.linspace(1.0, 2.0, num_points).to(device).unsqueeze(1)
    else:
        if include_zero:
            x1 = torch.linspace(-1.0, 2.0, 40)
            x2 = torch.linspace(-1.0, 2.0, 40)
        else:
            x1 = torch.linspace(1.0, 2.0, 40)
            x2 = torch.linspace(1.0, 2.0, 40)
        grid_x1, grid_x2 = torch.meshgrid((x1, x2), indexing='ij')
        x = torch.stack((grid_x1, grid_x2), dim=-1).reshape(-1, 2).to(device)

    # Generate data
    data = generate_data(num_points).to(device)
    if output_dim == 3:
        points1 = generate_3D_ellipse(num_points).to(device)
        points2 = generate_3D_ellipse(num_points, r1=0.5, r2=1.0).to(device)
        data = torch.concat((points1, points2), dim=0).to(device)

    for e in tqdm(range(epochs)):
        with torch.no_grad():
            if e % staleness == 0:
                H_t.eval()
                Zxs = torch.empty((num_Z_samples, num_points, input_dim)).to(device)
                Zs = generate_NN_latent_functions(num_samples=num_Z_samples, xdim=xdim, zdim=zdim, bias=bias)
                for i, model in enumerate(Zs):
                    model = model.to(device)
                    Z = 0
                    if multiply_output:
                        Z = model(x) * 50  # Multiply output
                    else:
                        Z = model(x)
                    if concat:
                        Zxs[i] = torch.cat((Z, x), dim=1).to(device) #Concat
                    else:
                        Zxs[i] = (Z).to(device)
                
                generated = H_t(Zxs).to(device)
                imle_nns = torch.tensor([find_nns(d, generated) for d in data], dtype=torch.long)
                imle_transformed_points = Zxs[imle_nns]
                H_t.train()

        optimizer.zero_grad()
        outs = H_t(imle_transformed_points)
        loss = f_loss(data, outs)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()

        if ((e-1) % int(epochs/2) == 0 or e == epochs-1) and e - 1 != 0:
            plt.figure(figsize=(3, 3))
            outs_disp = outs.to(device='cpu').detach().numpy()
            points_disp = data.to(device='cpu').detach().numpy()
            for i in range(1):
                plt.plot(points_disp[0, :, 0], points_disp[0, :, 1], c='red', alpha=0.5)
                plt.plot(points_disp[1, :, 0], points_disp[1, :, 1], c='blue', alpha=0.5)
                plt.scatter(outs_disp[0, :, 0], outs_disp[0, :, 1], s=1.5, c='red', marker='*')
                plt.scatter(outs_disp[1, :, 0], outs_disp[1, :, 1], s=1.5, c='blue', marker='*')
            plt.title(f'Output at epoch {e}')
            plt.savefig(os.path.join(output_dir, f'output_epoch_{e}.png'))
            plt.close()

    # Plot and save the loss curve
    plt.figure(figsize=(15, 5))
    plt.plot(losses)
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.savefig(os.path.join(output_dir, 'loss_curve.png'))
    plt.close()

    print(losses.pop())

    return H_t


def test(zdim, xdim, output_dim, H_t, include_zero=False, num_points=1600, concat = False, multiply_output = False):
    # Create the output directory for test results
    folder_name = f"case_bias_{bias}_zero_{include_zero}_concat_{concat}_multiply_{multiply_output}_{output_dim}D"
    output_dir = os.path.join("cases_results", folder_name)
    os.makedirs(output_dir, exist_ok=True)
    if concat:
        input_dim = zdim + xdim
    else:
        input_dim = zdim
    # Generate x based on xdim
    if xdim == 1:
        if include_zero:
            x = torch.linspace(-1.0, 2.0, num_points).to(device).unsqueeze(1)
        else:
            x = torch.linspace(1.0, 2.0, num_points).to(device).unsqueeze(1)
    else:
        if include_zero:
            x1 = torch.linspace(-1.0, 2.0, 40)
            x2 = torch.linspace(-1.0, 2.0, 40)
        else:
            x1 = torch.linspace(1.0, 2.0, 40)
            x2 = torch.linspace(1.0, 2.0, 40)
        grid_x1, grid_x2 = torch.meshgrid((x1, x2), indexing='ij')
        x = torch.stack((grid_x1, grid_x2), dim=-1).reshape(-1, 2).to(device)

    # Generate data
    data = generate_data(num_points).to(device)
    if output_dim == 3:
        points1 = generate_3D_ellipse(num_points).to(device)
        points2 = generate_3D_ellipse(num_points, r1=0.5, r2=1.0).to(device)
        data = torch.concat((points1, points2), dim=0).to(device)

    # Generate test points
    num_samples = 10
    Zxs = torch.empty((num_samples, num_points, input_dim)).to(device)
    data = data.to("cpu").detach().numpy()
    Zs = generate_NN_latent_functions(num_samples=num_samples, xdim=xdim, zdim=zdim)

    for i, model in enumerate(Zs):
        Z = 0
        model = model.to(device)
        if multiply_output:
            Z = model(x) * 50  # Multiply output
        else:
            Z = model(x)
        if concat:
            Zxs[i] = torch.cat((Z, x), dim=1).to(device)
        else:
            Zxs[i] = (Z).to(device)
        


    generated = H_t(Zxs).to("cpu").detach().numpy()
    plt.figure(figsize=(5, 5))
    for i in range(generated.shape[0]):
        
        for c in data:
            plt.plot(c[:, 0], c[:, 1], alpha=0.5)
        plt.scatter(generated[i, :, 0], generated[i, :, 1], alpha=1, s=1)

    plt.xlabel("X-axis")
    plt.ylabel("Y-axis")
    plt.savefig(os.path.join(output_dir, f"Generated_Results.png"))
    plt.close()

# Define the parameter options
bias_options = [0, 1]
include_zero_options = [True, False]
concat_options = [True, False]
multiply_output_options = [True, False]
ellipse_options = ['2D', '3D']

# Generate all combinations
combinations = list(itertools.product(bias_options, include_zero_options, concat_options, multiply_output_options, ellipse_options))

# Iterate through each combination
for combination in combinations:
    bias, include_zero, concat, multiply_output, ellipse = combination
    
    # Set xdim and output_dim based on 2D/3D choice
    if ellipse == '2D':
        xdim = 1
        output_dim = 2
    else:
        xdim = 2
        output_dim = 3

    # Call the train function with the current settings
    H_t = train(concat=concat, multiply_output=multiply_output, zdim=15, xdim=xdim, output_dim=output_dim, bias=bias, include_zero=include_zero, epochs=5000)
    test(zdim=15, xdim=xdim, output_dim=output_dim, H_t=H_t, concat=concat)


Training: case_bias_0_zero_True_concat_True_multiply_True_2D


 71%|███████   | 3540/5000 [00:15<00:06, 233.96it/s]


KeyboardInterrupt: 