[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/utkarshp1161/Active-learning-in-microscopy/blob/apply/apply_nb/GP_and_sGP_beyond_1D.ipynb)

# GP & sGP for Phase Data in Scanning Probe Microscopy (SPM)

This notebook demonstrates the implementation of Gaussian Process (GP) and Sparse Gaussian Process (sGP) for Phase data in Scanning Probe Microscopy (SPM).

- **Prepared by:** [Utkarsh Pratiush](https://github.com/utkarshp1161)  
- **Data & Ideas Discussion with:** [Richard Liu](https://github.com/RichardLiuCoding) and [SVK](https://github.com/SergeiVKalinin)

### Reference:
This implementation is based on the original notebook:[sGP Notebook](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/GP_%26_sGP_beyond_1D.ipynb)

## 1a. Install and Imports

In [3]:
#install
!pip install -q botorch==0.12.0
!pip install -q gpytorch==1.13

# Imports
import torch
import gpytorch
import botorch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from botorch.models import SingleTaskGP
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.fit import fit_gpytorch_mll
from botorch.optim import optimize_acqf_discrete
from botorch.acquisition import UpperConfidenceBound


# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1b. Download Data and set: Ground truth function : 
    - Choose either phases.txt or phases_masked.txt

In [8]:
# !gdown --id 1soIQoCWjyZVhvVO3FSdWMxQbLvGM_3P4 #phases.txt
# !gdown --id 1jyef_BgqpqI2QEnrnjSCcPmYDM3Z9wUO #phases_masked.txt
import os

# Check if running in Google Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

# Ensure the data directory exists
data_dir = "../data"
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Download the files if in Colab
if IN_COLAB:
    if not os.path.exists(os.path.join(data_dir, "phases.txt")):
        !gdown --id 1soIQoCWjyZVhvVO3FSdWMxQbLvGM_3P4 -O {data_dir}/phases.txt
    if not os.path.exists(os.path.join(data_dir, "phases_masked.txt")):
        !gdown --id 1jyef_BgqpqI2QEnrnjSCcPmYDM3Z9wUO -O {data_dir}/phases_masked.txt



In [None]:
import numpy as np
import torch

# Load the data globally so it's accessible to the function
t = np.linspace(0, 1, 100)
x_grid, y_grid = np.meshgrid(t, t)
# Load the data
z_data = np.loadtxt(os.path.join(data_dir, "phases.txt")).reshape(100, 100)

def test_function(X, ndim=2):
    """
    Function that returns values from the loaded data array
    Args:
        X: Input tensor of shape (n_points, ndim)
        ndim: Number of dimensions (only 2D is supported in this version)
    Returns:
        Y: Output tensor of shape (n_points, 1)
    """
    if ndim != 2:
        raise ValueError("This function only supports 2D inputs")
    
    # Convert input tensor to numpy array
    X_np = X.numpy()
    
    # Scale inputs from [-5, 5] to [0, 1] range (assuming your visualization uses [-5, 5])
    X_scaled = X_np#(X_np + 5) / 10
    
    # Find nearest indices in the grid
    x_idx = np.clip((X_scaled[:, 0] * 99).astype(int), 0, 99)
    y_idx = np.clip((X_scaled[:, 1] * 99).astype(int), 0, 99)
    
    # Get values from z_data array
    values = z_data[y_idx, x_idx]
    
    # Convert back to torch tensor
    return torch.tensor(values, dtype=torch.float32).reshape(-1, 1)

# The visualization function can remain the same
def visualize_ground_truth_function(func, ndim=2, resolution=50):
    """
    Visualize function in 2D and 3D side by side
    """
    plt.figure(figsize=(15, 6))
    
    # 2D Visualization
    plt.subplot(121)
    x = np.linspace(0, 1, resolution)
    y = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    Z = func(torch.tensor(points).float())
    Z = Z.reshape(resolution, resolution).numpy()
    
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("2D Visualization")
    plt.xlabel("x1")
    plt.ylabel("x2")
    
    # 3D Visualization
    ax = plt.subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("3D Visualization")
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_zlabel("f(x)")
    
    plt.tight_layout()
    plt.show()

# Visualize
visualize_ground_truth_function(test_function)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# Load and prepare the data
t = np.linspace(0, 1, 100)
x, y = np.meshgrid(t, t)
z = np.loadtxt("../data/phases_masked.txt")  # This should be of shape (10000,)
z_2d = z.reshape(100, 100)  # Reshape to 2D grid

# Define new test function
def test_function(X, ndim=2):
    """
    Function that returns values from the loaded data
    Args:
        X: Input tensor of shape (n_points, ndim)
        ndim: Number of input dimensions (only 2D supported in this version)
    Returns:
        Y: Output tensor of shape (n_points, 1)
    """
    if ndim != 2:
        raise ValueError("This function only supports 2D inputs")
    
    # Convert input to numpy for easier indexing
    X_np = X.numpy()
    
    # Scale inputs from [-5, 5] to [0, 1] range
    X_scaled = X_np#(X_np + 5) / 10
    
    # Get nearest indices
    x_idx = np.clip((X_scaled[:, 0] * 99).astype(int), 0, 99)
    y_idx = np.clip((X_scaled[:, 1] * 99).astype(int), 0, 99)
    
    # Get corresponding z values
    z_values = z_2d[y_idx, x_idx]
    
    return torch.tensor(z_values).float().unsqueeze(-1)

# The visualization function remains the same
def visualize_ground_truth_function(func, ndim=2, resolution=50):
    """
    Visualize function in 2D and 3D side by side
    """
    plt.figure(figsize=(15, 6))
    
    # 2D Visualization
    plt.subplot(121)
    x = np.linspace(0, 1, resolution)
    y = np.linspace(0, 1, resolution)
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    Z = func(torch.tensor(points).float())
    Z = Z.reshape(resolution, resolution).numpy()
    
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("2D Visualization")
    plt.xlabel("x1")
    plt.ylabel("x2")
    
    # 3D Visualization
    ax = plt.subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("3D Visualization")
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_zlabel("f(x)")
    
    plt.tight_layout()
    plt.show()

# Visualize
visualize_ground_truth_function(test_function)

## 1c. Define kernel
- For more information on defining kernel - see [this notebook](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/GP_%26_sGP_BO_BoTorch.ipynb)

In [None]:
## kernel
# Define the custom kernel
class CustomKernel(gpytorch.kernels.Kernel):
    def __init__(self, input_dim, lengthscale_prior=None, outputscale_prior=None):
        super().__init__()
        self.base_kernel = gpytorch.kernels.RBFKernel(
            ard_num_dims=input_dim,
            lengthscale_prior=lengthscale_prior
        )
        self.scaling_kernel = gpytorch.kernels.ScaleKernel(
            self.base_kernel,
            outputscale_prior=outputscale_prior
        )
        
    def forward(self, x1, x2, **params):
        return self.scaling_kernel.forward(x1, x2, **params)

# Create sample data
x = torch.linspace(-3, 3, 100).view(-1, 1)
kernel = CustomKernel(input_dim=1)

# Compute kernel matrix
K = kernel(x, x).evaluate().detach().numpy()

# Plot the kernel matrix
plt.figure(figsize=(10, 8))
plt.imshow(K, cmap='viridis')
plt.colorbar(label='Kernel value')
plt.title('Custom Kernel Matrix')
plt.xlabel('Index i')
plt.ylabel('Index j')
plt.show()

# Plot a slice of the kernel
x0 = torch.zeros(1, 1)
k_slice = kernel(x, x0).evaluate().detach().numpy()

plt.figure(figsize=(10, 6))
plt.plot(x.numpy(), k_slice)
plt.title('Kernel Slice (k(x, 0))')
plt.xlabel('x')
plt.ylabel('k(x, 0)')
plt.grid(True)
plt.show()

## 1d. GP model

In [13]:
from gpytorch.models import ExactGP
from gpytorch.means import ConstantMean

class SimpleGP(SingleTaskGP):
    def __init__(self, train_X, train_Y, covar_module, likelihood):
        super().__init__(train_X, train_Y)
        self.mean_module = ConstantMean()  # Constant mean function
        self.covar_module = covar_module
        self.likelihood = likelihood
        
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# Sample points function remains the same
def sample_points(ndim, n_points, bounds=(0, 1)):
    return torch.rand(n_points, ndim) * (bounds[1] - bounds[0]) + bounds[0]

## 1e. Plotting utility

In [14]:
def plot_step(model, train_X, train_Y, next_point=None, ndim=2):
    """Plot current state of optimization"""
    if ndim > 2:
        print("Visualization only supported for 1D and 2D inputs")
        return
    
    if ndim == 2:
        # Create meshgrid
        x1 = torch.linspace(0, 1, 100)
        x2 = torch.linspace(0, 1, 100)
        x1_grid, x2_grid = torch.meshgrid(x1, x2)
        grid_points = torch.stack([x1_grid.flatten(), x2_grid.flatten()], dim=-1)
        
        # Get predictions
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            posterior = model.posterior(grid_points)
            mean = posterior.mean.squeeze()  # Add squeeze to handle single-output case
            lower, upper = posterior.confidence_region()
        
        # Reshape for plotting
        mean_surface = mean.reshape(100, 100)
        std_surface = ((upper - lower) / 2).reshape(100, 100)
        true_values = test_function(grid_points, ndim).reshape(100, 100)
        
        # Create plots
        fig = plt.figure(figsize=(15, 5))
        
        # True function
        ax1 = fig.add_subplot(131)
        c1 = ax1.contourf(x1_grid, x2_grid, true_values, levels=20)
        plt.colorbar(c1, ax=ax1)
        ax1.set_title('True Function')
        ax1.scatter(train_X[:, 0], train_X[:, 1], c='red', marker='x', label='Training points')
        if next_point is not None and next_point.numel() > 0:  # Check if next_point exists and is not empty
            ax1.scatter(next_point[0, 0], next_point[0, 1], c='green', marker='o', label='Next point')
        ax1.legend()
        
        # Posterior mean
        ax2 = fig.add_subplot(132)
        c2 = ax2.contourf(x1_grid, x2_grid, mean_surface, levels=20)
        plt.colorbar(c2, ax=ax2)
        ax2.set_title('Posterior Mean')
        ax2.scatter(train_X[:, 0], train_X[:, 1], c='red', marker='x')
        if next_point is not None and next_point.numel() > 0:
            ax2.scatter(next_point[0, 0], next_point[0, 1], c='green', marker='o')
        
        # Posterior std
        ax3 = fig.add_subplot(133)
        c3 = ax3.contourf(x1_grid, x2_grid, std_surface, levels=20)
        plt.colorbar(c3, ax=ax3)
        ax3.set_title('Posterior Std Dev')
        ax3.scatter(train_X[:, 0], train_X[:, 1], c='red', marker='x')
        if next_point is not None and next_point.numel() > 0:
            ax3.scatter(next_point[0, 0], next_point[0, 1], c='green', marker='o')
        
        plt.tight_layout()
        plt.show()
        
    elif ndim == 1:
        x_plot = torch.linspace(train_X.min(), train_X.max(), 100).reshape(-1, 1)
        
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            posterior = model.posterior(x_plot)
            mean = posterior.mean.squeeze()
            lower, upper = posterior.confidence_region()
        
        plt.figure(figsize=(10, 6))
        plt.plot(x_plot.numpy(), mean.numpy(), 'b-', label='Posterior Mean')
        plt.fill_between(x_plot.numpy().flatten(), 
                        lower.numpy(), 
                        upper.numpy(), 
                        alpha=0.2, 
                        label='95% Confidence')
        plt.scatter(train_X.numpy(), train_Y.numpy(), c='red', 
                    marker='x', label='Training Points')
        if next_point is not None and next_point.numel() > 0:
            next_y = test_function(next_point, ndim)
            plt.scatter(next_point.numpy(), next_y.numpy(), 
                       c='green', marker='o', label='Next point')
        plt.legend()
        plt.title('GP Posterior')
        plt.show()

## 1f. Training GP

In [None]:
# Main execution
if __name__ == "__main__":
    # Parameters
    ndim = 2
    n_initial = 50
    n_test = 10000
    
    # Generate initial data
    train_X = sample_points(ndim, n_initial, bounds=(0, 1))
    train_Y = test_function(train_X, ndim)
    
    # # Define priors
    # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
    # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
    # noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

    # Conservative/Standard priors
    lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
    outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
    noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17

    # Define model components
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
    covar_module = CustomKernel(
        ndim,
        lengthscale_prior=lengthscale_prior,
        outputscale_prior=outputscale_prior
    )
    
    # Initialize and fit model
    model = SimpleGP(train_X, train_Y, covar_module, likelihood)
    mll = ExactMarginalLogLikelihood(likelihood, model)
    fit_gpytorch_mll(mll)
    
    plot_step(model, train_X, train_Y.squeeze(-1), next_point=None, ndim=ndim)  # Squeeze for plotting


# 2. Active learning with GP

In [None]:
# Parameters
ndim = 2
n_initial = 1
n_iterations = 15
beta = 1e6### --> higher means more exploration

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim).unsqueeze(-1)  # Add extra dimension here --> so concatenation works fine in the loop

# Generate candidate points for discrete optimization
n_candidates = 1000
candidates = sample_points(ndim, n_candidates)

for iteration in range(n_iterations):
    # Define priors
    # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
    # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
    # noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)
    # Conservative/Standard priors
    lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
    outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
    noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17

    # Define model components
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
    covar_module = CustomKernel(
        ndim,
        lengthscale_prior=lengthscale_prior,
        outputscale_prior=outputscale_prior
    )

    # Initialize and fit model
    model = SimpleGP(train_X, train_Y.squeeze(-1), covar_module, likelihood)
    mll = ExactMarginalLogLikelihood(likelihood, model)
    fit_gpytorch_mll(mll)

    # Define acquisition function (UCB with beta=1e6)
    UCB = UpperConfidenceBound(model, beta=beta)

    # Optimize acquisition function
    next_point, acq_value = optimize_acqf_discrete(
        acq_function=UCB,
        choices=candidates,
        q=1,
    )

    # Plot with next point
    plot_step(model, train_X, train_Y.squeeze(-1), next_point, ndim=ndim)  # Squeeze for plotting

    # Evaluate next point and update training data
    next_value = test_function(next_point, ndim).unsqueeze(-1)  # Add extra dimension
    train_X = torch.cat([train_X, next_point])
    train_Y = torch.cat([train_Y, next_value])  # Now dimensions match

# 3 sGP - with mean function prior
- Good idea to visit [this sGP notebook first](https://github.com/utkarshp1161/Active-learning-in-microscopy/blob/main/notebooks/GP_%26_sGP_BO_BoTorch.ipynb)

## 3a. Define cusom mean and sGP model

In [17]:
# GP Model
class CustomGP(SingleTaskGP):
    def __init__(self, train_X, train_Y, mean_module, covar_module, likelihood):
        super().__init__(train_X, train_Y)
        self.mean_module = mean_module
        self.covar_module = covar_module
        self.likelihood = likelihood
        
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


### 3a. i) Custom mean function 1 - Determisnistic function based on the simulator by Richard

In [None]:
import gpytorch
import torch
import numpy as np
import matplotlib.pyplot as plt

class PhaseDataMean(gpytorch.means.Mean):
    def __init__(self, input_size=2):
        super().__init__()
        self.register_parameter(
            name='amplitude',
            parameter=torch.nn.Parameter(torch.ones(1))
        )
        self.register_parameter(
            name='frequency',
            parameter=torch.nn.Parameter(torch.ones(1))
        )
        self.register_parameter(
            name='offset',
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name='scale',
            parameter=torch.nn.Parameter(torch.tensor([0.75]))
        )
        
    def forward(self, x):
        """
        Combines the specific function x = sqrt(y)/0.75 - 0.5 with periodic components
        Args:
            x: Input tensor of shape (n_points, 2)
        Returns:
            mean: Output tensor of shape (n_points)
        """
        # Extract x and y components
        x1, x2 = x[..., 0], x[..., 1]
        
        # Implement the specific function
        function_term = torch.sqrt(torch.abs(x2))/self.scale - 0.5
        
        # Periodic component
        r = torch.sqrt(x1**2 + x2**2)
        periodic_term = self.amplitude * torch.cos(self.frequency * r)
        
        # Combine terms
        return function_term + periodic_term + self.offset

def visualize_mean_function(mean_function, resolution=50):
    """
    Visualize the custom mean function in 2D and 3D
    """
    plt.figure(figsize=(15, 6))
    
    # Create grid of points
    x = np.linspace(-2, 2, resolution)
    y = np.linspace(0, 4, resolution)  # Using only positive values for y due to sqrt
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    
    # Get predictions
    with torch.no_grad():
        Z = mean_function(torch.tensor(points, dtype=torch.float32))
        Z = Z.reshape(resolution, resolution).numpy()
    
    # 2D Visualization
    plt.subplot(121)
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("Mean Function (2D)")
    plt.xlabel("x1")
    plt.ylabel("x2")
    
    # 3D Visualization
    ax = plt.subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("Mean Function (3D)")
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_zlabel("f(x)")
    
    plt.tight_layout()
    plt.show()

# Initialize and visualize
mean_function = PhaseDataMean(input_size=2)
visualize_mean_function(mean_function)

# Additional visualization of the specific function
plt.figure(figsize=(8, 6))
y = np.linspace(0, 4, 100)
x = np.sqrt(y)/0.75 - 0.5
plt.plot(x, y, 'b-', label='x = sqrt(y)/0.75 - 0.5')
plt.grid(True)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Original Function')
plt.legend()
plt.show()

#### 3a. i.a) Train sGP model for custom mean function 1

In [None]:
# Parameters
ndim = 2
n_initial = 50
n_test = 10000

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim)

# # Define priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
# noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

# Conservative/Standard priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
# noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17
# Priors for Sparse GP with RBF/Matern kernel
lengthscale_prior = gpytorch.priors.GammaPrior(2.0, 4.0)  # mean=0.5, more local variations
outputscale_prior = gpytorch.priors.GammaPrior(4.0, 4.0)  # mean=1.0, moderate variance
noise_prior = gpytorch.priors.GammaPrior(1.5, 6.0)        # mean=0.25, small noise



# Define model components
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
# mean_module = CustomMean(ndim)
mean_module = mean_function
covar_module = CustomKernel(
    ndim,
    lengthscale_prior=lengthscale_prior,
    outputscale_prior=outputscale_prior
)

# Initialize and fit model
model = CustomGP(train_X, train_Y, mean_module, covar_module, likelihood)
mll = ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_mll(mll)
plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting



#### 3a. i.b) Active learing for custom mean function 1

In [None]:
# Parameters
ndim = 2
n_initial = 1
n_iterations = 15
beta = 1e6### --> higher means more exploration

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim).unsqueeze(-1)  # Add extra dimension here --> so concatenation works fine in the loop

