In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import os
from matplotlib.ticker import MaxNLocator
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from scipy.interpolate import griddata
import itertools

from baybe.targets import NumericalTarget
from baybe.acquisition.acqfs import LogExpectedImprovement, ExpectedImprovement, ProbabilityOfImprovement, UpperConfidenceBound, qExpectedImprovement
from baybe import Campaign
from baybe.searchspace.continuous import SubspaceContinuous
from baybe.parameters.numerical import NumericalContinuousParameter
from baybe.recommenders import BotorchRecommender
from baybe.objectives import DesirabilityObjective
from baybe.utils.dataframe import add_parameter_noise

from gp_slice_visualization import plot_gp_1d_slices

# Ensure visualization directory exists
os.makedirs('visualization', exist_ok=True)

# Create a log file to keep track of optimization progress
log_path = Path('visualization/optimization_log.csv')
if not log_path.exists():
    with open(log_path, 'w') as f:
        f.write('Iteration,Best_Yield,Best_Impurity,Best_ImpurityXRatio,Desirability\n')

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load optimization data
optimization_file = "optimization.csv"
data = pd.read_csv(optimization_file)

# Get the current iteration count
iteration = len(data)
print(f"Current iteration: {iteration}")

# Convert all columns to numeric (forcing errors to NaN)
data = data.apply(pd.to_numeric, errors='coerce')
data = data.select_dtypes(include=[np.number])  # Drop non-numeric columns
data = data.fillna(data.mean())  # Replace NaNs with column means

# Move numeric data to GPU if available
data_tensor = torch.tensor(data.values, dtype=torch.float32, device=device)

# Define bounds for optimization
pbounds = {
    "T1Celsius": (20.0, 200.0),
    "t1min": (10.0, 60.0),
    "T2Celsius": (20.0, 200.0),
    "t2min": (10.0, 60.0),
    "EquivalentsReagent1": (1.0, 2.0),
    "EquivalentsBASE1": (1.0, 5.0),
    "ConcentrationMolar": (0.820, 0.8201)
}
bounds_df = pd.DataFrame(pbounds, index=["lower", "upper"])

# Define the search space
space = SubspaceContinuous.from_bounds(bounds_df)

# Define acquisition function
acquisition_function = ProbabilityOfImprovement()

# Define targets
target1 = NumericalTarget(name="Yield", mode="MAX", bounds=(0, 100))
target2 = NumericalTarget(name="Impurity", mode="MIN", bounds=(0, 100))
target3 = NumericalTarget(name="ImpurityXRatio", mode="MAX", bounds=(0, 100))

# Define the optimization objective
objective = DesirabilityObjective(
    targets=[target1, target2, target3],
    weights=[1.0, 1.0, 1.0]
)

# Configure the recommender
recommender = BotorchRecommender(acquisition_function=acquisition_function)
recommender.device = device

# Configure the campaign
campaign = Campaign(searchspace=space.to_searchspace(), objective=objective, recommender=recommender)
print(campaign)

# Add measurements
campaign.add_measurements(pd.DataFrame(data_tensor.cpu().numpy(), columns=data.columns))

# Calculate desirability scores for existing data
desirability_scores = campaign.objective.transform(campaign.measurements)

