# Enhanced Quality of Outputs

# 

In [None]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch

In [None]:
# Load the training data and sample 2x8 samples of Airfoils
# Load original dataset
real_data = np.load('./hydFoil_data/resampled_hydrofoils.npy')

In [None]:


def plot_random_airfoils(airfoil_dataset, 
                         num_plots_per_row=8, 
                         num_rows=2, 
                         figsize=(10, 1), 
                         save_path='random_airfoils_plot.svg'):
    """
    Randomly selects and plots airfoils from a given dataset in a specified grid layout.

    Args:
        airfoil_dataset (np.array or list of np.array): A dataset where each element
                                                        is a 2D array of (x, y) coordinates
                                                        representing an airfoil.
                                                        Expected shape for each airfoil: (200, 2).
        num_plots_per_row (int): Number of airfoil plots in each row. Default is 8.
        num_rows (int): Number of rows for the plots. Default is 2.
        figsize (tuple): Figure size for the matplotlib plot (width, height). Default is (16, 4).
        save_path (str): Path to save the generated plot image. Default is 'random_airfoils_plot.png'.
    """
    total_plots = num_plots_per_row * num_rows
    N = airfoil_dataset.shape[0]

    if N < total_plots:
        raise ValueError(
            f"The dataset contains only {len(airfoil_dataset)} airfoils, "
            f"but you requested to plot {total_plots}."
        )

    # Randomly select indices for the airfoils to be plotted
    selected_indices = np.random.choice(N, total_plots, replace=False)
    selected_airfoils = airfoil_dataset[selected_indices]

    # Create the subplot grid
    fig, axes = plt.subplots(num_rows, num_plots_per_row, figsize=figsize)
    axes = axes.flatten() # Flatten the 2D array of axes for easy iteration

    # Plot each selected airfoil
    for i, airfoil_data in enumerate(selected_airfoils):
        ax = axes[i]
        ax.plot(airfoil_data[:, 0], airfoil_data[:, 1], 'k-', linewidth=0.5) # Plot in black
        ax.set_aspect('equal', adjustable='box') # Maintain aspect ratio for airfoil shape
        ax.axis('off') # Turn off axes for a cleaner look

    plt.tight_layout() # Adjust subplot parameters for a tight layout
    plt.savefig(save_path, dpi=300) # Save the plot with high resolution
    plt.close(fig) # Close the figure to free up memory
    print(f"Plot of {total_plots} randomly selected airfoils saved to '{save_path}'")


In [None]:
# PLot real airfoils from dataset
plot_random_airfoils(real_data,
                     save_path='real_airfoils.svg')

In [None]:
# load generator model
# Hyperparamter for training
LATENT_DIM = 3
NOISE_DIM = 10

#checkpoint_bgan = torch.load("./trained_gan/3_10/checkpoints/model_epoch_2000.pth")
checkpoint_bhgan = torch.load("./trained_gan/3_10/checkpoints/model_epoch_10000.pth")

In [None]:
# load teh model 
from hydFoilGAN.gan import *

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generator model to load from checkpoints
gen_bhgan = Generator(latent_dim=LATENT_DIM,
                     noise_dim=NOISE_DIM).to(device)

In [None]:
# load state_dict from checkpoint
gen_bhgan.load_state_dict(checkpoint_bhgan['model_G_state_dict'])

In [None]:
n_samples = 1000
batch_size = 32
bhgan_pointclouds =  gen_bhgan.generate_hydrofoil_pointclouds(n_samples,
                                                            batch_size)

In [None]:
# Plot Bezier-HingeGAN generated shapes
plot_random_airfoils(bhgan_pointclouds,
                     save_path='bhgan_airfoils.svg')