# Generate candidate points for discrete optimization
n_candidates = 1000
candidates = sample_points(ndim, n_candidates)

for iteration in range(n_iterations):
    # Define priors
    # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
    # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
    # noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)
    # Conservative/Standard priors
    lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
    outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
    noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17

    # Define model components
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
    # mean_module = CustomMean(ndim)
    covar_module = CustomKernel(
        ndim,
        lengthscale_prior=lengthscale_prior,
        outputscale_prior=outputscale_prior
    )

    # Initialize and fit model
    model = CustomGP(train_X, train_Y.squeeze(-1), mean_module, covar_module, likelihood)  # Squeeze here for CustomGP
    mll = ExactMarginalLogLikelihood(likelihood, model)
    fit_gpytorch_mll(mll)

    # Plot current state
    plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting

    # Define acquisition function (UCB with beta=1e6)
    UCB = UpperConfidenceBound(model, beta=beta)

    # Optimize acquisition function
    next_point, acq_value = optimize_acqf_discrete(
        acq_function=UCB,
        choices=candidates,
        q=1,
    )

    # Plot with next point
    plot_step(model, train_X, train_Y.squeeze(-1), next_point, ndim=ndim)  # Squeeze for plotting

    # Evaluate next point and update training data
    next_value = test_function(next_point, ndim).unsqueeze(-1)  # Add extra dimension
    train_X = torch.cat([train_X, next_point])
    train_Y = torch.cat([train_Y, next_value])  # Now dimensions match

