# Benchmarks

Our code right now takes a while to run for a large number of directions. We'd like to understand what parts take the longest to run, so we can speed it up.

### Loss landscape and model

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import List

class LossLandscape:
    def __init__(self, 
                 minima_sharp_loc=-1.0, 
                 minima_wide_loc=1.0,
                 sharp_width=0.1,
                 wide_width=0.2,
                 amplitude=1.0,
                 baseline=1.0):
        """
        Initialize a customizable loss landscape with two minima.
        """
        self.minima_sharp_loc = minima_sharp_loc
        self.minima_wide_loc = minima_wide_loc
        self.sharp_width = sharp_width
        self.wide_width = wide_width
        self.amplitude = amplitude
        self.baseline = baseline
        
    def get_minima_parameters(self):
        """Return the key parameters defining the minima"""
        return {
            'minima_sharp_loc': self.minima_sharp_loc,
            'minima_wide_loc': self.minima_wide_loc,
            'sharp_width': self.sharp_width,
            'wide_width': self.wide_width
        }
        
    def __call__(self, x):
        """Compute the loss at point(s) x"""
        gaussian_sharp = self.amplitude * torch.exp(
            -0.5 * ((x - self.minima_sharp_loc) / self.sharp_width)**2
        )
        gaussian_wide = self.amplitude * torch.exp(
            -0.5 * ((x - self.minima_wide_loc) / self.wide_width)**2
        )
        return self.baseline - gaussian_sharp - gaussian_wide
    
    def visualize(self, x_range=(-2, 2), num_points=500):
        """Visualize the loss landscape"""
        x = torch.linspace(x_range[0], x_range[1], num_points)
        loss = self(x)
        
        plt.figure(figsize=(8, 5))
        plt.plot(x.numpy(), loss.numpy())
        plt.xlabel('x')
        plt.ylabel('Loss')
        plt.title('Loss Landscape')
        
        plt.axvline(x=self.minima_sharp_loc, color='r', 
                   linestyle='--', alpha=0.3, label='Sharp minimum')
        plt.axvline(x=self.minima_wide_loc, color='b', 
                   linestyle='--', alpha=0.3, label='Wide minimum')
        
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.show()

    def random_parameter_search(self, dim: int, num_samples: int = 10000, upper = 10, low = -10):
        """
        Perform random parameter search using the class's minima parameters
        
        Args:
            dim: Dimension of parameter space
            num_samples: Number of random samples to generate
            
        Returns:
            Tuple of (sharp_count, wide_count, samples)
        """
        minima_sharp_count = 0
        minima_wide_count = 0
        samples = []
        
        for _ in range(num_samples):
            params = np.random.uniform(low=low, high=upper, size=dim)
            product = np.prod(params)
            samples.append(product)
            
            if abs(product - self.minima_sharp_loc) < self.sharp_width:
                minima_sharp_count += 1
            if abs(product - self.minima_wide_loc) < self.wide_width:
                minima_wide_count += 1
        
        return minima_sharp_count, minima_wide_count, samples
        
class NParameterModel(torch.nn.Module):
    def __init__(self, initial_values: List[float]):
        super().__init__()
        # Create N parameters from the initial values
        self.params = torch.nn.ParameterList([
            torch.nn.Parameter(torch.tensor([val], dtype=torch.float32))
            for val in initial_values
        ])
    
    def forward(self):
        # Compute product of all parameters
        x = torch.prod(torch.stack([p for p in self.params]))
        return x

    def get_parameter_values(self):
        return [p.item() for p in self.params]

### Perturbation Functions

These are the likely bottlenecks that we want to benchmark.

In [2]:
# Importing our existing funcs
import os
import sys
from pathlib import Path
# Add parent directory to sys.path
parent_dir = Path.cwd().parent
sys.path.append(str(parent_dir))
# Import modules
from perturb_simple import (
    generate_random_perturbations,
)

# Class for implementing perturbations
class ModelPerturber:
    def __init__(self, model):
        self.model = model
        self.original_state = {n: p.detach().clone() 
                              for n, p in model.named_parameters()}
    
    def apply_perturbation(self, perturbation_dict):
        """
        Apply custom perturbations to model weights
        perturbation_dict: {param_name: tensor_with_same_shape}
        """
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in perturbation_dict:
                    param.add_(perturbation_dict[name])
    
    def reset(self):
        """Revert to original weights"""
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(self.original_state[name])

