# Surrogate modeling or functional approximation for the forward model 
    - Feed forward neural network, also known as a multilayer perceptron (MLP)
## Develop a neural network model that will act as a lookup table and work with IMP
* Generate ground truth dataset to train the NN on 

In [28]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
import h5py
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from scipy.optimize import minimize_scalar
from scipy.spatial.distance import cdist
from scipy.ndimage import gaussian_filter

### MCMC code to generate ground truth data points
* Points $\vec{p_i}$ and $\vec{p_j}$ picked at random
* Spheres of radius $\sigma_i$ and $\sigma_j$ around the two points
* Volumes of the spheres as $v_i = \frac{4}{3}\pi \left(\sigma_i\right)^3$ and $v_j = \frac{4}{3}\pi \left(\sigma_j\right)^3$
* Check distance $d_{ij} = |\vec{p_i} -\vec{p_j}|$
    - if $d_{ij} < \sigma_i + \sigma_j$
        - $xlvol = \frac{4}{3}\pi \left(\frac{L}{2}\right)^3$
        - $vol_i = min(v_i, xlvol)$
        - $vol_j = min(v_j, xlvol)$

## Analytical model is not what we shall implement, but a no brainer MCMC code, that will simply pick random points and compute the distances between them, if $d_{ij} < L$ then, store it as success, run the same for some million times, get a converged probability

In [31]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.mplot3d import Axes3D
import h5py
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import pickle

