# Run Stretch and Squeeze Experiment

This notebook runs the Stretch and Squeeze experiment with configurable parameters.

## Parameters

Adjust the variables in the next cell to configure the experiment run. 

**IMPORTANT:** You MUST set the path variables (`*_DIR`) correctly for your system.

In [None]:
import os

# --- Core Paths (EDIT THESE!) ---
WEIGHTS_DIR = '/path/to/your/Kreiman_Generators' # Path to DeePSiM generator weights
REFERENCES_DIR = '/path/to/your/references.pkl'  # Path to the reference code file (output of MaximizeActivity reference run)
OUTPUT_DIR_BASE = './output/' # Base directory for saving results
CUSTOM_WEIGHTS_DIR = '/path/to/your/custom_weights' # Path to robust models etc.
NATURAL_RECORDINGS_DIR = '/path/to/your/natural_recordings.pkl' # Path to natural recordings
DATASET_DIR = '/path/to/your/dataset/miniimagenet' # Path to MiniImageNet dataset

# --- Experiment Configuration ---
EXPERIMENT_NAME = "SnS_Invariance_ResNet50_Example" # Name for this specific run
EXPERIMENT_VERSION = 0 # Version number
NUM_ITERATIONS = 10 # Number of optimization iterations (keep low for demo, e.g., 10-50)
RANDOM_SEED = 12345
RENDER_GUI = False # Set to True to show GUI display screens (requires Tkinter)

# --- Generator ---
GENERATOR_VARIANT = "fc7"

# --- Natural Images ---
TEMPLATE = "T" # Mask template ('T'=Synthetic, 'F'=Natural). "T" means no natural images.
SHUFFLE_TEMPLATE = False
BATCH_SIZE = 16

# --- Natural Stats ---
# Method to aggregate natural image stats (if template includes 'F')
# Options: 'max', 'min', 'mean', 'percentile_99', etc. or a specific threshold value as string '15.0'
NAT_STATS_AGGREGATE = 'max'

# --- Subject ---
NETWORK_NAME = 'resnet50'
# Format: "layer_idx=[unit_spec], layer_idx=[unit_spec]", e.g., "26=[], 56=[19]"
# Unit spec: []=all, [idx1 idx2]=list, [(c h w) (c h w)]=tuples, N_r[]=N random, N_radj[]=N adjacent random
RECORDING_LAYERS = "26=[], 56=[19]" # Example: Record all from layer 26, unit 19 from layer 56
ROBUST_VARIANT = '' # e.g., 'imagenet_l2_3_0.pt' for robust ResNet50, or '' for standard
# Options: 'torch_load', 'torch_load_pretrained', 'madryLab_robust_load', 'robustBench_load'
WEIGHT_LOAD_FUNCTION = 'torch_load_pretrained' if not ROBUST_VARIANT else 'madryLab_robust_load' 

# --- Scorer ---
SCORING_LAYERS = "26=[], 56=[19]" # Usually same as recording for this experiment
# Reference info identifies the specific reference code to use
REFERENCE_INFO = "G=fc7, L=56, N=[19], S=1" # Example: Use reference for layer 56, unit [19], seed 1
# Signature: Layer weights. Positive=maximize similarity, Negative=minimize similarity (maximize distance)
SCORING_SIGNATURE = "26=-1, 56=1" # Example: Minimize distance in layer 26, Maximize distance in layer 56 (Invariance)
# Bounds: Constraints on activation/distance. N=No bound, <VAL, >VAL, <VAL%, >VAL%
BOUNDS = "26=N, 56=N" 
DISTANCE_METRIC = "euclidean"
# How to handle multiple solutions in the same pareto front ('random', 'crowding', 'onevar')
WITHIN_PARETO_ORDER = 'onevar' # GInv

# --- Optimizer ---
OPTIMIZER_TYPE = 'cmaes' # 'cmaes', 'genetic', 'hybrid'
POPULATION_SIZE = 10 # Number of candidates per iteration (keep low for demo)
SIGMA0 = 1.0 # Initial variance for CMA-ES
NOISE_STRENGTH = 0.01 # For initializing non-CMAES optimizers near reference

