# SciTeX Torch Utilities Tutorial

This notebook demonstrates the PyTorch utilities in SciTeX, focusing on NaN-safe operations and tensor manipulation functions.

## 1. Setup and Imports

In [None]:
import scitex as stx
import torch
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Understanding NaN-Safe Operations

In scientific computing, dealing with missing or invalid data (NaN values) is common. Standard PyTorch operations propagate NaN values, which can break computations. SciTeX provides NaN-safe alternatives.

### 2.1 The Problem with Standard Operations

In [None]:
# Create a tensor with NaN values
data = torch.tensor([1.0, 2.0, float('nan'), 4.0, 5.0, float('nan'), 7.0])
print(f"Data with NaNs: {data}")
print(f"NaN locations: {torch.isnan(data)}")

# Standard operations fail with NaN
print("\nStandard PyTorch operations:")
print(f"torch.max(data): {torch.max(data)}")  # Returns NaN
print(f"torch.mean(data): {torch.mean(data)}")  # Returns NaN
print(f"torch.std(data): {torch.std(data)}")  # Returns NaN
print(f"torch.sum(data): {torch.sum(data)}")  # Returns NaN

# SciTeX NaN-safe operations
print("\nSciTeX NaN-safe operations:")
print(f"stx.torch.nanmax(data): {stx.torch.nanmax(data)}")  # Ignores NaN
print(f"torch.nanmean(data): {torch.nanmean(data)}")  # PyTorch's built-in
print(f"stx.torch.nanstd(data): {stx.torch.nanstd(data)}")  # Ignores NaN
print(f"torch.nansum(data): {torch.nansum(data)}")  # PyTorch's built-in

### 2.2 NaN-Safe Min/Max Operations

In [None]:
# Create 2D tensor with NaN values
data_2d = torch.randn(5, 4)
# Insert some NaN values
data_2d[1, 2] = float('nan')
data_2d[3, 0] = float('nan')
data_2d[4, 3] = float('nan')

print("2D data with NaNs:")
print(data_2d)

# Compare standard vs NaN-safe operations
print("\nMax along rows (dim=1):")
print(f"Standard max: {torch.max(data_2d, dim=1)[0]}")
print(f"NaN-safe max: {stx.torch.nanmax(data_2d, dim=1)[0]}")

print("\nMin along columns (dim=0):")
print(f"Standard min: {torch.min(data_2d, dim=0)[0]}")
print(f"NaN-safe min: {stx.torch.nanmin(data_2d, dim=0)[0]}")

# Finding indices with argmax/argmin
print("\nArgmax along rows (dim=1):")
print(f"NaN-safe argmax: {stx.torch.nanargmax(data_2d, dim=1)}")
print(f"NaN-safe argmin: {stx.torch.nanargmin(data_2d, dim=1)}")

### 2.3 NaN-Safe Statistical Operations

In [None]:
# Generate time series data with missing values
time_series = torch.randn(100, 3)  # 100 time points, 3 channels

# Simulate missing data (10% NaN)
mask = torch.rand_like(time_series) < 0.1
time_series[mask] = float('nan')

print(f"Time series shape: {time_series.shape}")
print(f"NaN count: {torch.isnan(time_series).sum().item()} / {time_series.numel()}")

# Compute statistics per channel
print("\nPer-channel statistics (ignoring NaNs):")
print(f"Mean: {torch.nanmean(time_series, dim=0)}")
print(f"Std:  {stx.torch.nanstd(time_series, dim=0)}")
print(f"Var:  {stx.torch.nanvar(time_series, dim=0)}")

# Visualize
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
t = torch.arange(len(time_series))

for i in range(3):
    channel_data = time_series[:, i]
    valid_mask = ~torch.isnan(channel_data)
    
    # Plot valid data points
    axes[i].plot(t[valid_mask], channel_data[valid_mask], 'b-', alpha=0.7)
    
    # Mark NaN locations
    nan_mask = torch.isnan(channel_data)
    if nan_mask.any():
        axes[i].scatter(t[nan_mask], torch.zeros(nan_mask.sum()), 
                       color='red', marker='x', s=50, label='NaN')
    
    # Add statistics
    mean = torch.nanmean(channel_data)
    std = stx.torch.nanstd(channel_data)
    axes[i].axhline(mean, color='green', linestyle='--', alpha=0.5, label=f'Mean: {mean:.2f}')
    axes[i].fill_between(t, mean-std, mean+std, alpha=0.2, color='green')
    
    axes[i].set_ylabel(f'Channel {i+1}')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

axes[-1].set_xlabel('Time')
plt.suptitle('Time Series with Missing Data (NaN-safe statistics)')
plt.tight_layout()
plt.show()