def wiggle_simple(simple_model, loss_fn, perturbation_direction, coefficients):
    """
    Evaluates the model with given perturbation coefficients applied to each direction.
    
    Args:
        simple_model: A simple model that returns the parameter value directly.
        loss_fn: the loss for that parameter value
        perturbations: List of dictionaries [{param_name: tensor}] of what the specific perturbation does to each weight
        coefficients: List of coefficients to apply to each perturbation direction
        
    Returns:
        Dictionary containing:
        - 'losses': array of losses for each perturbation
        - 'coefficients': array of coefficients used (same as input)
        - 'perturbations': list of perturbation directions used
    """
    perturber = ModelPerturber(simple_model)
    losses = []

    for i, coeff in enumerate(coefficients):
        perturbation = {
            name: perturbation_direction[name] * coeff
            for name in perturbation_direction
        }

        perturber.apply_perturbation(perturbation)

        with torch.no_grad():
            x = simple_model()
            loss = loss_fn(x)
            losses.append(loss.item())
            
        perturber.reset()

    return {
        'losses': np.array(losses),
        'coefficients': coefficients,
        'perturbations': perturbation_direction,
    }

def wiggle_multiple_directions(
    model,
    loss_fn,
    perturbation_directions,  # List of {param_name: tensor} dicts
    coefficients,             # List of coefficients (same for all directions)
    verbose=False,            # Print progress
):
    """
    Evaluates wiggle_simple along a list of perturbations
    
    Returns:
        List of results (same format as wiggle_simple), one per direction.
    """
    results = []
    
    for i, direction in enumerate(perturbation_directions):
        if verbose:
            print(f"Evaluating direction {i+1}/{len(perturbation_directions)}...")
        
        # Run wiggle_simple for this direction
        result = wiggle_simple(
            simple_model=model,
            loss_fn=loss_fn,
            perturbation_direction=direction,
            coefficients=coefficients,
        )
        results.append(result)
    
    return results

def loss_threshold_crossing(wiggle_results, loss_threshold):
    """
    Find the radius r at which each direction first crosses the loss threshold.
    
    Args:
        wiggle_results: List of result dicts containing 'losses' and 'coefficients'
        Access is in the form wiggle_results[list_index]['losses'][loss for specific coeff]
        loss_threshold: Loss value threshold to search for
    
    Returns:
        Tuple of (r_values, valid_directions) where:
        - r_values: List of radii where threshold was crossed (empty if never crossed)
        - valid_directions: Boolean mask indicating which directions crossed threshold
    """
    r_values = []
    valid_directions = []
    
    for result in wiggle_results:
        coefficients = result['coefficients']
        losses = result['losses']
        crossed = False
        
        # Find first point where loss exceeds threshold
        for i in range(len(losses)):
            if losses[i] > loss_threshold:
                if i > 0:  # Only record if we have a valid previous point
                    r_values.append(abs(coefficients[i-1]))
                    crossed = True
                break
                
        valid_directions.append(crossed)
        if crossed == False: #still append something, just for completeness, even if it's an underestimate
            r_values.append(abs(coefficients[len(losses) - 1]))
    
    return r_values, valid_directions


# Experiments

Here, you can run the experiment. Change experimental parameters at will.

In [3]:
# Initialize our loss landscape
minima_sharp_loc=-1.0
minima_wide_loc=0.5
sharp_width=0.1
wide_width=0.2
loss_fn = LossLandscape(minima_sharp_loc=minima_sharp_loc, 
                 minima_wide_loc=minima_wide_loc,
                 sharp_width=sharp_width,
                 wide_width=wide_width)

# The loss is 1 - gaussian_sharp, which is 0.3935
loss_threshold = 0.3935

# A family of models, located at the wide minima, with a variety of scale factors
scale_factors = [1.0]
model_family = [NParameterModel([np.sqrt(minima_wide_loc)*scale, np.sqrt(minima_wide_loc)/scale]) for scale in scale_factors]

# Maximum number of perturbation directions
num_perturb_directions = 50
seed = 10
torch.manual_seed(seed)
random_perturb_vectors = generate_random_perturbations(model_family[0], n = num_perturb_directions)

# Coefficients to sample
N = 1001
coefficients = np.linspace(0, 10, N)

# Actual Code

Here is the actual code, which we want to benchmark. First, which check how much wiggle_multiple and loss_threshold take. We suspect that wiggle_multiple is the real botteneck.