### 3a. ii) Custom mean function 2 - based on the idea of a Quadratic boundary - has priors on the terms

In [None]:
import gpytorch
import torch
import numpy as np
import matplotlib.pyplot as plt
from gpytorch.priors import LogNormalPrior, NormalPrior
from gpytorch.constraints import Positive, Interval

class QuadraticBoundaryMean(gpytorch.means.Mean):
    def __init__(self, batch_shape=torch.Size()):
        super().__init__()

        # Register raw parameters
        self.register_parameter(
            name="raw_center_x",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_center_y",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_scale",
            parameter=torch.nn.Parameter(torch.zeros(1))  # log(1) = 0
        )
        self.register_parameter(
            name="raw_a",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_b",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_theta",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )

        # Register constraints
        self.register_constraint("raw_center_x", Interval(-5.0, 5.0))
        self.register_constraint("raw_center_y", Interval(-5.0, 5.0))
        self.register_constraint("raw_scale", Positive())
        self.register_constraint("raw_a", Positive())
        self.register_constraint("raw_b", Positive())
        self.register_constraint("raw_theta", Interval(-np.pi, np.pi))

        # Register priors
        self.register_prior(
            "center_x_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.center_x,
            lambda module, value: module._set_center_x(value)
        )
        self.register_prior(
            "center_y_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.center_y,
            lambda module, value: module._set_center_y(value)
        )
        self.register_prior(
            "scale_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.scale,
            lambda module, value: module._set_scale(value)
        )
        self.register_prior(
            "a_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.a,
            lambda module, value: module._set_a(value)
        )
        self.register_prior(
            "b_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.b,
            lambda module, value: module._set_b(value)
        )
        self.register_prior(
            "theta_prior",
            NormalPrior(0.0, np.pi/4),
            lambda module: module.theta,
            lambda module, value: module._set_theta(value)
        )

    # Properties
    @property
    def center_x(self):
        return self.raw_center_x_constraint.transform(self.raw_center_x)

    @property
    def center_y(self):
        return self.raw_center_y_constraint.transform(self.raw_center_y)

    @property
    def scale(self):
        return self.raw_scale_constraint.transform(self.raw_scale)

    @property
    def a(self):
        return self.raw_a_constraint.transform(self.raw_a)

    @property
    def b(self):
        return self.raw_b_constraint.transform(self.raw_b)

    @property
    def theta(self):
        return self.raw_theta_constraint.transform(self.raw_theta)

    # Setters
    def _set_center_x(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_center_x=self.raw_center_x_constraint.inverse_transform(value))

    def _set_center_y(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_center_y=self.raw_center_y_constraint.inverse_transform(value))

    def _set_scale(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_scale=self.raw_scale_constraint.inverse_transform(value))

    def _set_a(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_a=self.raw_a_constraint.inverse_transform(value))

    def _set_b(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_b=self.raw_b_constraint.inverse_transform(value))

    def _set_theta(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_theta=self.raw_theta_constraint.inverse_transform(value))

    def forward(self, x):
        if x.ndimension() == 1:
            x = x.unsqueeze(-1)

        x1, x2 = x[..., 0], x[..., 1]
        
        # Translate points to center
        x1_t = x1 - self.center_x
        x2_t = x2 - self.center_y
        
        # Rotate points
        cos_theta = torch.cos(self.theta)
        sin_theta = torch.sin(self.theta)
        x1_r = x1_t * cos_theta + x2_t * sin_theta
        x2_r = -x1_t * sin_theta + x2_t * cos_theta
        
        # Compute normalized distance
        dist = (x1_r/self.a)**2 + (x2_r/self.b)**2
        boundary = self.scale
        
        # Smooth transition between inside and outside
        smoothing_factor = 10.0
        return torch.sigmoid(smoothing_factor * (dist - boundary))

