In [1]:
!pip install jax keras jaxtyping 

[0m

In [2]:
!pip install signax keras_sig pandas

[0m

In [3]:
import os
# Keras and backend configuration
BACKEND = 'jax'
os.environ['KERAS_BACKEND'] = BACKEND

print('removing access to CUDA device')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["JAX_PLATFORMS"]="cpu" 

removing access to CUDA device


In [4]:
import numpy as np
import pandas as pd
import time
import json
from datetime import datetime
import platform
from typing import Callable, List, Dict
from IPython.display import display
import signax
import keras_sig
import jax
import jax.numpy as jnp

def get_cpu_info():
    if platform.system() == "Linux":
        try:
            with open('/proc/cpuinfo', 'r') as f:
                for line in f:
                    if 'model name' in line:
                        return line.split(':')[1].strip()
        except:
            pass
    return platform.processor() or platform.machine()

def time_function(func: Callable, number: int = 10) -> float:
    """Time a function over multiple runs and return average time in milliseconds"""
    # Initial compilation run
    result = func()
    # Force evaluation by converting to numpy
    if isinstance(result, (jax.Array, jnp.ndarray)):
        _ = np.asarray(result)
    
    times = []
    for _ in range(number):
        start = time.time()
        result = func()
        # Force evaluation
        if isinstance(result, (jax.Array, jnp.ndarray)):
            _ = np.asarray(result)
        times.append((time.time() - start) * 1000)  # Convert to milliseconds
    return np.mean(times)

def run_benchmark(batch_size: int, seq_len: int, n_features: int, depth: int) -> Dict:
    """Run benchmark for a specific configuration"""
    
    # Generate paths
    paths = np.random.randn(batch_size, seq_len, n_features)
    paths_jax = jnp.array(paths)
    
    def measure_signax():
        result = signax.signature(paths_jax, depth)
        return result
    
    def measure_keras_sig():
        result = keras_sig.signature(paths, depth)
        return result
        
    def measure_keras_sig_jax():
        result = keras_sig.jax_gpu_signature(paths, depth)
        return result
    
    # Time each implementation
    results = {
        'batch_size': batch_size,
        'seq_len': seq_len,
        'n_features': n_features,
        'depth': depth,
        'signax_time': time_function(measure_signax),
        'keras_sig_time': time_function(measure_keras_sig),
        'keras_sig_jax_time': time_function(measure_keras_sig_jax)
    }
    
    return results

def run_parameter_sweep():
    """Run benchmarks varying one parameter at a time"""
    
    # Default parameters
    default_batch_size = 128
    default_seq_len = 100
    default_n_features = 3
    default_depth = 4
    
    # Parameter ranges
    batch_sizes = [32, 64, 128, 256, 512]
    seq_lens = [50, 100, 200, 500, 1000]
    depths = [2, 3, 4, 5, 6]
    
    results = []
    
    # Display JAX configuration
    print("JAX devices:", jax.devices())
    print("Default device:", jax.default_backend())
    is_gpu = jax.default_backend() != "cpu"
    
    # Vary batch size
    print("\nVarying batch size...")
    for batch_size in batch_sizes:
        result = run_benchmark(
            batch_size=batch_size,
            seq_len=default_seq_len,
            n_features=default_n_features,
            depth=default_depth
        )
        results.append(result)
        print(f"Completed batch_size={batch_size}")
    
    # Vary sequence length
    print("\nVarying sequence length...")
    for seq_len in seq_lens:
        result = run_benchmark(
            batch_size=default_batch_size,
            seq_len=seq_len,
            n_features=default_n_features,
            depth=default_depth
        )
        results.append(result)
        print(f"Completed seq_len={seq_len}")
    
    # Vary depth
    print("\nVarying depth...")
    for depth in depths:
        result = run_benchmark(
            batch_size=default_batch_size,
            seq_len=default_seq_len,
            n_features=default_n_features,
            depth=depth
        )
        results.append(result)
        print(f"Completed depth={depth}")
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    
    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    device_type = "gpu" if is_gpu else "cpu"
    csv_filename = f'signature_benchmarks_{device_type}_{timestamp}.csv'
    df.to_csv(csv_filename, index=False)
    
    # Save metadata about the system
    metadata = {
        'cpu_info': get_cpu_info(),
        'jax_devices': str(jax.devices()),
        'jax_backend': jax.default_backend(),
        'is_gpu_enabled': is_gpu,
        'jax_version': jax.__version__,
        'signax_version': signax.__version__,
        'keras_sig_version': '1.0.2',
        'timestamp': timestamp,
        'system': platform.system(),
        'python_version': platform.python_version(),
        'platform': platform.platform()
    }
    
    with open(f'signature_benchmarks_{device_type}_metadata_{timestamp}.json', 'w') as f:
        json.dump(metadata, f, indent=4)
    
    return df, metadata

print("Starting signature benchmarks...")
df, metadata = run_parameter_sweep()
print("\nBenchmarks complete. Results saved to CSV.")
print("\nSystem information:")
for key, value in metadata.items():
    print(f"{key}: {value}")

display(df)

Starting signature benchmarks...
JAX devices: [CpuDevice(id=0)]
Default device: cpu

Varying batch size...
Completed batch_size=32
Completed batch_size=64
Completed batch_size=128
Completed batch_size=256
Completed batch_size=512

Varying sequence length...
Completed seq_len=50
Completed seq_len=100
Completed seq_len=200
Completed seq_len=500
Completed seq_len=1000

Varying depth...
Completed depth=2
Completed depth=3
Completed depth=4
Completed depth=5
Completed depth=6

Benchmarks complete. Results saved to CSV.

System information:
cpu_info: AMD Ryzen 9 5900X 12-Core Processor
jax_devices: [CpuDevice(id=0)]
jax_backend: cpu
is_gpu_enabled: False
jax_version: 0.4.38
signax_version: 0.2.1
keras_sig_version: 1.0.2
timestamp: 20250105_141221
system: Linux
python_version: 3.10.12
platform: Linux-6.8.0-45-generic-x86_64-with-glibc2.35


Unnamed: 0,batch_size,seq_len,n_features,depth,signax_time,keras_sig_time,keras_sig_jax_time
0,32,100,3,4,0.286245,143.489313,2.032423
1,64,100,3,4,0.823975,144.472289,4.712534
2,128,100,3,4,1.587844,146.006513,6.119704
3,256,100,3,4,3.10359,146.944284,14.458728
4,512,100,3,4,6.60243,147.0824,32.920814
5,128,50,3,4,0.82891,100.53122,2.940273
6,128,100,3,4,2.035236,151.788855,7.009292
7,128,200,3,4,2.262187,43.5287,14.098263
8,128,500,3,4,7.41868,45.985579,45.637083
9,128,1000,3,4,15.78145,132.633924,75.982785


In [5]:
import keras
import numpy as np
import pandas as pd
import time
import json
from datetime import datetime
import platform
import signax

def get_cpu_info():
    if platform.system() == "Linux":
        try:
            with open('/proc/cpuinfo', 'r') as f:
                for line in f:
                    if 'model name' in line:
                        return line.split(':')[1].strip()
        except:
            pass
    return platform.processor() or platform.machine()

class SignaxSigLayer(keras.layers.Layer):
    def __init__(self, depth, stream=False, unroll=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.depth = depth
        self.stream = stream
        
    def call(self, inputs):
        return signax.signature(inputs, depth=self.depth, stream=self.stream)

class SigNet(keras.Model):
    def __init__(self, in_channels, out_dimension, sig_input_size, sig_depth, sig_layer_class):
        super().__init__()
        self.dense1 = keras.layers.Dense(sig_input_size)
        self.signature = sig_layer_class(depth=sig_depth)
        self.linear = keras.layers.Dense(out_dimension)
        
    def call(self, inputs):
        dense_out = self.dense1(inputs)
        y = self.signature(dense_out)
        z = self.linear(y)
        return z

def create_data(num_sample, seq_len, n_feature, n_ahead):
    X = np.random.randn(num_sample, seq_len, n_feature).astype(np.float32)
    y = np.random.randn(num_sample, n_ahead).astype(np.float32)
    return X, y

def measure_compilation_time(model, X, y, batch_size):
    model.compile(
        optimizer=keras.optimizers.Adam(),
        loss="mse",
        jit_compile=True
    )
    
    # Time the first prediction which triggers compilation
    sample_X = X[:1]  # Take just one sample
    
    compilation_start = time.time()
    model.predict(sample_X, verbose=0)  # First prediction triggers compilation
    compilation_time = time.time() - compilation_start
    
    return compilation_time

def train_model(model, X, y, batch_size, epochs=10):
    # Time the actual training
    training_start = time.time()
    history = model.fit(
        X, y,
        batch_size=batch_size,
        epochs=epochs,
        verbose=0
    )
    training_time = time.time() - training_start
    
    return {
        'final_loss': float(history.history['loss'][-1]),
        'training_time': training_time,
        'avg_epoch_time': training_time/epochs
    }

def run_benchmarks():
    # Parameters to test
    seq_lens = [100, 200, 350, 500]
    sig_input_sizes = [2, 4, 6, 10]
    depths = [2, 3, 4]
    
    # Fixed parameters
    batch_size = 128
    n_feature = 20
    n_ahead = 10
    epochs = 10
    
    results = []
    sig_layers = {
        'keras_sig': keras_sig.SigLayer,
        'signax': SignaxSigLayer
    }
    
    total_runs = len(seq_lens) * len(sig_input_sizes) * len(depths) * len(sig_layers)
    current_run = 0
    
    for seq_len in seq_lens:
        num_sample = batch_size * 100 - 35  # Not exactly divisible by batch size
        
        for sig_input_size in sig_input_sizes:
            for depth in depths:
                for layer_name, layer_class in sig_layers.items():
                    current_run += 1
                    print(f"\nRun {current_run}/{total_runs}")
                    print(f"Parameters: seq_len={seq_len}, sig_input_size={sig_input_size}, depth={depth}, implementation={layer_name}")
                    
                    # Create data
                    X, y = create_data(num_sample, seq_len, n_feature, n_ahead)
                    
                    # Create model
                    model = SigNet(n_feature, n_ahead, sig_input_size, depth, layer_class)
                    
                    # Measure compilation time
                    compilation_time = measure_compilation_time(model, X, y, batch_size)
                    
                    # Train model and measure training time
                    training_results = train_model(model, X, y, batch_size, epochs)
                    
                    results.append({
                        'seq_len': seq_len,
                        'sig_input_size': sig_input_size,
                        'depth': depth,
                        'implementation': layer_name,
                        'compilation_time': compilation_time,
                        'training_time': training_results['training_time'],
                        'avg_epoch_time': training_results['avg_epoch_time'],
                        'final_loss': training_results['final_loss']
                    })
                    # Clear model and free memory
                    del model
                    keras.backend.clear_session()
    
    # Convert results to DataFrame
    df = pd.DataFrame(results)
    
    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    device_type = "gpu" if keras.backend.backend() == "jax" else "cpu"
    
    # Save benchmark results
    csv_filename = f'keras_benchmarks_{device_type}_{timestamp}.csv'
    df.to_csv(csv_filename, index=False)
    
    # Save metadata
    metadata = {
        'cpu_info': get_cpu_info(),
        'keras_backend': keras.backend.backend(),
        'keras_version': keras.__version__,
        'signax_version': signax.__version__,
        'keras_sig_version': '1.0.2',
        'timestamp': timestamp,
        'system': platform.system(),
        'python_version': platform.python_version(),
        'platform': platform.platform(),
        'batch_size': batch_size,
        'n_feature': n_feature,
        'n_ahead': n_ahead,
        'epochs': epochs
    }
    
    with open(f'keras_benchmarks_{device_type}_metadata_{timestamp}.json', 'w') as f:
        json.dump(metadata, f, indent=4)
    
    return df, metadata

print("Starting Keras signature benchmarks...")
df, metadata = run_benchmarks()
print("\nBenchmarks complete. Results saved to CSV and JSON files.")
print("\nSystem information:")
for key, value in metadata.items():
    print(f"{key}: {value}")

# Display results
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
print("\nBenchmark Results:")
print(df)

Starting Keras signature benchmarks...

Run 1/96
Parameters: seq_len=100, sig_input_size=2, depth=2, implementation=keras_sig

Run 2/96
Parameters: seq_len=100, sig_input_size=2, depth=2, implementation=signax

Run 3/96
Parameters: seq_len=100, sig_input_size=2, depth=3, implementation=keras_sig

Run 4/96
Parameters: seq_len=100, sig_input_size=2, depth=3, implementation=signax

Run 5/96
Parameters: seq_len=100, sig_input_size=2, depth=4, implementation=keras_sig

Run 6/96
Parameters: seq_len=100, sig_input_size=2, depth=4, implementation=signax

Run 7/96
Parameters: seq_len=100, sig_input_size=4, depth=2, implementation=keras_sig

Run 8/96
Parameters: seq_len=100, sig_input_size=4, depth=2, implementation=signax

Run 9/96
Parameters: seq_len=100, sig_input_size=4, depth=3, implementation=keras_sig

Run 10/96
Parameters: seq_len=100, sig_input_size=4, depth=3, implementation=signax

Run 11/96
Parameters: seq_len=100, sig_input_size=4, depth=4, implementation=keras_sig

Run 12/96
Parame