# --- Derived Paths ---
OUTPUT_DIR = os.path.join(OUTPUT_DIR_BASE, EXPERIMENT_NAME, f"{EXPERIMENT_NAME}-{EXPERIMENT_VERSION}")

# --- Sanity Checks ---
essential_paths = [WEIGHTS_DIR, REFERENCES_DIR, CUSTOM_WEIGHTS_DIR, NATURAL_RECORDINGS_DIR, DATASET_DIR]
if any(p is None or '/path/to/your/' in p for p in essential_paths):
    print("ERROR: Please edit the placeholder paths (WEIGHTS_DIR, REFERENCES_DIR, etc.) in this cell first!")
    raise ValueError("Essential paths not configured.")

print(f"Experiment results will be saved to: {OUTPUT_DIR}")

## Imports and Setup

In [None]:
import time
import numpy as np
import torch
import matplotlib.pyplot as plt
import warnings

# Load snslib components
from snslib.core.utils.parameters import ArgParams, ParamConfig
from snslib.experiment.utils.args import ExperimentArgParams
from snslib.experiment.stretch_and_squeeze import StretchSqueezeExperiment
from snslib.core.utils.logger import LoguruLogger, DisplayScreen

# Setup autoreload and suppress warnings
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

print("Setup complete.")

## Prepare Configuration

We gather the parameters defined above into the dictionary format expected by the `Experiment.from_config` method.

In [None]:
# Create the configuration dictionary
config: ParamConfig = {
    # Experiment Info
    ArgParams.ExperimentName.value: EXPERIMENT_NAME,
    ArgParams.ExperimentVersion.value: EXPERIMENT_VERSION,
    ArgParams.OutputDirectory.value: OUTPUT_DIR_BASE, # Base dir, versioned path created internally
    ArgParams.NumIterations.value: NUM_ITERATIONS,
    ArgParams.RandomSeed.value: RANDOM_SEED,
    ArgParams.Render.value: RENDER_GUI,

    # Generator
    ExperimentArgParams.GenWeights.value: WEIGHTS_DIR,
    ExperimentArgParams.GenVariant.value: GENERATOR_VARIANT,

    # Natural Images
    ExperimentArgParams.Template.value: TEMPLATE,
    ExperimentArgParams.Dataset.value: DATASET_DIR,
    ExperimentArgParams.Shuffle.value: SHUFFLE_TEMPLATE,
    ExperimentArgParams.BatchSize.value: BATCH_SIZE,

    # Natural Stats
    ExperimentArgParams.Nat_recs.value: NATURAL_RECORDINGS_DIR,
    ExperimentArgParams.Nrec_aggregate.value: NAT_STATS_AGGREGATE,

    # Subject
    ExperimentArgParams.NetworkName.value: NETWORK_NAME,
    ExperimentArgParams.RecordingLayers.value: RECORDING_LAYERS,
    ExperimentArgParams.CustomWeightsPath.value: CUSTOM_WEIGHTS_DIR,
    ExperimentArgParams.CustomWeightsVariant.value: ROBUST_VARIANT,
    ExperimentArgParams.WeightLoadFunction.value: WEIGHT_LOAD_FUNCTION,
    ExperimentArgParams.Rec_low.value: "", # Placeholders, specific experiments might use these
    ExperimentArgParams.Rec_high.value: "",

    # Scorer
    ExperimentArgParams.ScoringLayers.value: SCORING_LAYERS,
    ExperimentArgParams.Reference.value: REFERENCES_DIR,
    ExperimentArgParams.ReferenceInfo.value: REFERENCE_INFO,
    ExperimentArgParams.ScoringSignature.value: SCORING_SIGNATURE,
    ExperimentArgParams.Bounds.value: BOUNDS,
    ExperimentArgParams.Distance.value: DISTANCE_METRIC,
    ExperimentArgParams.UnitsReduction.value: "mean", # Often fixed for SnS
    ExperimentArgParams.LayerReduction.value: "mean", # Often fixed for SnS
    ExperimentArgParams.Within_pareto_order.value: WITHIN_PARETO_ORDER,
    ExperimentArgParams.Score_low.value: "", # Placeholders
    ExperimentArgParams.Score_high.value: "",

    # Optimizer
    ExperimentArgParams.OptimType.value: OPTIMIZER_TYPE,
    ExperimentArgParams.PopulationSize.value: POPULATION_SIZE,
    ExperimentArgParams.Sigma0.value: SIGMA0,
    ExperimentArgParams.Noise_strength.value: NOISE_STRENGTH,

    # Internal flags
    ArgParams.CloseScreen.value: True # Assume we close screen after single run
}