# Visualization functions
def plot_optimization_history(data, iteration):
    """Plot the optimization history for all objectives."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle(f'Optimization History (Iteration {iteration})', fontsize=16)
    
    # Plot Yield
    ax = axes[0, 0]
    ax.plot(range(1, len(data) + 1), data['Yield'], 'bo-')
    ax.set_title('Yield (MAX)')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Yield')
    ax.grid(True)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    # Plot Impurity
    ax = axes[0, 1]
    ax.plot(range(1, len(data) + 1), data['Impurity'], 'ro-')
    ax.set_title('Impurity (MIN)')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Impurity')
    ax.grid(True)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    # Plot ImpurityXRatio
    ax = axes[1, 0]
    ax.plot(range(1, len(data) + 1), data['ImpurityXRatio'], 'go-')
    ax.set_title('ImpurityXRatio (MAX)')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('ImpurityXRatio')
    ax.grid(True)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    # Plot Desirability
    ax = axes[1, 1]
    ax.plot(range(1, len(data) + 1), desirability_scores.values, 'mo-')
    ax.set_title('Desirability Score')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Desirability')
    ax.grid(True)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    plt.tight_layout()
    plt.savefig(f'visualization/optimization_history_{iteration}.png')
    
    # Also save as the "latest" for easy reference
    plt.savefig('visualization/optimization_history_latest.png')
    plt.close()

def plot_parameter_effects(data, iteration):
    """Create scatter plots showing the effect of each parameter on each objective."""
    param_names = ["T1Celsius", "t1min", "T2Celsius", "t2min", 
                   "EquivalentsReagent1", "EquivalentsBASE1", "ConcentrationMolar"]
    target_names = ["Yield", "Impurity", "ImpurityXRatio"]
    
    for target in target_names:
        fig, axes = plt.subplots(2, 4, figsize=(20, 10))
        fig.suptitle(f'Parameter Effects on {target} (Iteration {iteration})', fontsize=16)
        
        for i, param in enumerate(param_names):
            row, col = i // 4, i % 4
            ax = axes[row, col]
            
            # Create a colormap for iteration number
            cmap = plt.cm.viridis
            norm = plt.Normalize(1, len(data))
            colors = [cmap(norm(i+1)) for i in range(len(data))]
            
            # Plot points colored by iteration
            for j in range(len(data)):
                ax.scatter(data[param].iloc[j], data[target].iloc[j], 
                          color=colors[j], s=50, alpha=0.7)
            
            # Add trend line
            z = np.polyfit(data[param], data[target], 1)
            p = np.poly1d(z)
            ax.plot(data[param], p(data[param]), "r--", alpha=0.5)
            
            ax.set_xlabel(param)
            ax.set_ylabel(target)
            ax.grid(True)
            
            # Highlight the best point
            if target == "Impurity":  # For Impurity, lower is better
                best_idx = data[target].idxmin()
            else:  # For Yield and ImpurityXRatio, higher is better
                best_idx = data[target].idxmax()
                
            ax.scatter(data.loc[best_idx, param], data.loc[best_idx, target], 
                      color='red', s=100, edgecolor='black', zorder=5, label='Best')
                
            # Highlight latest point
            latest_idx = len(data) - 1
            ax.scatter(data.loc[latest_idx, param], data.loc[latest_idx, target], 
                      color='green', s=100, edgecolor='black', zorder=5, label='Latest')
        
        # Remove empty subplots
        for i in range(len(param_names), 8):
            row, col = i // 4, i % 4
            fig.delaxes(axes[row, col])
            
        # Add a manual legend
        handles = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(0), markersize=10, label='Early'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(0.5), markersize=10, label='Middle'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=cmap(1.0), markersize=10, label='Recent'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Best'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='green', markersize=10, label='Latest')
        ]
        axes[0, 3].legend(handles=handles, loc='center')
        
        plt.tight_layout()
        plt.savefig(f'visualization/parameter_effects_{target}_{iteration}.png')
        plt.savefig(f'visualization/parameter_effects_{target}_latest.png')
        plt.close()

def plot_pairwise_interactions(data, iteration, top_n=4):
    """Plot pairwise parameter interactions for the top influential parameters."""
    # Identify most influential parameters based on correlation with objectives
    param_names = ["T1Celsius", "t1min", "T2Celsius", "t2min", 
                   "EquivalentsReagent1", "EquivalentsBASE1", "ConcentrationMolar"]
    target_names = ["Yield", "Impurity", "ImpurityXRatio"]
    
    # Calculate absolute correlation values
    corr_matrix = data[param_names + target_names].corr().abs()
    
    for target in target_names:
        # Get top parameters by correlation with target
        top_params = corr_matrix[target].sort_values(ascending=False)[1:top_n+1].index.tolist()
        
        if len(top_params) >= 2:  # Need at least 2 parameters for pairwise plots
            # Create pairwise plots
            fig, axes = plt.subplots(len(top_params)-1, len(top_params)-1, figsize=(15, 15))
            fig.suptitle(f'Pairwise Parameter Interactions for {target} (Iteration {iteration})', fontsize=16)
            
            # Ensure axes is a 2D array even with a single subplot
            if len(top_params) == 2:
                axes = np.array([[axes]])
            
            for i, param1 in enumerate(top_params[:-1]):
                for j, param2 in enumerate(top_params[1:]):
                    if i <= j:  # Only fill the lower triangle
                        ax = axes[i, j]
                        
                        # Create a colormap for the target value
                        norm = plt.Normalize(min(data[target]), max(data[target]))
                        
                        # Plot points colored by target value
                        for k in range(len(data)):
                            ax.scatter(data[param1].iloc[k], data[param2].iloc[k], 
                                     color=plt.cm.coolwarm(norm(data[target].iloc[k])), 
                                     s=50, alpha=0.7)
                        
                        ax.set_xlabel(param1)
                        ax.set_ylabel(param2)
                        ax.grid(True)
                        
                        # Add contours if enough data points
                        if len(data) > 5:
                            try:
                                # Create a grid for contour plotting
                                xi = np.linspace(data[param1].min(), data[param1].max(), 100)
                                yi = np.linspace(data[param2].min(), data[param2].max(), 100)
                                xi, yi = np.meshgrid(xi, yi)
                                
                                # Interpolate the target values
                                zi = griddata((data[param1], data[param2]), data[target], 
                                             (xi, yi), method='linear')
                                
                                # Plot contours
                                cs = ax.contour(xi, yi, zi, cmap='viridis', alpha=0.5)
                                # Add colorbar
                                fig.colorbar(cs, ax=ax, label=target)
                            except Exception as e:
                                print(f"Could not create contour for {param1} vs {param2}: {e}")
            
            plt.tight_layout()
            plt.savefig(f'visualization/pairwise_interactions_{target}_{iteration}.png')
            plt.savefig(f'visualization/pairwise_interactions_{target}_latest.png')
            plt.close()

def plot_parallel_coordinates(data, iteration):
    """Create parallel coordinates plot to visualize multidimensional data."""
    # Normalize all columns to [0, 1] for better visualization
    normalized_data = data.copy()
    
    for col in data.columns:
        min_val = data[col].min()
        max_val = data[col].max()
        if max_val > min_val:
            normalized_data[col] = (data[col] - min_val) / (max_val - min_val)
    
    # Calculate a combined desirability score
    # For Impurity, lower is better, so we use 1 - normalized value
    combined_score = (normalized_data['Yield'] + (1 - normalized_data['Impurity']) + 
                      normalized_data['ImpurityXRatio']) / 3
    
    # Add the score to the data
    normalized_data['CombinedScore'] = combined_score
    
    # Sort by combined score
    normalized_data = normalized_data.sort_values('CombinedScore')
    
    # Create the parallel coordinates plot
    fig, ax = plt.subplots(figsize=(15, 8))  # Create figure and axes explicitly
    
    # Use a colormap to visualize the combined score
    cmap = plt.cm.viridis
    color_values = normalized_data['CombinedScore'].values
    
    # Plot each row as a line
    for i, (idx, row) in enumerate(normalized_data.iterrows()):
        # Get color from colormap
        color = cmap(color_values[i])
        
        # Plot the line
        xs = list(range(len(row) - 1))  # Exclude the CombinedScore column
        ys = row.values[:-1]  # Exclude the CombinedScore column
        ax.plot(xs, ys, color=color, alpha=0.7)
    
    # Set the x-axis ticks and labels
    ax.set_xticks(range(len(normalized_data.columns) - 1))
    ax.set_xticklabels(normalized_data.columns[:-1], rotation=45)
    
    # Add grid lines
    ax.grid(True, axis='y')
    
    # Add colorbar - pass the current axes to the colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(color_values.min(), color_values.max()))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax)  # Specify ax parameter
    cbar.set_label('Combined Score (higher is better)')
    
    ax.set_title(f'Parallel Coordinates Plot (Iteration {iteration})')
    plt.tight_layout()
    plt.savefig(f'visualization/parallel_coordinates_{iteration}.png')
    plt.savefig('visualization/parallel_coordinates_latest.png')
    plt.close(fig)

def create_3d_surface_plot(data, iteration):
    """Create 3D surface plots for objectives."""
    if len(data) < 5:  # Need enough points for interpolation
        return
        
    try:
        # Identify the two most influential parameters for each objective
        param_names = ["T1Celsius", "t1min", "T2Celsius", "t2min", 
                       "EquivalentsReagent1", "EquivalentsBASE1"]
        target_names = ["Yield", "Impurity", "ImpurityXRatio"]
        
        corr_matrix = data[param_names + target_names].corr().abs()
        
        for target in target_names:
            # Get top 2 parameters
            top_params = corr_matrix[target].sort_values(ascending=False)[1:3].index.tolist()
            
            if len(top_params) >= 2:
                param1, param2 = top_params
                
                # Create grid for surface
                x_min, x_max = data[param1].min(), data[param1].max()
                y_min, y_max = data[param2].min(), data[param2].max()
                
                x = np.linspace(x_min, x_max, 30)
                y = np.linspace(y_min, y_max, 30)
                X, Y = np.meshgrid(x, y)
                
                # Interpolate target values
                points = data[[param1, param2]].values
                values = data[target].values
                Z = griddata(points, values, (X, Y), method='cubic')
                
                # Create 3D plot
                fig = plt.figure(figsize=(12, 10))
                ax = fig.add_subplot(111, projection='3d')
                
                # Plot surface
                surf = ax.plot_surface(X, Y, Z, cmap='viridis', edgecolor='none', alpha=0.7)
                
                # Add actual data points
                ax.scatter(data[param1], data[param2], data[target], c='red', s=50, label='Observed')
                
                # Add labels
                ax.set_xlabel(param1)
                ax.set_ylabel(param2)
                ax.set_zlabel(target)
                ax.set_title(f'3D Surface for {target} vs {param1} and {param2}')
                
                # Add colorbar
                fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)
                
                plt.savefig(f'visualization/surface_plot_{target}_{iteration}.png')
                plt.savefig(f'visualization/surface_plot_{target}_latest.png')
                plt.close()
    except Exception as e:
        print(f"Error creating 3D surface plot: {e}")

def update_log_file(data, desirability_scores, iteration):
    """Update the log file with best values found so far."""
    best_yield = data['Yield'].max()
    best_impurity = data['Impurity'].min()
    best_impurity_ratio = data['ImpurityXRatio'].max()
    best_desirability = desirability_scores.max()
    
    with open(log_path, 'a') as f:
        f.write(f'{iteration},{best_yield},{best_impurity},{best_impurity_ratio},{best_desirability}\n')

# Generate all visualization plots with current data
plot_optimization_history(data, iteration)
plot_parameter_effects(data, iteration)
plot_pairwise_interactions(data, iteration)
plot_parallel_coordinates(data, iteration)
create_3d_surface_plot(data, iteration)
update_log_file(data, desirability_scores, iteration)

# Create 1D GP slice visualizations
# Get the best point so far as reference
# Get the best point so far as reference
if "Impurity" in data.columns:  # Ensuring we have the right columns
    # Combined score (higher is better)
    data['combined_score'] = data['Yield'] - data['Impurity'] + data['ImpurityXRatio']
    best_idx = data['combined_score'].idxmax()
    
    # Fixed approach to extract reference point - more robust than the original
    try:
        # Convert parameter names to a list and ensure they exist in the data
        param_list = [p for p in space.parameter_names if p in data.columns]
        
        # Extract reference point more safely
        reference_point = []
        for param in param_list:
            reference_point.append(float(data.loc[best_idx, param]))
        
        # Create GP slice visualizations with reference point
        try:
            plot_gp_1d_slices(campaign, data, iteration, reference_point)
        except Exception as e:
            print(f"Error creating GP slice plots with reference point: {e}")
            # Try without reference point as fallback
            try:
                plot_gp_1d_slices(campaign, data, iteration)
            except Exception as e2:
                print(f"Error creating GP slice plots (fallback attempt): {e2}")
    except Exception as e:
        print(f"Error creating reference point: {e}")
        # Try without reference point
        try:
            plot_gp_1d_slices(campaign, data, iteration)
        except Exception as e2:
            print(f"Error creating GP slice plots (fallback attempt): {e2}")

# ✅ Get next recommended parameters
next_params = campaign.recommend(batch_size=1)

# NOISE CONTROL: Apply noise to recommended parameters
parameters = list(space.parameters)
add_parameter_noise(
    data=next_params,
    parameters=parameters,
    noise_type="relative_percent",  # Options: "absolute" or "relative_percent"
    noise_level=1.0  # 1% noise for relative_percent mode
)

next_params["Yield"] = 0  # Placeholder, to be filled later
next_params["Impurity"] = 0  # Placeholder
next_params["ImpurityXRatio"] = 0  # Placeholder
campaign.add_measurements(next_params)
desirability = campaign.objective.transform(campaign.measurements)

# Print summary of the optimization status
print("\nOptimization Status:")
print(f"Completed Iterations: {iteration}")
print(f"Best Yield: {data['Yield'].max():.4f}")
print(f"Best Impurity: {data['Impurity'].min():.4f}")
print(f"Best ImpurityXRatio: {data['ImpurityXRatio'].max():.4f}")

# Fix the desirability score formatting
if isinstance(desirability_scores.max(), pd.Series):
    # If it returns a Series, extract the first value
    max_desirability = float(desirability_scores.max().iloc[0])
else:
    # If it's already a scalar
    max_desirability = float(desirability_scores.max())
print(f"Best Desirability: {max_desirability:.4f}")

# Save suggested parameters
suggested_params_df = pd.DataFrame(next_params)  # Convert to DataFrame
suggested_params_df = suggested_params_df[list(space.parameter_names)]
suggested_params_df.to_csv("suggestion.csv", index=False)
print("\nSuggested parameters saved to suggestion.csv")