### 2.4 NaN-Safe Cumulative Operations

In [None]:
# Demonstrate cumulative operations with NaN
data = torch.tensor([1.0, 2.0, float('nan'), 3.0, float('nan'), 4.0, 5.0])

print("Original data:")
print(data)

# Cumulative sum
print("\nCumulative sum:")
print(f"Standard cumsum: {torch.cumsum(data, dim=0)}")
print(f"NaN-safe cumsum: {stx.torch.nancumsum(data, dim=0)}")

# Cumulative product
print("\nCumulative product:")
print(f"Standard cumprod: {torch.cumprod(data, dim=0)}")
print(f"NaN-safe cumprod: {stx.torch.nancumprod(data, dim=0)}")

# Visualize cumulative operations
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

x = torch.arange(len(data))

# Cumulative sum plot
ax1.plot(x, stx.torch.nancumsum(data, dim=0), 'b-', linewidth=2, label='NaN-safe cumsum')
ax1.scatter(x[torch.isnan(data)], stx.torch.nancumsum(data, dim=0)[torch.isnan(data)], 
           color='red', s=100, zorder=5, label='NaN positions')
ax1.set_xlabel('Index')
ax1.set_ylabel('Cumulative Sum')
ax1.set_title('NaN-Safe Cumulative Sum')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Cumulative product plot
ax2.plot(x, stx.torch.nancumprod(data, dim=0), 'g-', linewidth=2, label='NaN-safe cumprod')
ax2.scatter(x[torch.isnan(data)], stx.torch.nancumprod(data, dim=0)[torch.isnan(data)], 
           color='red', s=100, zorder=5, label='NaN positions')
ax2.set_xlabel('Index')
ax2.set_ylabel('Cumulative Product')
ax2.set_title('NaN-Safe Cumulative Product')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. The apply_to Function

The `apply_to` function is a powerful utility for applying any function along a specific dimension of a tensor.

### 3.1 Basic Usage

In [None]:
# Create a 3D tensor
data = torch.randn(2, 3, 4)
print(f"Data shape: {data.shape}")
print(f"Data:\n{data}")

# Apply sum along different dimensions
print("\nApplying sum along different dimensions:")
result_dim0 = stx.torch.apply_to(torch.sum, data, dim=0)
print(f"Sum along dim=0, shape: {result_dim0.shape}")

result_dim1 = stx.torch.apply_to(torch.sum, data, dim=1)
print(f"Sum along dim=1, shape: {result_dim1.shape}")

result_dim2 = stx.torch.apply_to(torch.sum, data, dim=2)
print(f"Sum along dim=2, shape: {result_dim2.shape}")

### 3.2 Custom Functions with apply_to

In [None]:
# Define custom functions to apply
def normalize_slice(x):
    """Normalize a tensor slice to have mean=0 and std=1."""
    if x.numel() == 0:
        return x
    mean = x.mean()
    std = x.std()
    if std == 0:
        return x - mean
    return (x - mean) / std

def compute_range(x):
    """Compute the range (max - min) of a tensor."""
    return x.max() - x.min()

def percentile_95(x):
    """Compute 95th percentile."""
    return torch.quantile(x.float(), 0.95)

# Create sample data
data = torch.randn(4, 5, 6) * 10 + 5

# Apply custom functions
print("Original data shape:", data.shape)

# Normalize along last dimension
normalized = stx.torch.apply_to(normalize_slice, data, dim=2)
print(f"\nAfter normalization along dim=2:")
print(f"Mean along dim=2: {normalized.mean(dim=2)}")
print(f"Std along dim=2: {normalized.std(dim=2)}")

# Compute range along different dimensions
range_dim0 = stx.torch.apply_to(compute_range, data, dim=0)
range_dim1 = stx.torch.apply_to(compute_range, data, dim=1)
print(f"\nRange along dim=0 shape: {range_dim0.shape}")
print(f"Range along dim=1 shape: {range_dim1.shape}")

# Compute percentiles
p95_dim2 = stx.torch.apply_to(percentile_95, data, dim=2)
print(f"\n95th percentile along dim=2 shape: {p95_dim2.shape}")

### 3.3 Advanced Example: Sliding Window Statistics

In [None]:
# Create time series data
n_samples = 1000
n_channels = 3
time_series = torch.cumsum(torch.randn(n_samples, n_channels), dim=0)