def plot_mean_function(mean_module, resolution=100):
    """
    Visualize the mean function in 2D and 3D
    """
    plt.figure(figsize=(15, 6))
    
    # Create grid of points
    x = np.linspace(-3, 3, resolution)
    y = np.linspace(-3, 3, resolution)
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    
    # Get predictions
    with torch.no_grad():
        Z = mean_module(torch.tensor(points, dtype=torch.float32))
        Z = Z.reshape(resolution, resolution).numpy()
    
    # 2D Visualization
    plt.subplot(121)
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("Mean Function (2D)")
    plt.xlabel("x1")
    plt.ylabel("x2")
    
    # 3D Visualization
    ax = plt.subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("Mean Function (3D)")
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_zlabel("f(x)")
    
    # Show parameter values
    param_text = (f'Parameters:\nCenter: ({mean_module.center_x.item():.3f}, '
                 f'{mean_module.center_y.item():.3f})\n'
                 f'Scale: {mean_module.scale.item():.3f}\n'
                 f'a: {mean_module.a.item():.3f}\n'
                 f'b: {mean_module.b.item():.3f}\n'
                 f'θ: {mean_module.theta.item():.3f}')

    # plt.text(0.02, 0.98, param_text,
    #          transform=plt.gca().transAxes,
    #          verticalalignment='top',
    #          bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