class CrosslinkProbabilityEstimator:
    '''
    Monte Carlo estimator for crosslink probability between two points.
    Generates training data, computes probabilities, and visualizes configurations.
    
    Parameters:
    -----------
    L : float
        Crosslinker length threshold (default: 21.0 Angstroms)
    '''
    
    def __init__(self, L=21.0):
        self.L = L
        self.configurations = []
        self.results = []
        
    def _pick_points(self, p1, p2, sigma1, sigma2):
        '''
        Sample one point from each uncertainty sphere and check if crosslink is possible.
        
        Parameters:
        -----------
        p1 : np.ndarray
            Center position of first point (shape: (3,))
        p2 : np.ndarray
            Center position of second point (shape: (3,))
        sigma1 : float
            Positional uncertainty (radius) for first point
        sigma2 : float
            Positional uncertainty (radius) for second point
            
        Returns:
        --------
        bool
            True if distance between sampled points < L, False otherwise
        '''
        q1 = p1 + np.random.normal(0, sigma1, size=p1.shape)
        q2 = p2 + np.random.normal(0, sigma2, size=p2.shape)
        d = np.linalg.norm(q1 - q2)
        return d < self.L
    
    def estimate_probability(self, p1, p2, sigma1, sigma2, N=100000):
        '''
        Estimate crosslink probability using Monte Carlo sampling.
        
        Parameters:
        -----------
        p1 : np.ndarray
            Center position of first point (shape: (3,))
        p2 : np.ndarray
            Center position of second point (shape: (3,))
        sigma1 : float
            Positional uncertainty (radius) for first point
        sigma2 : float
            Positional uncertainty (radius) for second point
        N : int
            Number of Monte Carlo trials
            
        Returns:
        --------
        float
            Estimated probability of crosslinking (between 0 and 1)
        '''
        success_count = 0
        for _ in range(N):
            if self._pick_points(p1, p2, sigma1, sigma2):
                success_count += 1
        
        return success_count / N
    
    def generate_configurations(self, num_configs, 
                               position_range=(-50, 50),
                               sigma_range=(1, 15),
                               N_trials=100000,
                               verbose=True):
        '''
        Generate random configurations and compute crosslink probabilities.
        
        Parameters:
        -----------
        num_configs : int
            Number of configurations to generate
        position_range : tuple
            Range for random point positions (min, max)
        sigma_range : tuple
            Range for random sigma values (min, max)
        N_trials : int
            Number of Monte Carlo trials per configuration
        verbose : bool
            Whether to print progress
            
        Returns:
        --------
        None (stores results in self.configurations and self.results)
        '''
        self.configurations = []
        self.results = []
        
        for i in range(num_configs):
            # Generate random configuration
            p1 = np.random.uniform(position_range[0], position_range[1], size=(3,))
            p2 = np.random.uniform(position_range[0], position_range[1], size=(3,))
            sigma1 = np.random.uniform(sigma_range[0], sigma_range[1])
            sigma2 = np.random.uniform(sigma_range[0], sigma_range[1])
            
            # Compute distance between centers
            d = np.linalg.norm(p1 - p2)
            
            # Estimate probability
            prob = self.estimate_probability(p1, p2, sigma1, sigma2, N=N_trials)
            
            # Store configuration
            config = {
                'p1': p1,
                'p2': p2,
                'sigma1': sigma1,
                'sigma2': sigma2,
                'd': d,
                'L': self.L,
                'probability': prob
            }
            self.configurations.append(config)
            
            # Store result for training data (invariant features)
            self.results.append([d, sigma1, sigma2, self.L, prob])
            
            if verbose and (i + 1) % max(1, num_configs // 10) == 0:
                print(f"Generated {i + 1}/{num_configs} configurations...")
        
        if verbose:
            print(f"Completed! Generated {num_configs} configurations.")
    
    def save_training_data(self, filename='surrogate_model_data', format='npz'):
        '''
        Save generated training data to disk.
        
        Parameters:
        -----------
        filename : str
            Base filename (without extension)
        format : str
            'npz' for NumPy compressed or 'hdf5' for HDF5 format
        '''
        if not self.results:
            raise ValueError("No data to save. Run generate_configurations() first.")
        
        results_array = np.array(self.results)
        X_data = results_array[:, :-1]  # Features: [d, sigma1, sigma2, L]
        y_data = results_array[:, -1:]  # Target: [probability]
        
        if format == 'npz':
            np.savez_compressed(f'{filename}.npz', X=X_data, y=y_data)
            print(f"Saved training data to {filename}.npz")
        elif format == 'hdf5':
            with h5py.File(f'{filename}.h5', 'w') as f:
                f.create_dataset('X', data=X_data, compression='gzip')
                f.create_dataset('y', data=y_data, compression='gzip')
            print(f"Saved training data to {filename}.h5")
        else:
            raise ValueError("Format must be 'npz' or 'hdf5'")
    
    def prepare_training_data(self, test_size=0.2, save_scaler=True):
        '''
        Prepare and normalize training data for neural network.
        
        Parameters:
        -----------
        test_size : float
            Fraction of data to use for validation
        save_scaler : bool
            Whether to save the fitted scaler to disk
            
        Returns:
        --------
        tuple
            (X_train_scaled, X_val_scaled, y_train, y_val, scaler)
        '''
        if not self.results:
            raise ValueError("No data to prepare. Run generate_configurations() first.")
        
        results_array = np.array(self.results)
        X_all = results_array[:, :-1]  # Features: [d, sigma1, sigma2, L]
        y_all = results_array[:, -1:]  # Target: [probability]
        
        # Split data
        X_train, X_val, y_train, y_val = train_test_split(
            X_all, y_all, test_size=test_size, random_state=42
        )
        
        # Fit scaler on training data only
        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)
        
        if save_scaler:
            with open('data_scaler.pkl', 'wb') as f:
                pickle.dump(scaler, f)
            print("Saved scaler to data_scaler.pkl")
        
        print(f"Training set: {X_train_scaled.shape[0]} samples")
        print(f"Validation set: {X_val_scaled.shape[0]} samples")
        
        return X_train_scaled, X_val_scaled, y_train, y_val, scaler
    
    def visualize_configurations(self, num_visualize=None, output_dir='output_figures'):
        '''
        Visualize generated configurations as 3D sphere plots.
        
        Parameters:
        -----------
        num_visualize : int or None
            Number of configurations to visualize (None = visualize all)
        output_dir : str
            Directory to save PDF file
        '''
        if not self.configurations:
            raise ValueError("No configurations to visualize. Run generate_configurations() first.")
        
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Determine how many to visualize
        if num_visualize is None:
            num_visualize = len(self.configurations)
        else:
            num_visualize = min(num_visualize, len(self.configurations))
        
        pdf_path = os.path.join(output_dir, 'configurations_visualization.pdf')
        
        with PdfPages(pdf_path) as pdf:
            for i in range(num_visualize):
                config = self.configurations[i]
                
                fig = plt.figure(figsize=(12, 10))
                ax = fig.add_subplot(111, projection='3d')
                
                p1 = config['p1']
                p2 = config['p2']
                sigma1 = config['sigma1']
                sigma2 = config['sigma2']
                d = config['d']
                prob = config['probability']
                
                # Draw spheres
                u = np.linspace(0, 2 * np.pi, 50)
                v = np.linspace(0, np.pi, 50)
                
                # Sphere 1
                x1 = p1[0] + sigma1 * np.outer(np.cos(u), np.sin(v))
                y1 = p1[1] + sigma1 * np.outer(np.sin(u), np.sin(v))
                z1 = p1[2] + sigma1 * np.outer(np.ones(np.size(u)), np.cos(v))
                ax.plot_surface(x1, y1, z1, alpha=0.3, color='blue', label='Sphere 1')
                
                # Sphere 2
                x2 = p2[0] + sigma2 * np.outer(np.cos(u), np.sin(v))
                y2 = p2[1] + sigma2 * np.outer(np.sin(u), np.sin(v))
                z2 = p2[2] + sigma2 * np.outer(np.ones(np.size(u)), np.cos(v))
                ax.plot_surface(x2, y2, z2, alpha=0.3, color='red', label='Sphere 2')
                
                # Plot centers
                ax.scatter(*p1, color='blue', s=100, marker='o', label='Center 1')
                ax.scatter(*p2, color='red', s=100, marker='o', label='Center 2')
                
                # Draw line between centers
                ax.plot([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], 
                       'k--', linewidth=2, label=f'd = {d:.2f}Å')
                
                # Annotations
                ax.text(p1[0], p1[1], p1[2] + sigma1 + 5, 
                       f'σ₁ = {sigma1:.2f}Å\np₁ = ({p1[0]:.1f}, {p1[1]:.1f}, {p1[2]:.1f})',
                       fontsize=9, ha='center')
                ax.text(p2[0], p2[1], p2[2] + sigma2 + 5, 
                       f'σ₂ = {sigma2:.2f}Å\np₂ = ({p2[0]:.1f}, {p2[1]:.1f}, {p2[2]:.1f})',
                       fontsize=9, ha='center')
                
                # Set labels and title
                ax.set_xlabel('X (Å)')
                ax.set_ylabel('Y (Å)')
                ax.set_zlabel('Z (Å)')
                ax.set_title(f'Configuration {i + 1}\n'
                           f'Crosslinker Length L = {self.L:.1f}Å\n'
                           f'Distance d = {d:.2f}Å\n'
                           f'Crosslink Probability = {prob:.4f}',
                           fontsize=12, fontweight='bold')
                
                ax.legend(loc='upper right')
                
                # Equal aspect ratio
                max_range = np.array([
                    max(abs(p1[0]) + sigma1, abs(p2[0]) + sigma2),
                    max(abs(p1[1]) + sigma1, abs(p2[1]) + sigma2),
                    max(abs(p1[2]) + sigma1, abs(p2[2]) + sigma2)
                ]).max()
                
                mid_x = (p1[0] + p2[0]) / 2
                mid_y = (p1[1] + p2[1]) / 2
                mid_z = (p1[2] + p2[2]) / 2
                
                ax.set_xlim(mid_x - max_range * 0.6, mid_x + max_range * 0.6)
                ax.set_ylim(mid_y - max_range * 0.6, mid_y + max_range * 0.6)
                ax.set_zlim(mid_z - max_range * 0.6, mid_z + max_range * 0.6)
                
                plt.tight_layout()
                pdf.savefig(fig, bbox_inches='tight')
                plt.close(fig)
                
                if (i + 1) % 10 == 0:
                    print(f"Visualized {i + 1}/{num_visualize} configurations...")
        
        print(f"Saved visualization to {pdf_path}")

# Example usage
if __name__ == "__main__":
    # Create estimator
    estimator = CrosslinkProbabilityEstimator(L=21.0)
    
    # Generate configurations
    estimator.generate_configurations(num_configs=100, N_trials=100000)
    
    # Save training data
    estimator.save_training_data(filename='surrogate_model_data', format='hdf5')
    
    # Prepare for training
    X_train, X_val, y_train, y_val, scaler = estimator.prepare_training_data()
    
    # Visualize first 20 configurations
    estimator.visualize_configurations(num_visualize=20)

Generated 10/100 configurations...
Generated 20/100 configurations...
Generated 30/100 configurations...
Generated 40/100 configurations...
Generated 50/100 configurations...
Generated 60/100 configurations...
Generated 70/100 configurations...
Generated 80/100 configurations...
Generated 90/100 configurations...
Generated 100/100 configurations...
Completed! Generated 100 configurations.
Saved training data to surrogate_model_data.h5
Saved scaler to data_scaler.pkl
Training set: 80 samples
Validation set: 20 samples
Visualized 10/20 configurations...
Visualized 20/20 configurations...
Saved visualization to output_figures/configurations_visualization.pdf