# Define sliding window function
def sliding_window_stats(data, window_size=50, step=10):
    """Compute statistics over sliding windows."""
    n_windows = (len(data) - window_size) // step + 1
    
    means = torch.zeros(n_windows)
    stds = torch.zeros(n_windows)
    
    for i in range(n_windows):
        start = i * step
        end = start + window_size
        window = data[start:end]
        means[i] = window.mean()
        stds[i] = window.std()
    
    return torch.stack([means, stds])

# Apply sliding window statistics to each channel
window_stats = stx.torch.apply_to(
    lambda x: sliding_window_stats(x, window_size=50, step=10),
    time_series.T,  # Transpose to have channels as first dimension
    dim=0
)

print(f"Time series shape: {time_series.shape}")
print(f"Window stats shape: {window_stats.shape}  # [channels, 2 (mean/std), n_windows]")

# Visualize
fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True)
t = torch.arange(n_samples)
window_t = torch.arange(0, n_samples-50, 10) + 25  # Window centers

for i in range(n_channels):
    # Plot original time series
    axes[i].plot(t, time_series[:, i], 'b-', alpha=0.5, label='Original')
    
    # Plot sliding window statistics
    means = window_stats[i, 0, :]
    stds = window_stats[i, 1, :]
    
    axes[i].plot(window_t, means, 'r-', linewidth=2, label='Window mean')
    axes[i].fill_between(window_t, means - stds, means + stds, 
                        alpha=0.3, color='red', label='±1 std')
    
    axes[i].set_ylabel(f'Channel {i+1}')
    axes[i].legend()
    axes[i].grid(True, alpha=0.3)

axes[-1].set_xlabel('Time')
plt.suptitle('Sliding Window Statistics with apply_to')
plt.tight_layout()
plt.show()

## 4. Practical Applications

### 4.1 Data Quality Assessment

In [None]:
# Simulate multi-channel sensor data with varying quality
n_sensors = 8
n_timepoints = 500

# Generate base signals
t = torch.linspace(0, 10, n_timepoints)
signals = torch.zeros(n_sensors, n_timepoints)

for i in range(n_sensors):
    # Different frequency components per sensor
    freq = 0.5 + i * 0.2
    signals[i] = torch.sin(2 * np.pi * freq * t) + 0.1 * torch.randn(n_timepoints)

# Introduce quality issues
# Sensor 2: Intermittent failures (NaN)
failure_mask = torch.rand(n_timepoints) < 0.15
signals[2, failure_mask] = float('nan')

# Sensor 5: Drift
signals[5] += 0.3 * t

# Sensor 7: High noise
signals[7] += 0.5 * torch.randn(n_timepoints)

# Compute quality metrics
print("Data Quality Assessment:")
print("=" * 50)

# NaN count per sensor
nan_counts = torch.isnan(signals).sum(dim=1)
print(f"\nNaN counts per sensor: {nan_counts.tolist()}")

# Signal-to-noise ratio (using NaN-safe operations)
signal_power = stx.torch.nanvar(signals, dim=1)
noise_estimate = stx.torch.apply_to(
    lambda x: stx.torch.nanvar(x[1:] - x[:-1]) / 2,  # Variance of first differences
    signals,
    dim=0
)
snr_db = 10 * torch.log10(signal_power / noise_estimate)

print(f"\nSNR (dB) per sensor: {snr_db.tolist()}")

# Data range (detect outliers)
ranges = stx.torch.apply_to(
    lambda x: stx.torch.nanmax(x) - stx.torch.nanmin(x),
    signals,
    dim=0
)
print(f"\nData ranges per sensor: {ranges.tolist()}")

# Visualize sensor quality
fig, axes = plt.subplots(2, 4, figsize=(15, 8))
axes = axes.flatten()