# Initialize and visualize
mean_module = QuadraticBoundaryMean()
plot_mean_function(mean_module)

#### 3a. ii.a) Train sGP model for custom mean function 2

In [None]:
# Parameters
ndim = 2
n_initial = 50
n_test = 10000

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim)

# # Define priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
# noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

# Conservative/Standard priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
# noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17
# Priors for Sparse GP with RBF/Matern kernel
lengthscale_prior = gpytorch.priors.GammaPrior(2.0, 4.0)  # mean=0.5, more local variations
outputscale_prior = gpytorch.priors.GammaPrior(4.0, 4.0)  # mean=1.0, moderate variance
noise_prior = gpytorch.priors.GammaPrior(1.5, 6.0)        # mean=0.25, small noise



# Define model components
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
# mean_module = CustomMean(ndim)
mean_module = mean_function
covar_module = CustomKernel(
    ndim,
    lengthscale_prior=lengthscale_prior,
    outputscale_prior=outputscale_prior
)

# Initialize and fit model
model = CustomGP(train_X, train_Y, mean_module, covar_module, likelihood)
mll = ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_mll(mll)
plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting



#### 3a. ii.b) Active learing for custom mean function 2

In [None]:
# Parameters
ndim = 2
n_initial = 1
n_iterations = 15
beta = 1e6### --> higher means more exploration

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim).unsqueeze(-1)  # Add extra dimension here --> so concatenation works fine in the loop

# Generate candidate points for discrete optimization
n_candidates = 1000
candidates = sample_points(ndim, n_candidates)

for iteration in range(n_iterations):
    # Define priors
    # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
    # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
    # noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)
    # Conservative/Standard priors
    lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
    outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
    noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17

    # Define model components
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
    # mean_module = CustomMean(ndim)
    covar_module = CustomKernel(
        ndim,
        lengthscale_prior=lengthscale_prior,
        outputscale_prior=outputscale_prior
    )

    # Initialize and fit model
    model = CustomGP(train_X, train_Y.squeeze(-1), mean_module, covar_module, likelihood)  # Squeeze here for CustomGP
    mll = ExactMarginalLogLikelihood(likelihood, model)
    fit_gpytorch_mll(mll)

    # Plot current state
    plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting

    # Define acquisition function (UCB with beta=1e6)
    UCB = UpperConfidenceBound(model, beta=beta)

    # Optimize acquisition function
    next_point, acq_value = optimize_acqf_discrete(
        acq_function=UCB,
        choices=candidates,
        q=1,
    )

    # Plot with next point
    plot_step(model, train_X, train_Y.squeeze(-1), next_point, ndim=ndim)  # Squeeze for plotting

    # Evaluate next point and update training data
    next_value = test_function(next_point, ndim).unsqueeze(-1)  # Add extra dimension
    train_X = torch.cat([train_X, next_point])
    train_Y = torch.cat([train_Y, next_value])  # Now dimensions match

### 3a. iii) Custom mean function 3 - based on the idea of a Quadratic boundary + linear_term + constant - has priors on the terms