In [4]:
import time

# Initialize storage for r values
random_r_array = [[None] for _ in range(len(model_family))]

# The minima and the width
a = minima_wide_loc
w = wide_width

# Go through each member of the model family
for model_idx, model in enumerate(model_family):
    print(f"\nProcessing model {model_idx + 1}/{len(model_family)}")

    ################ RANDOM ################
    ## Numeric random r
    
    # Start timer for wiggle_multiple_directions
    start_wiggle = time.time()
    
    random_loss_results = wiggle_multiple_directions(
        model=model_family[model_idx],
        loss_fn=loss_fn,
        perturbation_directions=random_perturb_vectors,
        coefficients=coefficients,
        #verbose = True,
    )
    
    # End timer for wiggle_multiple_directions
    end_wiggle = time.time()
    wiggle_time = end_wiggle - start_wiggle
    print(f"wiggle_multiple_directions took {wiggle_time:.4f} seconds")
    
    # Start timer for loss_threshold_crossing
    start_threshold = time.time()
    
    random_r_values, valid_directions = loss_threshold_crossing(random_loss_results, loss_threshold)
    
    # End timer for loss_threshold_crossing
    end_threshold = time.time()
    threshold_time = end_threshold - start_threshold
    print(f"loss_threshold_crossing took {threshold_time:.4f} seconds")
    
    # Check for invalid directions
    if not all(valid_directions):
        invalid_count = len([v for v in valid_directions if not v])
        print(f"Random Warning: {invalid_count}/{len(valid_directions)} directions failed threshold")
    
    # Store r_values
    random_r_array[model_idx] = random_r_values
    print("Done random!")

print("Finished computing radii!")


Processing model 1/1
wiggle_multiple_directions took 12.7266 seconds
loss_threshold_crossing took 0.0000 seconds
Done random!
Finished computing radii!


So as expected, wiggle multiple takes the majority of the time. It takes nearly 1/4 of a second per direction.

### Wiggle Multiple Breakdown

In [6]:
def wiggle_simple_benchmark(simple_model, loss_fn, perturbation_direction, coefficients):
    """
    Evaluates the model with given perturbation coefficients applied to each direction.
    """
    perturber = ModelPerturber(simple_model)
    losses = []
    
    # Track time per coefficient
    total_time = 0.0
    apply_time = 0.0
    loss_time = 0.0
    reset_time = 0.0

    for i, coeff in enumerate(coefficients):
        # Time perturbation application
        start_apply = time.time()
        perturbation = {
            name: perturbation_direction[name] * coeff
            for name in perturbation_direction
        }
        perturber.apply_perturbation(perturbation)
        apply_time += time.time() - start_apply

        # Time loss computation
        start_loss = time.time()
        with torch.no_grad():
            x = simple_model()
            loss = loss_fn(x)
            losses.append(loss.item())
        loss_time += time.time() - start_loss

        # Time reset
        start_reset = time.time()
        perturber.reset()
        reset_time += time.time() - start_reset

    total_time = apply_time + loss_time + reset_time
    
    # Print timing breakdown (per direction)
    print(f"\nPer-coefficient breakdown for direction:")
    print(f"  - Apply perturbation: {apply_time:.4f}s ({100 * apply_time / total_time:.1f}%)")
    print(f"  - Loss computation:   {loss_time:.4f}s ({100 * loss_time / total_time:.1f}%)")
    print(f"  - Reset:              {reset_time:.4f}s ({100 * reset_time / total_time:.1f}%)")
    print(f"Total time per direction: {total_time:.4f}s")

    return {
        'losses': np.array(losses),
        'coefficients': coefficients,
        'perturbations': perturbation_direction,
    }


In [7]:
random_loss_results = wiggle_simple_benchmark(
        simple_model=model_family[model_idx],
        loss_fn=loss_fn,
        perturbation_direction=random_perturb_vectors[0],
        coefficients=coefficients,
    )


Per-coefficient breakdown for direction:
  - Apply perturbation: 0.0354s (18.8%)
  - Loss computation:   0.1344s (71.6%)
  - Reset:              0.0180s (9.6%)
Total time per direction: 0.1878s


The majority of the time taken is in the loss computation. This is dependent on our algorithm and can't be sped up. Perhaps this speed benchmarking is premature, given that our system is so trivial (2 parameters) and the results may not be reflective of actual use case.