for i in range(n_sensors):
    ax = axes[i]
    
    # Plot signal
    valid_mask = ~torch.isnan(signals[i])
    ax.plot(t[valid_mask], signals[i][valid_mask], 'b-', linewidth=0.5)
    
    # Highlight NaN regions
    if nan_counts[i] > 0:
        nan_mask = torch.isnan(signals[i])
        ax.scatter(t[nan_mask], torch.zeros(nan_mask.sum()), 
                  color='red', s=20, alpha=0.5)
    
    # Add quality info
    quality_text = f"SNR: {snr_db[i]:.1f} dB\nNaN: {nan_counts[i]}/{n_timepoints}"
    ax.text(0.02, 0.98, quality_text, transform=ax.transAxes, 
            verticalalignment='top', fontsize=8,
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    ax.set_title(f'Sensor {i+1}')
    ax.set_ylim(-3, 3)
    ax.grid(True, alpha=0.3)

plt.suptitle('Multi-Sensor Data Quality Assessment')
plt.tight_layout()
plt.show()

### 4.2 Robust Feature Extraction

In [None]:
# Extract robust features from noisy multi-channel data
def extract_robust_features(data):
    """Extract features that are robust to NaN values."""
    features = {}
    
    # Basic statistics (NaN-safe)
    features['mean'] = torch.nanmean(data, dim=1)
    features['std'] = stx.torch.nanstd(data, dim=1)
    features['max'] = stx.torch.nanmax(data, dim=1)[0]
    features['min'] = stx.torch.nanmin(data, dim=1)[0]
    
    # Percentiles (using NaN-safe approach)
    features['q25'] = stx.torch.apply_to(
        lambda x: torch.quantile(x[~torch.isnan(x)], 0.25) if (~torch.isnan(x)).any() else torch.tensor(float('nan')),
        data, dim=0
    )
    features['q75'] = stx.torch.apply_to(
        lambda x: torch.quantile(x[~torch.isnan(x)], 0.75) if (~torch.isnan(x)).any() else torch.tensor(float('nan')),
        data, dim=0
    )
    
    # Trend (linear fit, NaN-safe)
    t = torch.arange(data.shape[1], dtype=torch.float32)
    features['trend'] = stx.torch.apply_to(
        lambda x: torch.tensor(
            np.polyfit(t[~torch.isnan(x)].numpy(), 
                      x[~torch.isnan(x)].numpy(), 1)[0]
        ) if (~torch.isnan(x)).sum() > 10 else torch.tensor(0.0),
        data, dim=0
    )
    
    return features

# Extract features from the sensor data
features = extract_robust_features(signals)

# Create feature matrix
feature_names = list(features.keys())
feature_matrix = torch.stack([features[name] for name in feature_names], dim=1)

print(f"Feature matrix shape: {feature_matrix.shape}  # [n_sensors, n_features]")
print(f"Features: {feature_names}")

# Visualize feature matrix
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(feature_matrix.T, aspect='auto', cmap='RdBu_r')
ax.set_yticks(range(len(feature_names)))
ax.set_yticklabels(feature_names)
ax.set_xticks(range(n_sensors))
ax.set_xticklabels([f'S{i+1}' for i in range(n_sensors)])
ax.set_xlabel('Sensors')
ax.set_title('Robust Feature Extraction (NaN-safe)')
plt.colorbar(im, ax=ax)

# Add values to heatmap
for i in range(len(feature_names)):
    for j in range(n_sensors):
        val = feature_matrix[j, i].item()
        if not np.isnan(val):
            text = ax.text(j, i, f'{val:.2f}', ha='center', va='center',
                         color='white' if abs(val) > feature_matrix.max()/2 else 'black',
                         fontsize=8)

plt.tight_layout()
plt.show()

## 5. Performance Comparison

In [None]:
import time

# Compare performance of NaN-safe operations
sizes = [100, 1000, 10000, 100000]
nan_percentages = [0, 0.1, 0.3, 0.5]

results = []

for size in sizes:
    for nan_pct in nan_percentages:
        # Create data with specified NaN percentage
        data = torch.randn(size)
        nan_mask = torch.rand(size) < nan_pct
        data[nan_mask] = float('nan')
        
        # Time standard operation (will return NaN if any NaN present)
        start = time.time()
        _ = torch.max(data)
        std_time = time.time() - start
        
        # Time NaN-safe operation
        start = time.time()
        _ = stx.torch.nanmax(data)
        nan_safe_time = time.time() - start
        
        results.append({
            'size': size,
            'nan_pct': nan_pct,
            'std_time': std_time * 1000,  # Convert to ms
            'nan_safe_time': nan_safe_time * 1000,
            'overhead': (nan_safe_time / std_time - 1) * 100 if std_time > 0 else 0
        })

# Display results
import pandas as pd
df_results = pd.DataFrame(results)

print("Performance Comparison: Standard vs NaN-safe Operations")
print("=" * 60)
print("\nExecution times (milliseconds):")
pivot_table = df_results.pivot_table(
    values='nan_safe_time', 
    index='size', 
    columns='nan_pct'
)
print(pivot_table.round(3))

# Visualize overhead
fig, ax = plt.subplots(figsize=(10, 6))
for nan_pct in nan_percentages:
    subset = df_results[df_results['nan_pct'] == nan_pct]
    ax.plot(subset['size'], subset['overhead'], 
           marker='o', label=f'{int(nan_pct*100)}% NaN')

ax.set_xscale('log')
ax.set_xlabel('Tensor Size')
ax.set_ylabel('Overhead (%)')
ax.set_title('NaN-safe Operation Overhead vs Standard Operations')
ax.legend()
ax.grid(True, alpha=0.3)
plt.show()

print("\nNote: NaN-safe operations have some overhead but provide correct results")
print("with missing data, while standard operations fail (return NaN).")

## 6. Integration with Neural Networks

In [None]:
import torch.nn as nn

class NaNRobustLayer(nn.Module):
    """A layer that handles NaN values in inputs."""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, x):
        # Replace NaN with zeros (or could use mean imputation)
        x_clean = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        
        # Track NaN mask for potential use
        nan_mask = torch.isnan(x)
        
        # Apply linear transformation
        output = self.linear(x_clean)
        
        return output, nan_mask