In [None]:
class QuadraticBoundaryMean(gpytorch.means.Mean):
    def __init__(self, batch_shape=torch.Size()):
        super().__init__()

        # Original parameters for quadratic term
        self.register_parameter(
            name="raw_center_x",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_center_y",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_scale",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_a",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_b",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_theta",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )

        # Weights for different terms
        self.register_parameter(
            name="raw_w1",  # weight for constant term
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_w2",  # weight for quadratic term
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_w3",  # weight for linear term
            parameter=torch.nn.Parameter(torch.zeros(1))
        )

        # Linear term parameters
        self.register_parameter(
            name="raw_slope_x",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_slope_y",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )
        self.register_parameter(
            name="raw_constant",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )

        # Register constraints
        self.register_constraint("raw_center_x", Interval(-5.0, 5.0))
        self.register_constraint("raw_center_y", Interval(-5.0, 5.0))
        self.register_constraint("raw_scale", Positive())
        self.register_constraint("raw_a", Positive())
        self.register_constraint("raw_b", Positive())
        self.register_constraint("raw_theta", Interval(-np.pi, np.pi))
        self.register_constraint("raw_w1", Positive())
        self.register_constraint("raw_w2", Positive())
        self.register_constraint("raw_w3", Positive())
        self.register_constraint("raw_slope_x", Interval(-5.0, 5.0))
        self.register_constraint("raw_slope_y", Interval(-5.0, 5.0))
        self.register_constraint("raw_constant", Interval(-5.0, 5.0))

        # Register priors for original parameters
        self.register_prior(
            "center_x_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.center_x,
            lambda module, value: module._set_center_x(value)
        )
        self.register_prior(
            "center_y_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.center_y,
            lambda module, value: module._set_center_y(value)
        )
        self.register_prior(
            "scale_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.scale,
            lambda module, value: module._set_scale(value)
        )
        self.register_prior(
            "a_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.a,
            lambda module, value: module._set_a(value)
        )
        self.register_prior(
            "b_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.b,
            lambda module, value: module._set_b(value)
        )
        self.register_prior(
            "theta_prior",
            NormalPrior(0.0, np.pi/4),
            lambda module: module.theta,
            lambda module, value: module._set_theta(value)
        )

        # Register priors for weights
        self.register_prior(
            "w1_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.w1,
            lambda module, value: module._set_w1(value)
        )
        self.register_prior(
            "w2_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.w2,
            lambda module, value: module._set_w2(value)
        )
        self.register_prior(
            "w3_prior",
            LogNormalPrior(0.0, 0.5),
            lambda module: module.w3,
            lambda module, value: module._set_w3(value)
        )

        # Register priors for linear terms
        self.register_prior(
            "slope_x_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.slope_x,
            lambda module, value: module._set_slope_x(value)
        )
        self.register_prior(
            "slope_y_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.slope_y,
            lambda module, value: module._set_slope_y(value)
        )
        self.register_prior(
            "constant_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.constant,
            lambda module, value: module._set_constant(value)
        )

    # Properties for all parameters
    @property
    def center_x(self):
        return self.raw_center_x_constraint.transform(self.raw_center_x)

    @property
    def center_y(self):
        return self.raw_center_y_constraint.transform(self.raw_center_y)

    @property
    def scale(self):
        return self.raw_scale_constraint.transform(self.raw_scale)

    @property
    def a(self):
        return self.raw_a_constraint.transform(self.raw_a)

    @property
    def b(self):
        return self.raw_b_constraint.transform(self.raw_b)

    @property
    def theta(self):
        return self.raw_theta_constraint.transform(self.raw_theta)

    @property
    def w1(self):
        return self.raw_w1_constraint.transform(self.raw_w1)

    @property
    def w2(self):
        return self.raw_w2_constraint.transform(self.raw_w2)

    @property
    def w3(self):
        return self.raw_w3_constraint.transform(self.raw_w3)

    @property
    def slope_x(self):
        return self.raw_slope_x_constraint.transform(self.raw_slope_x)

    @property
    def slope_y(self):
        return self.raw_slope_y_constraint.transform(self.raw_slope_y)

    @property
    def constant(self):
        return self.raw_constant_constraint.transform(self.raw_constant)

    # Setters for all parameters
    def _set_center_x(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_center_x=self.raw_center_x_constraint.inverse_transform(value))

    def _set_center_y(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_center_y=self.raw_center_y_constraint.inverse_transform(value))

    def _set_scale(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_scale=self.raw_scale_constraint.inverse_transform(value))

    def _set_a(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_a=self.raw_a_constraint.inverse_transform(value))

    def _set_b(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_b=self.raw_b_constraint.inverse_transform(value))

    def _set_theta(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_theta=self.raw_theta_constraint.inverse_transform(value))

    def _set_w1(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_w1=self.raw_w1_constraint.inverse_transform(value))

    def _set_w2(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_w2=self.raw_w2_constraint.inverse_transform(value))

    def _set_w3(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_w3=self.raw_w3_constraint.inverse_transform(value))

    def _set_slope_x(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_slope_x=self.raw_slope_x_constraint.inverse_transform(value))

    def _set_slope_y(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_slope_y=self.raw_slope_y_constraint.inverse_transform(value))

    def _set_constant(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value))

    def forward(self, x):
        if x.ndimension() == 1:
            x = x.unsqueeze(-1)

        x1, x2 = x[..., 0], x[..., 1]
        
        # Quadratic term
        x1_t = x1 - self.center_x
        x2_t = x2 - self.center_y
        cos_theta = torch.cos(self.theta)
        sin_theta = torch.sin(self.theta)
        x1_r = x1_t * cos_theta + x2_t * sin_theta
        x2_r = -x1_t * sin_theta + x2_t * cos_theta
        quad_term = torch.sigmoid(10.0 * ((x1_r/self.a)**2 + (x2_r/self.b)**2 - self.scale))
        
        # Linear term
        linear_term = self.slope_x * x1 + self.slope_y * x2
        
        # Combine all terms with weights
        return (self.w1 * self.constant + 
                self.w2 * quad_term + 
                self.w3 * linear_term)
        
        
def plot_mean_function(mean_module, resolution=100):
    """
    Visualize the mean function in 2D and 3D
    """
    plt.figure(figsize=(15, 6))
    
    # Create grid of points
    x = np.linspace(-3, 3, resolution)
    y = np.linspace(-3, 3, resolution)
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    
    # Get predictions
    with torch.no_grad():
        Z = mean_module(torch.tensor(points, dtype=torch.float32))
        Z = Z.reshape(resolution, resolution).numpy()
    
    # 2D Visualization
    plt.subplot(121)
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("Mean Function (2D)")
    plt.xlabel("x1")
    plt.ylabel("x2")
    
    # 3D Visualization
    ax = plt.subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("Mean Function (3D)")
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    ax.set_zlabel("f(x)")
    
    # Show parameter values
    param_text = (f'Parameters:\nCenter: ({mean_module.center_x.item():.3f}, '
                 f'{mean_module.center_y.item():.3f})\n'
                 f'Scale: {mean_module.scale.item():.3f}\n'
                 f'a: {mean_module.a.item():.3f}\n'
                 f'b: {mean_module.b.item():.3f}\n'
                 f'θ: {mean_module.theta.item():.3f}')

    # plt.text(0.02, 0.98, param_text,
    #          transform=plt.gca().transAxes,
    #          verticalalignment='top',
    #          bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

# Initialize and visualize
mean_module = QuadraticBoundaryMean()
plot_mean_function(mean_module)

#### 3a. iii.a) Train sGP model for custom mean function 3

In [None]:
# Parameters
ndim = 2
n_initial = 50
n_test = 10000

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim)