# Add the experiment title needed by from_config internals
config[ArgParams.ExperimentTitle.value] = StretchSqueezeExperiment.EXPERIMENT_TITLE

print("Configuration dictionary prepared.")

## Instantiate and Run Experiment

We create the experiment object from the configuration and run it.

In [None]:
# Handle GUI display setup if needed
main_screen = None
if RENDER_GUI:
    print("Setting up main screen for GUI display...")
    main_screen = DisplayScreen.set_main_screen() # Keep reference
    # Note: Logger inside from_config will handle adding screens if RENDER_GUI is True
    
try:
    print("Instantiating StretchSqueezeExperiment...")
    # Create experiment instance
    experiment = StretchSqueezeExperiment.from_config(config)

    print(f"Starting experiment: {experiment.name}...")
    start_time = time.time() # Record start time
    
    # Run the experiment
    message = experiment.run()
    
    end_time = time.time() # Record end time
    print("\nExperiment finished.")
    print(f"Results saved to: {experiment.dir}")
    print(f"Total runtime: {end_time - start_time:.2f} seconds")

except FileNotFoundError as e:
    print(f"\nERROR: A required file was not found: {e}")
    print("Please double-check the paths in the 'Parameters' section.")
except Exception as e:
    print(f"\nAn unexpected error occurred: {e}")
    import traceback
    traceback.print_exc()
finally:
    # Ensure main Tkinter window is closed if it was created
    if main_screen is not None and RENDER_GUI:
        print("Closing main GUI screen.")
        main_screen.quit()
        main_screen.destroy() # Explicitly destroy

## Basic Results Preview

Let's look at some basic output stored in the message object returned by the run.

In [None]:
# Check if the experiment ran successfully and produced a message
if 'message' in locals():
    print(f"Experiment run completed in: {message.elapsed_time:.2f} seconds")
    
    # Display final scores if available
    if message.scores_gen_history:
        final_scores_gen = message.scores_gen_history[-1]
        print(f"\nFinal Generation Scores (Synthetic):")
        print(f"  Mean: {np.mean(final_scores_gen):.3f}")
        print(f"  Std:  {np.std(final_scores_gen):.3f}")
        print(f"  Min:  {np.min(final_scores_gen):.3f}")
        print(f"  Max:  {np.max(final_scores_gen):.3f}")
    else:
        print("\nNo synthetic scores recorded in history.")
        
    # Display layer-wise scores if available (ParetoMessage specific)
    if hasattr(message, 'layer_scores_gen_history') and message.layer_scores_gen_history:
        print("\nFinal Layer-wise Scores (Synthetic, Mean over population):")
        final_layer_scores = {k: np.mean(v[-1]) for k, v in message.layer_scores_gen_history.items() if v}
        for layer, score in final_layer_scores.items():
            print(f"  Layer '{layer}': {score:.3f}")
    
    # Display info about Pareto Front if available
    if hasattr(message, 'Pfront_1') and message.Pfront_1:
        print(f"\nFinal Global Pareto Front size: {len(list(message.Pfront_1.values())[-1])}")
        
    print(f"\nFull results and state saved in: {experiment.dir}")
else:
    print("Experiment did not run successfully. Cannot display results.")