class RobustFeatureExtractor(nn.Module):
    """Extract robust statistical features from time series with NaN."""
    
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        # x shape: [batch, channels, time]
        batch_size, n_channels, _ = x.shape
        
        features = []
        
        # Extract NaN-safe features for each sample in batch
        for i in range(batch_size):
            sample_features = []
            
            # Per-channel features
            sample_features.append(torch.nanmean(x[i], dim=1))  # Mean
            sample_features.append(stx.torch.nanstd(x[i], dim=1))  # Std
            sample_features.append(stx.torch.nanmax(x[i], dim=1)[0])  # Max
            sample_features.append(stx.torch.nanmin(x[i], dim=1)[0])  # Min
            
            # Stack features
            features.append(torch.cat(sample_features))
        
        return torch.stack(features)

# Example usage
# Create data with NaN
batch_size = 16
n_channels = 8
seq_length = 100

data = torch.randn(batch_size, n_channels, seq_length)
# Add random NaN values
nan_mask = torch.rand_like(data) < 0.1
data[nan_mask] = float('nan')

# Extract features
extractor = RobustFeatureExtractor()
features = extractor(data)

print(f"Input shape: {data.shape}")
print(f"Feature shape: {features.shape}  # [batch, n_features]")
print(f"NaN count in input: {torch.isnan(data).sum().item()}")
print(f"NaN count in features: {torch.isnan(features).sum().item()}")

# Use in a complete model
class TimeSeriesClassifier(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.feature_extractor = RobustFeatureExtractor()
        # 4 features per channel (mean, std, max, min)
        n_features = n_channels * 4
        self.classifier = nn.Sequential(
            NaNRobustLayer(n_features, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, n_classes)
        )
    
    def forward(self, x):
        features = self.feature_extractor(x)
        # Handle the NaNRobustLayer output
        output, _ = self.classifier[0](features)
        # Apply rest of classifier
        for layer in self.classifier[1:]:
            output = layer(output)
        return output

# Test the classifier
model = TimeSeriesClassifier(n_channels=8, n_classes=4)
output = model(data)
print(f"\nClassifier output shape: {output.shape}")

## 7. Summary and Best Practices

### Key Takeaways

1. **NaN-Safe Operations**: Essential for real-world data with missing values
2. **apply_to Function**: Flexible way to apply any function along tensor dimensions
3. **Performance**: Small overhead for NaN-safety is worth the robustness
4. **Integration**: Easy to integrate with PyTorch models and workflows

### Best Practices

1. **Data Validation**:
   ```python
   # Always check for NaN before processing
   if torch.isnan(data).any():
       # Use NaN-safe operations
       result = stx.torch.nanmean(data)
   else:
       # Use standard operations for better performance
       result = torch.mean(data)
   ```

2. **Feature Engineering**:
   ```python
   # Create robust features that handle missing data
   features = {
       'mean': torch.nanmean(data, dim=-1),
       'std': stx.torch.nanstd(data, dim=-1),
       'valid_ratio': (~torch.isnan(data)).float().mean(dim=-1)
   }
   ```

3. **Custom Functions with apply_to**:
   ```python
   # Define reusable processing functions
   def robust_normalize(x):
       mean = torch.nanmean(x)
       std = stx.torch.nanstd(x)
       return (x - mean) / (std + 1e-8)
   
   normalized = stx.torch.apply_to(robust_normalize, data, dim=1)
   ```

4. **Model Design**:
   ```python
   # Build models that gracefully handle missing data
   class RobustModel(nn.Module):
       def forward(self, x):
           # Track data quality
           nan_ratio = torch.isnan(x).float().mean()
           if nan_ratio > 0.5:
               warnings.warn(f"High NaN ratio: {nan_ratio:.2f}")
           # Process with NaN-safe operations
           return self.process(x)
   ```

In [None]:
print("\nTorch utilities tutorial completed!")
print("\nNext steps:")
print("1. Apply NaN-safe operations to your data pipelines")
print("2. Use apply_to for custom tensor transformations")
print("3. Build robust models that handle missing data")
print("4. Monitor data quality in production systems")