# # Define priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
# noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

# Conservative/Standard priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
# noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17
# Priors for Sparse GP with RBF/Matern kernel
lengthscale_prior = gpytorch.priors.GammaPrior(2.0, 4.0)  # mean=0.5, more local variations
outputscale_prior = gpytorch.priors.GammaPrior(4.0, 4.0)  # mean=1.0, moderate variance
noise_prior = gpytorch.priors.GammaPrior(1.5, 6.0)        # mean=0.25, small noise



# Define model components
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
# mean_module = CustomMean(ndim)
mean_module = mean_function
covar_module = CustomKernel(
    ndim,
    lengthscale_prior=lengthscale_prior,
    outputscale_prior=outputscale_prior
)

# Initialize and fit model
model = CustomGP(train_X, train_Y, mean_module, covar_module, likelihood)
mll = ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_mll(mll)
plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting



#### 3a. iii.b) Active learing for custom mean function 3

In [None]:
# Parameters
ndim = 2
n_initial = 1
n_iterations = 15
beta = 1e6### --> higher means more exploration

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim).unsqueeze(-1)  # Add extra dimension here --> so concatenation works fine in the loop

# Generate candidate points for discrete optimization
n_candidates = 1000
candidates = sample_points(ndim, n_candidates)

for iteration in range(n_iterations):
    # Define priors
    # lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
    # outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
    # noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)
    # Conservative/Standard priors
    lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
    outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
    noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17

    # Define model components
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
    # mean_module = CustomMean(ndim)
    covar_module = CustomKernel(
        ndim,
        lengthscale_prior=lengthscale_prior,
        outputscale_prior=outputscale_prior
    )

    # Initialize and fit model
    model = CustomGP(train_X, train_Y.squeeze(-1), mean_module, covar_module, likelihood)  # Squeeze here for CustomGP
    mll = ExactMarginalLogLikelihood(likelihood, model)
    fit_gpytorch_mll(mll)

    # Plot current state
    plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting

    # Define acquisition function (UCB with beta=1e6)
    UCB = UpperConfidenceBound(model, beta=beta)

    # Optimize acquisition function
    next_point, acq_value = optimize_acqf_discrete(
        acq_function=UCB,
        choices=candidates,
        q=1,
    )

    # Plot with next point
    plot_step(model, train_X, train_Y.squeeze(-1), next_point, ndim=ndim)  # Squeeze for plotting

    # Evaluate next point and update training data
    next_value = test_function(next_point, ndim).unsqueeze(-1)  # Add extra dimension
    train_X = torch.cat([train_X, next_point])
    train_Y = torch.cat([train_Y, next_value])  # Now dimensions match

### 3a. iv) Custom mean function 4 - trying fourier basis as a mean function

In [None]:
class FlexibleMean(gpytorch.means.Mean):
    def __init__(self, num_fourier_terms=3, num_polynomial_terms=3, batch_shape=torch.Size()):
        super().__init__()
        
        self.num_fourier_terms = num_fourier_terms
        self.num_polynomial_terms = num_polynomial_terms

        # Register Fourier frequencies
        self.register_parameter(
            name="raw_frequencies",
            parameter=torch.nn.Parameter(torch.zeros(num_fourier_terms))
        )

        # Register Fourier weights (sine and cosine terms)
        self.register_parameter(
            name="raw_fourier_weights_sin",
            parameter=torch.nn.Parameter(torch.zeros(num_fourier_terms))
        )
        self.register_parameter(
            name="raw_fourier_weights_cos",
            parameter=torch.nn.Parameter(torch.zeros(num_fourier_terms))
        )

        # Register polynomial weights
        self.register_parameter(
            name="raw_poly_weights",
            parameter=torch.nn.Parameter(torch.zeros(num_polynomial_terms))
        )

        # Register constant term
        self.register_parameter(
            name="raw_constant",
            parameter=torch.nn.Parameter(torch.zeros(1))
        )

        # Register constraints
        self.register_constraint("raw_frequencies", Positive())
        self.register_constraint("raw_fourier_weights_sin", Interval(-5.0, 5.0))
        self.register_constraint("raw_fourier_weights_cos", Interval(-5.0, 5.0))
        self.register_constraint("raw_poly_weights", Interval(-5.0, 5.0))
        self.register_constraint("raw_constant", Interval(-5.0, 5.0))

        # Register priors
        # Frequencies prior (log-normal to ensure positivity and reasonable spread)
        self.register_prior(
            "frequencies_prior",
            LogNormalPrior(0.0, 1.0),
            lambda module: module.frequencies,
            lambda module, value: module._set_frequencies(value)
        )

        # Fourier weights priors (normal distribution)
        self.register_prior(
            "fourier_weights_sin_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.fourier_weights_sin,
            lambda module, value: module._set_fourier_weights_sin(value)
        )
        self.register_prior(
            "fourier_weights_cos_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.fourier_weights_cos,
            lambda module, value: module._set_fourier_weights_cos(value)
        )

        # Polynomial weights prior (student-t for robustness)
        self.register_prior(
            "poly_weights_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.poly_weights,
            lambda module, value: module._set_poly_weights(value)
        )

        # Constant term prior
        self.register_prior(
            "constant_prior",
            NormalPrior(0.0, 1.0),
            lambda module: module.constant,
            lambda module, value: module._set_constant(value)
        )

    # Properties
    @property
    def frequencies(self):
        return self.raw_frequencies_constraint.transform(self.raw_frequencies)

    @property
    def fourier_weights_sin(self):
        return self.raw_fourier_weights_sin_constraint.transform(self.raw_fourier_weights_sin)

    @property
    def fourier_weights_cos(self):
        return self.raw_fourier_weights_cos_constraint.transform(self.raw_fourier_weights_cos)

    @property
    def poly_weights(self):
        return self.raw_poly_weights_constraint.transform(self.raw_poly_weights)

    @property
    def constant(self):
        return self.raw_constant_constraint.transform(self.raw_constant)

    # Setters
    def _set_frequencies(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_frequencies=self.raw_frequencies_constraint.inverse_transform(value))

    def _set_fourier_weights_sin(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_fourier_weights_sin=self.raw_fourier_weights_sin_constraint.inverse_transform(value))

    def _set_fourier_weights_cos(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_fourier_weights_cos=self.raw_fourier_weights_cos_constraint.inverse_transform(value))

    def _set_poly_weights(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_poly_weights=self.raw_poly_weights_constraint.inverse_transform(value))

    def _set_constant(self, value):
        if not torch.is_tensor(value):
            value = torch.tensor(value)
        self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value))

    def forward(self, x):
        if x.ndimension() == 1:
            x = x.unsqueeze(-1)

        x1, x2 = x[..., 0], x[..., 1]
        r = torch.sqrt(x1**2 + x2**2)
        theta = torch.atan2(x2, x1)

        # Fourier terms
        fourier_basis = torch.zeros_like(r)
        for i in range(self.num_fourier_terms):
            fourier_basis += (self.fourier_weights_sin[i] * torch.sin(self.frequencies[i] * r) +
                            self.fourier_weights_cos[i] * torch.cos(self.frequencies[i] * r))

        # Polynomial terms
        poly_basis = torch.zeros_like(r)
        for i in range(self.num_polynomial_terms):
            poly_basis += self.poly_weights[i] * r**i

        return self.constant + fourier_basis + poly_basis

def plot_flexible_mean(mean_module, x_range=(-5, 5), y_range=(-5, 5), resolution=100):
    """
    Visualize the flexible mean function
    """
    x = np.linspace(x_range[0], x_range[1], resolution)
    y = np.linspace(y_range[0], y_range[1], resolution)
    X, Y = np.meshgrid(x, y)
    points = np.column_stack((X.flatten(), Y.flatten()))
    
    with torch.no_grad():
        Z = mean_module(torch.tensor(points, dtype=torch.float32))
        Z = Z.reshape(resolution, resolution).numpy()
    
    fig = plt.figure(figsize=(15, 5))
    
    # 2D contour plot
    plt.subplot(121)
    plt.contour(X, Y, Z, levels=20)
    plt.colorbar()
    plt.title("Mean Function (2D)")
    plt.xlabel("x")
    plt.ylabel("y")
    
    # 3D surface plot
    ax = fig.add_subplot(122, projection='3d')
    surf = ax.plot_surface(X, Y, Z, cmap='viridis')
    plt.colorbar(surf)
    ax.set_title("Mean Function (3D)")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("f(x, y)")
    
    plt.tight_layout()
    plt.show()

# Create and visualize
mean_module = FlexibleMean(num_fourier_terms=3, num_polynomial_terms=3)
plot_flexible_mean(mean_module)

#### 3a. iv.a) Train sGP model for custom mean function 4

In [None]:
# Parameters
ndim = 2
n_initial = 50
n_test = 10000

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim)

# # Define priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
# noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

# Conservative/Standard priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
# noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17
# Priors for Sparse GP with RBF/Matern kernel
lengthscale_prior = gpytorch.priors.GammaPrior(2.0, 4.0)  # mean=0.5, more local variations
outputscale_prior = gpytorch.priors.GammaPrior(4.0, 4.0)  # mean=1.0, moderate variance
noise_prior = gpytorch.priors.GammaPrior(1.5, 6.0)        # mean=0.25, small noise



# Define model components
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
# mean_module = CustomMean(ndim)
mean_module = mean_function
covar_module = CustomKernel(
    ndim,
    lengthscale_prior=lengthscale_prior,
    outputscale_prior=outputscale_prior
)

# Initialize and fit model
model = CustomGP(train_X, train_Y, mean_module, covar_module, likelihood)
mll = ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_mll(mll)
plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting



#### 3a. iv.b) Active learing for custom mean function 4

In [None]:
# Parameters
ndim = 2
n_initial = 50
n_test = 10000

# Generate initial data
train_X = sample_points(ndim, n_initial)
train_Y = test_function(train_X, ndim)

# # Define priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 6.0)
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 0.15)
# noise_prior = gpytorch.priors.GammaPrior(1.1, 0.05)

# Conservative/Standard priors
# lengthscale_prior = gpytorch.priors.GammaPrior(3.0, 1.0)  # mean = 3.0, variance = 3.0
# outputscale_prior = gpytorch.priors.GammaPrior(2.0, 2.0)  # mean = 1.0, variance = 0.5
# noise_prior = gpytorch.priors.GammaPrior(1.5, 3.0)        # mean = 0.5, variance = 0.17
# Priors for Sparse GP with RBF/Matern kernel
lengthscale_prior = gpytorch.priors.GammaPrior(2.0, 4.0)  # mean=0.5, more local variations
outputscale_prior = gpytorch.priors.GammaPrior(4.0, 4.0)  # mean=1.0, moderate variance
noise_prior = gpytorch.priors.GammaPrior(1.5, 6.0)        # mean=0.25, small noise



# Define model components
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_prior=noise_prior)
# mean_module = CustomMean(ndim)
mean_module = QuadraticBoundaryMean()
covar_module = CustomKernel(
    ndim,
    lengthscale_prior=lengthscale_prior,
    outputscale_prior=outputscale_prior
)

# Initialize and fit model
model = CustomGP(train_X, train_Y, mean_module, covar_module, likelihood)
mll = ExactMarginalLogLikelihood(likelihood, model)
fit_gpytorch_mll(mll)
plot_step(model, train_X, train_Y.squeeze(-1), ndim=ndim)  # Squeeze for plotting

